tf.keras和tf.estimator和tf.dataset

时间:2019-05-06 17:11:37

标签: python tensorflow keras tensorflow-datasets tensorflow-estimator

我正在尝试更新代码以与TF 2.0一起使用。首先,我使用了预制的keras模型:

def train_input_fn(batch_size=1):
  """An input function for training"""
  print("train_input_fn: start function")

  train_dataset = tf.data.experimental.make_csv_dataset(CSV_PATH_TRAIN, batch_size=batch_size,label_name='label',
                                                        select_columns=["sample","label"])
  print('train_input_fn: finished make_csv_dataset')
  train_dataset = train_dataset.map(parse_features_vector)
  print("train_input_fn: finished the map with pars_features_vector")
  train_dataset = train_dataset.repeat().batch(batch_size)
  print("train_input_fn: finished batch size. train_dataset is %s ", train_dataset)
  return train_dataset

IMG_SHAPE = (160,160,3)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                              include_top = False,
                                              weights = 'imagenet')

base_model.trainable = False
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),
             loss='binary_crossentropy',
             metrics=['accuracy'])

estimator = tf.keras.estimator.model_to_estimator(keras_model = model, model_dir = './date')

# train_input_fn read a CSV of images, resize them and returns dataset batch
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=20)

# eval_input_fn read a CSV of images, resize them and returns dataset batch of one sample
eval_spec = tf.estimator.EvalSpec(eval_input_fn)

tf.estimator.train_and_evaluate(estimator, train_spec=train_spec, eval_spec=eval_spec)

日志是:

train_input_fn: finished batch size. train_dataset is %s  <BatchDataset shapes: ({mobilenetv2_1.00_160_input: (None, 1, 160, 160, 3)}, (None, 1)), types: ({mobilenetv2_1.00_160_input: tf.float32}, tf.int32)>

错误:

ValueError: Input 0 of layer Conv1_pad is incompatible with the layer: expected ndim=4, found ndim=5. Full shape received: [None, 1, 160, 160, 3]

将tf.keras与数据集API相结合的正确方法是什么?这是问题还是其他原因?

谢谢, 埃兰语

1 个答案:

答案 0 :(得分:1)

您不需要此行

  train_dataset = train_dataset.repeat().batch(batch_size)

您用于创建数据集的函数tf.data.experimental.make_csv_dataset已批处理了它。您可以通过

使用repeat