[深度学习] 门控循环单元GRU

门控循环单元(Gated Recurrent Unit, GRU)是一种用于处理序列数据的递归神经网络(Recurrent Neural Network, RNN)变体,它通过引入门控机制来解决传统RNN在处理长序列时的梯度消失问题。GRU与长短期记忆网络(LSTM)相似,但结构更为简化。以下是GRU的详细介绍:

1. GRU的结构

GRU由以下几个主要部分组成:

  • 重置门(reset gate):控制当前时间步的输入如何与之前的记忆结合,用于决定要丢弃多少过去的信息。
  • 更新门(update gate):控制上一时间步的记忆如何流入当前时间步的记忆,用于决定要保留多少过去的信息。

具体来说,GRU的计算过程如下:

2. 公式表示

假设xt是当前时间步的输入,ht−1​是上一时间步的隐状态,则GRU的更新过程可以用以下公式表示:

  1. 重置门(reset gate)
    在这里插入图片描述
  2. 更新门(update gate)
    在这里插入图片描述
  3. 候选隐状态(candidate hidden state)
    在这里插入图片描述
  4. 当前隐状态(current hidden state)
    在这里插入图片描述
    其中:
  • σ 是sigmoid激活函数。
  • tanh 是tanh激活函数。
  • W和U 是权重矩阵,b是偏置项。
  • ⊙ 表示元素乘法(Hadamard积)。

3. GRU的工作原理

  • 重置门rt:决定了多少过去的记忆需要被重置或忽略。重置门的值接近0时,意味着更多的过去信息被丢弃;值接近1时,意味着保留更多的过去信息。
  • 更新门zt:决定了当前时间步的记忆如何与之前的记忆进行权衡。更新门的值接近0时,更多的过去记忆被保留;值接近1时,更多的当前信息被引入。

4. GRU与LSTM的比较

  • 结构:GRU比LSTM结构更简单,LSTM有三个门(输入门、遗忘门和输出门),而GRU只有两个门(重置门和更新门)。
  • 参数:由于结构较为简化,GRU的参数量比LSTM少,因此在某些任务中计算效率更高。
  • 性能:在许多任务上,GRU与LSTM的表现相当,有时GRU甚至表现得更好,特别是在数据量较少的情况下。

5. 应用场景

GRU广泛应用于自然语言处理(NLP)、语音识别、时间序列预测等领域,尤其适合需要处理长序列数据的任务。

6. 实现示例

在TensorFlow中,可以使用tf.keras.layers.GRU来实现一个GRU层:

import tensorflow as tf
import numpy as np

# 生成示例数据
# 输入序列(样本数量,时间步长,特征维度)
input_seq = np.random.randn(3, 5, 10).astype(np.float32)

# 定义GRU模型
model = tf.keras.Sequential([
    tf.keras.layers.GRU(20, return_sequences=True, input_shape=(5, 10)),  # 隐状态维度为20
    tf.keras.layers.GRU(20)  # 第二个GRU层
])

# 编译模型
model.compile(optimizer='adam', loss='mse')

# 打印模型摘要
model.summary()

# 生成示例标签(样本数量,输出维度)
output_seq = np.random.randn(3, 20).astype(np.float32)

# 训练模型
model.fit(input_seq, output_seq, epochs=10)

# 预测
predictions = model.predict(input_seq)
print(predictions)

代码解释
  1. 数据生成

    input_seq = np.random.randn(3, 5, 10).astype(np.float32)
    

    这里生成了一个随机的输入序列,假设有3个样本,每个样本有5个时间步,每个时间步有10个特征。

  2. 定义GRU模型

    model = tf.keras.Sequential([
        tf.keras.layers.GRU(20, return_sequences=True, input_shape=(5, 10)),
        tf.keras.layers.GRU(20)
    ])
    

    使用tf.keras.Sequential定义了一个简单的GRU模型。第一个GRU层的隐状态维度为20,并且返回所有时间步的输出。第二个GRU层的隐状态维度也为20,但只返回最后一个时间步的输出。

  3. 编译模型

    model.compile(optimizer='adam', loss='mse')
    

    使用Adam优化器和均方误差损失函数来编译模型。

  4. 打印模型摘要

    model.summary()
    

    打印模型的摘要信息,以查看模型的结构和参数数量。

  5. 生成示例标签并训练模型

    output_seq = np.random.randn(3, 20).astype(np.float32)
    model.fit(input_seq, output_seq, epochs=10)
    

    生成与输入序列匹配的随机标签,并使用这些标签来训练模型。

  6. 预测

    predictions = model.predict(input_seq)
    print(predictions)
    

    使用训练好的模型进行预测,并打印预测结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/745224.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

