我有一个来自tf.data
数据生成器和tf.keras
模型的神经网络,如下所示(简化的版本,因为它太长了):
dataset = ...
使用tf.data.Dataset
方法的next_x
对象为get_next
迭代器调用x_train
,为next_y
方法调用get_next
y_train
迭代器。每个标签都是一个(1, 67)
数组,形式为一字通。
层:
input_tensor = tf.keras.layers.Input(shape=(240, 240, 3)) # dim of x
output = tf.keras.layers.Flatten()(input_tensor)
output= tf.keras.Dense(67, activation='softmax')(output) # 67 is the number of classes
型号:
model = tf.keras.models.Model(inputs=input_tensor, outputs=prediction)
model.compile(optimizer=tf.train.AdamOptimizer(), loss=tf.losses.softmax_cross_entropy, metrics=['accuracy'])
model.fit_generator(gen(dataset.next_x(), dataset.next_y()), steps_per_epochs=100)
gen
的定义如下:
def gen(x, y):
while True:
yield(x, y)
我的问题是,当我尝试运行它时,在model.fit
部分出现了错误:
ValueError: Cannot take the length of Shape with unknown rank.
任何想法都值得赞赏!
答案 0 :(得分:1)
您可以发布更长的堆栈跟踪吗?我认为您的问题可能与最近的张量流问题有关:
https://github.com/tensorflow/tensorflow/issues/24520
还有一个简单的PR可以修复它(尚未合并)。也许自己尝试一下?
编辑
这里是PR:
打开tensorflow/python/keras/engine/training_utils.py
替换以下内容(当前为232行):
if (x.shape is not None
and len(x.shape) == 1
与此:
if tensor_util.is_tensor(x):
x_shape_ndims = x.shape.ndims if x.shape is not None else None
else:
x_shape_ndims = len(x.shape)
if (x_shape_ndims == 1
答案 1 :(得分:0)
我发现了问题所在。实际上,我必须run
tf.Session
中的下一批,才能产生它。
这是它的工作方式(我不会编写其余代码,因为它保持不变):
model.fit_generator(gen(), steps_per_epochs=100)
def gen():
with tf.Session() as sess:
next_x = dataset.next_x()
next_y = dataset.next_y()
while True:
x_batch = sess.run(next_x)
y_batch = sess.run(next_y)
yield x_batch, y_batch
答案 2 :(得分:0)
对于问题Cannot take the length of Shape with unknown rank
,
多亏了以上答案,我根据此issue comment将output_shape
添加到from_generator来解决了。
就我而言,我在数据集管道中使用Dataset.from_generator
。
之前:
Dataset.from_generator(_generator_factory,
output_types=(tf.float32, tf.int8))
为我工作的代码:
Dataset.from_generator(_generator_factory,
output_types = (tf.float32, tf.int8),
output_shapes = (
tf.TensorShape([2, 224, 224, 3]),
tf.TensorShape([1,])
))
还发现this dataset official guide from tensorflow表示:
...
不需要
output_shapes
参数,但强烈建议使用该参数,因为许多张量流操作不支持未知秩的张量。如果特定轴的长度未知或变量,请在output_shapes中将其设置为None。...