我遇到了tf.contrib.rnn.BasicLSTMCell的奇怪问题。当我尝试使用'无'批量大小时,它会让我误以为:
Traceback (most recent call last):
File "/home/nezin/.venv/ascc/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 1620, in zeros
tensor_shape.TensorShape(shape))
File "/home/nezin/.venv/ascc/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 256, in _tensor_shape_tensor_conversion_function
"Cannot convert a partially known TensorShape to a Tensor: %s" % s)
ValueError: Cannot convert a partially known TensorShape to a Tensor: (?, 128)
如果我换行
cell = tf.contrib.rnn.TimeFreqLSTMCell(128, feature_size = window_size)
与
cell = tf.contrib.rnn.BasicLSTMCell(128)
一切都按预期工作。
我认为问题源于我正在使用
batch_size = tf.shape(x)[0]
因为如果我改为使用
batch_size = 1024
培训工作正常。问题是我有时会使用不同大小的批次,例如在一个时代的末尾,批次可能会更小,所以当我尝试将其提供给不同的东西时,这会崩溃。
完整代码位于
之下import tensorflow as tf
from train.train import train
config = {}
model_name = 'timefreq'
class Tfrnn:
def __init__(self,x,config):
max_time = 32
window_size = 32
batch_size = tf.shape(x)[0]
x = tf.transpose(x, perm = [0,2,1])
x = tf.reshape(x, [batch_size, max_time, window_size, 2])
x = tf.transpose(x, perm = [0,1,3,2])
x = x[:,:,0]
print(x)
x = tf.reshape(x, [batch_size, max_time, window_size])
cell = tf.contrib.rnn.TimeFreqLSTMCell(128, feature_size = window_size)
#cell = tf.contrib.rnn.BasicLSTMCell(128)
init = cell.zero_state(batch_size, tf.float32)
print(x)
output, state = tf.nn.dynamic_rnn(
cell, x, initial_state = init, dtype = tf.float32,
sequence_length = tf.tile([max_time], [batch_size]))
print(output)
flat = tf.contrib.layers.flatten(output[:,max_time-1,:])
self.logits = tf.contrib.layers.fully_connected(flat,24,activation_fn=None)
self.train_feed_dict = {}
self.valid_feed_dict = {}
train(Tfrnn, model_name, config, batch_size = 1024, recover = False, shuffle = False, epochs=100,
log_directory = './timefreq_log/')
请注意,我正在使用我的自定义“训练”功能来处理一些数据集。