tf.nn.dynamic_rnn +数据集迭代器

时间:2018-04-17 12:16:58

标签: tensorflow rnn

我正在使用Dataset API构建LSTM网络。 输入张量(在代码中命名为x)对于train和val集具有不同的形状,并且定义了迭代器而没有指定输出形状。

问题是当定义了tf.nn.dynamic_rnn graph_op时,x的形状是未知的,并引发以下错误: ValueError:未在未知的TensorShape上定义as_list()。

在没有数据集API的情况下使用tf.nn.dynamic_rnn按预期工作。 如何解决这个错误?

TF版本:1.4

import tensorflow as tf
import numpy as np

"""
1d: Number of examples per epoch
2d: Time steps size
3d: Batch size e.g. number of independent time series
4d: Number of points that are given as input in the lstm each time step

Batch size is usually smaller in val set because we use most of data for training.
Time steps size is bigger in val set because we want to speed up inference.

"""
x_train = np.random.rand(100, 8, 12, 2).astype(np.float32)
x_val = np.random.rand(8, 100, 4, 2).astype(np.float32)


use_dataset_api = True

with tf.device('/gpu:0'):
    tf.reset_default_graph()

    if not use_dataset_api:

        batch_size_pl = tf.placeholder(shape=[], dtype=tf.int32)
        x_pl = tf.placeholder(shape=[None, None, 2], dtype=tf.float32)

        cell = tf.contrib.rnn.LSTMCell(num_units=11)
        init_state = cell.zero_state(batch_size=batch_size_pl, dtype=tf.float32)

        rnn_outputs, current_state = tf.nn.dynamic_rnn(cell, x_pl, initial_state=init_state,
                                                       time_major=True, dtype=tf.float32)

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())

        # Use first example of train set
        rnn_outputs_, current_state_ = sess.run([rnn_outputs, current_state],
                                                feed_dict={batch_size_pl: 12, x_pl: x_train[0]})

        # Use first example of val set
        rnn_outputs_, current_state_ = sess.run([rnn_outputs, current_state],
                                                feed_dict={batch_size_pl: 4, x_pl: x_val[0]})

    else:
        batch_size_pl = tf.placeholder(shape=[], dtype=tf.int32)

        train_set = tf.data.Dataset.from_tensor_slices((x_train))
        val_set = tf.data.Dataset.from_tensor_slices((x_val))

        iterator = tf.data.Iterator.from_structure(train_set.output_types)  # , train_set.output_shapes)

        train_init_op = iterator.make_initializer(train_set)
        val_init_op = iterator.make_initializer(val_set)

        x = iterator.get_next()

        cell = tf.contrib.rnn.LSTMCell(num_units=11)
        init_state = cell.zero_state(batch_size=batch_size_pl, dtype=tf.float32)

        # Raises error for tensor x: as_list() is not defined on an unknown TensorShape.
        rnn_outputs, current_state = tf.nn.dynamic_rnn(cell, x, initial_state=init_state,
                                                       time_major=True, dtype=tf.float32)

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())

        # Use first example of train set
        sess.run(train_init_op)
        rnn_outputs_, current_state_ = sess.run([rnn_outputs, current_state],
                                                feed_dict={batch_size_pl: 12})

        # Use first example of val set
        sess.run(val_init_op)
        rnn_outputs_, current_state_ = sess.run([rnn_outputs, current_state],
                                                feed_dict={batch_size_pl: 4})

1 个答案:

答案 0 :(得分:0)

解决方案是更改以下行:

iterator = tf.data.Iterator.from_structure(train_set.output_types)

使用:

iterator = tf.data.Iterator.from_structure(train_set.output_types, [None, None, 2])