Tensorflor 2.0-from_generator形状不匹配

时间:2019-09-04 22:32:26

标签: python tensorflow keras tensorflow-datasets

我正在尝试使用Dataset.from_generator函数来生成用于训练和测试的数据。目前,我正面临有关数据形状匹配的问题:

InvalidArgumentError:  Incompatible shapes: [7,3] vs. [7]
     [[node metrics_16/sparse_categorical_accuracy/Equal (defined at <ipython-input-13-d214206d5c0a>:40) ]] [Op:__inference_keras_scratch_graph_9143]

Function call stack:
keras_scratch_graph

这是完整的最小代码:

import numpy as np
import pandas as pd
import random
import tensorflow as tf

INPUT_SHAPE=[3, 5]
NUM_POINTS=20
BATCH_SIZE=7
EPOCHS=3

def data_gen(num=10, in_shape=[5, 3]):
    print("IN_SHAPE: ", in_shape)
    for i in range(num):
        res = np.random.rand(in_shape[0], in_shape[1]), random.randint(0,2)
        print("Output shape of el: ", res[0].shape, " and training pair (x,y):\n", res)
        yield res

train = tf.data.Dataset.from_generator(
    generator=data_gen,
    output_types=(tf.float32, tf.int32),
#     output_shapes=(tf.TensorShape([None, INPUT_SHAPE[1]]), tf.TensorShape(None)),
#     output_shapes=(tf.TensorShape(INPUT_SHAPE), tf.TensorShape(None)),
    args=([NUM_POINTS, INPUT_SHAPE])
)

def create_model(input_shape):
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(100, activation="tanh",input_shape=input_shape),        
        tf.keras.layers.Dense(3, activation="softmax", kernel_regularizer= tf.keras.regularizers.l2(0.001))
    ])
    return model

model = create_model(input_shape=INPUT_SHAPE)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4, clipvalue=1.0),
          loss= tf.keras.losses.SparseCategoricalCrossentropy(), #'sparse_categorical_crossentropy',
          metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
print(model.summary())
model.fit(train.batch(BATCH_SIZE), epochs=EPOCHS, verbose=2)
model.evaluate(train, steps=None, verbose=1)

我在以下位置运行代码:

  • python = 3.7.4(默认,2019年8月17日,20:42:51) [C 10.0.1 (clang-1001.0.46.4)]
  • tensorflow = 2.0.0-beta1
  • numpy = 1.17.0
  • pandas = 0.24.2

以后的修改:简化示例

0 个答案:

没有答案