Advertisement

TensorFlow 堆叠循环神经网络(深层循环神经网络)

阅读量:

将多个循环神经网络串联起来的体系被称为深层循环神经网络(RNN)。其本质是将若干个循环神经元组合在一起构成一个复杂的非线性变换模型。在TensorFlow框架中提供了tf.keras.layers.StackedRNNCells()这一类模块来实现这种多层结构的设计。需要注意的是这种结构虽然复杂但本质上还是一个序列模型因此在实际应用中还需将其嵌入到完整的序列处理框架中才能发挥应有的作用

在这里插入图片描述

左侧为堆叠循环神经网络,右侧为展开图。

代码如下:

复制代码
    import tensorflow as tf
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train=x_train/255.
    x_test=x_test/255.
    batch_size=128
    cell_size=32
    lstm_cells = [tf.keras.layers.LSTMCell(cell_size) for _ in range(2)]
    stacked_lstm = tf.keras.layers.StackedRNNCells(lstm_cells)
    stacked_lstm_layer = tf.keras.layers.RNN(stacked_lstm)
    model = tf.keras.Sequential()
    model.add(stacked_lstm_layer)
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(10,activation='softmax'))
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    model.fit(x_train, y_train, batch_size=batch_size, epochs=3)
    loss, accuracy = model.evaluate(x_test, y_test)
    print('test loss', loss)
    print('test accuracy', accuracy)

全部评论 (0)

还没有任何评论哟~