我以
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot add tensor to the batch: number of elements does not match. Shapes are: [tensor]: [5], [batch]: [3]
import tensorflow as tf
#from tensorflow.data import Dataset
#tf.data.Dataset.from_tensor_slices
data = tf.ragged.constant([
[[1, 1],
[2, 2]],
[[3, 3]],
[],
[[4, 4],
[5, 5]]
], dtype=tf.float32, ragged_rank=1)
lst = [[13, 15], [5, 5], [12, 12, 12], [20, 5, 15, 15, 15], [7, 7], [1], [2, 2, 15, 4, 4], [1]]
lst2 = [[11, 11], [5, 5], [12, 12, 12], [20, 5, 15, 15, 15], [7, 7], [1], [2, 2, 15, 4, 4], [1]]
#lst = [[[1, 1],[2, 2]],[[3, 3]],[],[[4, 4],[5, 5]],[[1, 1],[2, 2]],[[1, 1],[2, 2],[3,4]]]
#data = tf.ragged.constant(lst,lst2)
lst = tf.ragged.constant(lst)
lst2 = tf.ragged.constant(lst2)
def preprocessing(lst,lst2):
#print('preprocessing->',data)
return lst,lst2
def train(ds):
for i, batch in enumerate(ds):
print("batch", i)
for x in batch:
print(x.numpy())
ds = tf.data.Dataset.from_tensor_slices((lst,lst2))
train(ds.map(preprocessing).batch(2))
此示例在batch_size为1的情况下适用,但在批处理大小大于2时给出此错误