深度学习--循环神经网络

目录

  • 第1关:学习单步的RNN:RNNCell
  • 第2关:探幽入微LSTM
  • 第3关:进阶RNN:学习一次执行多步以及堆叠RNN的构建

第1关:学习单步的RNN:RNNCell

本关任务:学习 RNN 循环神经网络的基本概念,并构建单个 RNNCell 。


# -*- coding: utf-8 -*-
import tensorflow as tf

# 参数 a 是 BasicRNNCell所含的神经元数, 参数 b 是 batch_size, 参数 c 是单个 input 的维数,shape = [ b , c ]
def creatRNNCell(a,b,c):
    # 请在此添加代码 完成本关任务
    # ********** Begin *********#
    output_size = a
    batch_size = b
    cell = tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
    print(cell.output_size)
    input = tf.placeholder(dtype=tf.float32,shape=[batch_size,c])
    h0 = cell.zero_state(batch_size=batch_size,dtype=tf.float32)
    output,h1 = cell(input,h0)
    print(h1)

    # ********** End **********#

第2关:探幽入微LSTM

本关任务了解 LSTM 的核心思想,根据教程指导创建 LSTM 网络的一个 LSTMCell 。

# -*- coding: utf-8 -*-
import tensorflow as tf

# 参数 a 是 BasicLSTMCell所含的神经元数, 参数 b 是 batch_size, 参数 c 是单个 input 的维数,shape = [ b , c ]
def creatLSTMCell(a,b,c):
    # 请在此添加代码 完成本关任务
    # ********** Begin *********#
    inputs = tf.placeholder(tf.float32, [b,c])
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=a)
    h0 = lstm_cell.zero_state(batch_size=b,dtype=tf.float32)
    output, h1 = lstm_cell.__call__(inputs, h0)
    print(h1.h)  
    print(h1.c) 
    # ********** End **********#


第3关:进阶RNN:学习一次执行多步以及堆叠RNN的构建

本关任务:学习如何一次执行得到 RNNCell 多步调用 call 方法得到的结果,其次掌握构建堆叠 RNNCell 的方法。

# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np

# 参数 a 是RNN的层数, 参数 b 是每个BasicRNNCell包含的神经元数即state_size
# 参数 c 是输入序列的批量大小即batch_size,参数 d 是时间序列的步长即time_steps,参数 e 是单个输入input的维数即input_size

def get_a_cell(b):
   return tf.nn.rnn_cell.BasicRNNCell(num_units=b)

def MultiRNNCell_dynamic_call(a,b,c,d,e):
    # 用tf.nn.rnn_cell MultiRNNCell创建a层RNN,并调用tf.nn.dynamic_rnn
    # 请在此添加代码 完成本关任务
    # ********** Begin *********#
    cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(b) for _ in range(a)])
    inputs = tf.placeholder(np.float32, shape=(c,d,e))
    h0 = cell.zero_state(c, np.float32)
    output, h1 = tf.nn.dynamic_rnn(cell,inputs,initial_state=h0)
    print(output)
    # ********** End **********#

你可能感兴趣的:(深度学习框架,rnn,深度学习,lstm)