TensorFlow:具有可变批量大小的TimeFreqLSTMCell

时间:2018-06-02 03:56:50

标签: tensorflow lstm

我遇到了tf.contrib.rnn.BasicLSTMCell的奇怪问题。当我尝试使用'无'批量大小时,它会让我误以为:

Traceback (most recent call last):
  File "/home/nezin/.venv/ascc/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 1620, in zeros
tensor_shape.TensorShape(shape))
  File "/home/nezin/.venv/ascc/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 256, in _tensor_shape_tensor_conversion_function
"Cannot convert a partially known TensorShape to a Tensor: %s" % s)
ValueError: Cannot convert a partially known TensorShape to a Tensor: (?, 128)

如果我换行

cell = tf.contrib.rnn.TimeFreqLSTMCell(128, feature_size = window_size)

cell = tf.contrib.rnn.BasicLSTMCell(128)

一切都按预期工作。

我认为问题源于我正在使用

batch_size = tf.shape(x)[0]

因为如果我改为使用

batch_size = 1024

培训工作正常。问题是我有时会使用不同大小的批次,例如在一个时代的末尾,批次可能会更小,所以当我尝试将其提供给不同的东西时,这会崩溃。

完整代码位于

之下
import tensorflow as tf
from train.train import train

config = {}
model_name = 'timefreq'

class Tfrnn:
    def __init__(self,x,config):
        max_time = 32
        window_size = 32
        batch_size = tf.shape(x)[0]
        x = tf.transpose(x, perm = [0,2,1])
        x = tf.reshape(x, [batch_size, max_time, window_size, 2])
        x = tf.transpose(x, perm = [0,1,3,2])
        x = x[:,:,0]
        print(x)
        x = tf.reshape(x, [batch_size, max_time, window_size])
        cell = tf.contrib.rnn.TimeFreqLSTMCell(128, feature_size = window_size)
        #cell = tf.contrib.rnn.BasicLSTMCell(128)
        init = cell.zero_state(batch_size, tf.float32)
        print(x)
        output, state = tf.nn.dynamic_rnn(
            cell, x, initial_state = init, dtype = tf.float32,
            sequence_length = tf.tile([max_time], [batch_size]))
        print(output)
        flat = tf.contrib.layers.flatten(output[:,max_time-1,:])
        self.logits = tf.contrib.layers.fully_connected(flat,24,activation_fn=None)

        self.train_feed_dict = {}
        self.valid_feed_dict = {}


train(Tfrnn, model_name, config, batch_size = 1024, recover = False, shuffle = False, epochs=100,
      log_directory = './timefreq_log/')

请注意,我正在使用我的自定义“训练”功能来处理一些数据集。

0 个答案:

没有答案