元素数量不匹配。形状为:[张量]:[5],[批处理]:[3]

时间:2020-07-12 14:30:18

标签: python tensorflow batch-processing tensorflow2.0

我以

错误为例
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot add tensor to the batch: number of elements does not match. Shapes are: [tensor]: [5], [batch]: [3]



import tensorflow as tf
#from tensorflow.data import Dataset
#tf.data.Dataset.from_tensor_slices
data = tf.ragged.constant([
    [[1, 1],
     [2, 2]],

    [[3, 3]],

    [],

    [[4, 4],
     [5, 5]]
], dtype=tf.float32, ragged_rank=1)
lst = [[13, 15], [5, 5], [12, 12, 12], [20, 5, 15, 15, 15], [7, 7], [1], [2, 2, 15, 4, 4], [1]]
lst2 = [[11, 11], [5, 5], [12, 12, 12], [20, 5, 15, 15, 15], [7, 7], [1], [2, 2, 15, 4, 4], [1]]

#lst = [[[1, 1],[2, 2]],[[3, 3]],[],[[4, 4],[5, 5]],[[1, 1],[2, 2]],[[1, 1],[2, 2],[3,4]]]
#data = tf.ragged.constant(lst,lst2)

lst = tf.ragged.constant(lst)
lst2 = tf.ragged.constant(lst2)
def preprocessing(lst,lst2):
    #print('preprocessing->',data)
    return lst,lst2


def train(ds):
    for i, batch in enumerate(ds):
        print("batch", i)
        for x in batch:
            print(x.numpy())

ds = tf.data.Dataset.from_tensor_slices((lst,lst2))
train(ds.map(preprocessing).batch(2))

此示例在batch_size为1的情况下适用,但在批处理大小大于2时给出此错误

0 个答案:

没有答案