反射及动态代理

反射 定义: 反射允许对封装类的字段,方法和构造 函数的信息进行编程访问 图来自黑马程序员 获取class对象的三种方式: 1)Class.forName("全类名") 2)类名.class 3) 对象.getClass() 图来自黑马程序员 pac…

前端JS必用工具【js-tool-big-box】学习,数值型数组的正向排序和倒向排序

这一小节,我们说一下前端 js-tool-big-box 这个工具库,添加的数值型数组的正向排序和倒向排序。 以前呢,我们的数组需要排序的时候,都是在项目的utils目录里,写一段公共方法,弄个冒泡排序啦,弄…

JNI详解

JNI简介 Java是跨平台的语言,但在有的时候仍需要调用本地代码(这些代码通常由C/C编写的)。 Sun公司提供的JNI是Java平台的一个功能强大的接口,JNI接口提供了Java与操作系统本地代码互相调用的功能。 Java调C 1)使用…

Spring Boot 学习第八天:AOP代理机制对性能的影响

1 概述 在讨论动态代理机制时,一个不可避免的话题是性能。无论采用JDK动态代理还是CGLIB动态代理,本质上都是在原有目标对象上进行了封装和转换,这个过程需要消耗资源和性能。而JDK和CGLIB动态代理的内部实现过程本身也存在很大差异。下面将讨…

VMware vSphere 8.0 Update 3 发布下载 - 企业级工作负载平台

VMware vSphere 8.0 Update 3 发布下载 - 企业级工作负载平台 vSphere 8.0U3 | ESXi 8.0U3 & vCenter Server 8.0U3 请访问原文链接:https://sysin.org/blog/vmware-vsphere-8-u3/,查看最新版。原创作品,转载请保留出处。 作者主页&am…

Java面试八股之JVM内存溢出的原因及解决方案

JVM内存溢出的原因及解决方案 JVM内存溢出(Out Of Memory,OOM)通常是由于程序运行过程中内存使用不当造成的,常见原因及相应的解决方案如下: 原因及解决方案 内存中加载的数据量过大 原因:一次性从数据…

运维入门技术——监控的三个维度(非常详细)零基础收藏这一篇就够了_监控维度怎么区分

一个好的监控系统最后要做到的形态:实现Metrics、Tracing、Logging的融合。监控的三个维度也就是Metrics、Tracing、Logging。 Metrics Metrics也就是我们常说的指标。 首先它的典型特征就是可聚合(aggregatable).什么是可聚合的呢,简单讲可聚合就是一种基本单位可以在一种维…

Verilog刷题笔记48——FSM1型异步复位

题目: 解题: module top_module(input clk,input areset, // Asynchronous reset to state Binput in,output out);// parameter A0, B1; reg state, next_state;always (*) begin // This is a combinational always block// State transition logiccase(…

加拿大魁北克IT人士的就业分析

魁北克省作为加拿大东部的一个重要省份,近年来在IT行业的就业市场上展现出了强劲的增长势头。随着数字化转型的加速,魁北克对IT专业人士的需求日益增加,特别是在软件开发、网络安全、数据分析和人工智能等领域。 热门职位方面,软…

禹晶、肖创柏、廖庆敏《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》Chapter 9插图

禹晶、肖创柏、廖庆敏《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》 Chapter 9插图

201.回溯算法:全排列(力扣)

class Solution { public:vector<int> res; // 用于存储当前排列组合vector<vector<int>> result; // 用于存储所有的排列组合void backtracing(vector<int>& nums, vector<bool>& used) {// 如果当前排列组合的长度等于 nums 的长度&am…

用 Rust 实现一个替代 WebSocket 的协议

很久之前我就对websocket颇有微词&#xff0c;它的确满足了很多情境下的需求&#xff0c;但是仍然有不少问题。对我来说&#xff0c;最大的一个问题是websocket的数据是明文传输的&#xff0c;这使得websocket的数据很容易遭到劫持和攻击。同时&#xff0c;WebSocket继承自HTTP…

【操作系统】操作系统发展简史

目录 1.前言 2.第一代&#xff08;1945~1955&#xff09;&#xff1a;真空管和穿孔卡片 3.第二代&#xff08;1955~1965&#xff09;&#xff1a;晶体管和批处理系统 4.第三代&#xff08;1965~1980&#xff09;&#xff1a;集成电路和多道程序设计 5.第四代&#xff08;1…

关于VMware遇到的一些问题

问题一&#xff1a;打不开磁盘…或它所依赖的某个快照磁盘&#xff0c;开启模块DiskEarly的操作失败&#xff0c;未能启动虚拟机 解决方法&#xff1a; 首先将centos 7关机&#xff0c;然后把快照1删掉 然后打开虚拟机所在目录&#xff0c;把提示的000001.vmdk全部删除&…

本地读取classNames txt文件

通过本地读取classNames,来减少程序修改代码,提高了程序的拓展性和自定义化。 步骤: 1、输入本地路径,分割字符串。 2、将className按顺序放入vector容器中。 3、将vector赋值给classNmaes;获取classNames.size(),赋值给CLASSES;这样,类别个数和类别都已经赋值完成。…

大厂面试官问我:Redis内存淘汰,LRU维护整个队列吗?【后端八股文四:Redis内存淘汰策略八股文合集】

往期内容&#xff1a; 大厂面试官问我&#xff1a;Redis处理点赞&#xff0c;如果瞬时涌入大量用户点赞&#xff08;千万级&#xff09;&#xff0c;应当如何进行处理&#xff1f;【后端八股文一&#xff1a;Redis点赞八股文合集】-CSDN博客 大厂面试官问我&#xff1a;布隆过滤…

vue3 Cesium 离线地图

1、vite-plugin-cesium 是一个专门为 Vite 构建工具定制的插件&#xff0c;用于在 Vite 项目中轻松使用 Cesium 库。它简化了在 Vite 项目中集成 Cesium 的过程。 npm i cesium vite-plugin-cesium vite -D 2、配置vite.config.js import cesium from vite-plugin-cesiumexp…

监测与管理:钢筋计在工程项目中的应用

在现代工程建设中&#xff0c;特别是大型长期工程项目&#xff0c;对结构安全性的监测与管理至关重要。钢筋计作为一种重要的监测工具&#xff0c;在工程项目中发挥着不可替代的作用。本文将探讨钢筋计在长期工程项目中的应用&#xff0c;包括安装方法、数据监测与分析以及实际…

“基于下垂的多电源分布式控制系统设计”,高分资源,匠心制作,查重5%,下载可用。强烈推荐!!!

“基于下垂的多电源分布式控制系统设计”&#xff0c;高分资源&#xff0c;匠心制作&#xff0c;查重5%&#xff0c;下载可用。强烈推荐&#xff01;&#xff01;&#xff01; 摘要 社会的进步与发展&#xff0c;人们对于能源的需求和依赖越来越大。与此同时&#xff0c;国家…

通达信擒牛亮剑出击抄底主升浪指标公式源码

通达信擒牛亮剑出击抄底主升浪指标公式源码&#xff1a; ABC1:(CLOSE-REF(CLOSE,1))/REF(CLOSE,1)*100; ABC2:IF(CLOSE>OPEN,CLOSE,OPEN); ABC3:IF(CLOSE>OPEN,OPEN,CLOSE); ABC4:LLV(ABC2,4); ABC5:HHV(ABC3,4); ABC6:ABC2>ABC4 AND ABC3<ABC4 AND ABC2>ABC5 …