我当前的input_fn
是:
def augment(image, label):
image = tf.image.random_flip_up_down(image)
image = tf.image.random_flip_left_right(image)
return image, label
def train_input_fn(images, labels, batch_size=BATCH_SIZE):
dataset = tf.data.Dataset.from_tensor_slices((images, labels.values))
dataset = dataset.cache()
dataset = dataset.map(augment)
dataset = dataset.shuffle(batch_size * 10)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
images
和labels
是内存中的numpy数组。
我启动了一个keras模型并使用TPU进行训练:
model = from_base_model(tf.keras.applications.ResNet50)
tpu_model = tf.contrib.tpu.keras_to_tpu_model(
model,
strategy=tf.contrib.tpu.TPUDistributionStrategy(
tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)))
tpu_model.fit(lambda : train_input_fn(images, labels),
steps_per_epoch=10,
epochs=10)
哪一种效果很好,但速度有点慢,我感觉这与我的input_fn
有关。所以我的第一个问题是,我在input_fn
上做错什么了吗?
除此之外,当我增加images
数组时,我得到以下信息:
/usr/local/lib/python3.6/dist-packages/keras_applications/resnet50.py:265: UserWarning: The output shape of `ResNet50(include_top=False)` has been changed since Keras 2.2.0.
warnings.warn('The output shape of `ResNet50(include_top=False)` '
INFO:tensorflow:Querying Tensorflow master (b'grpc://10.120.179.10:8470') for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 1028257885212618014)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 3818076851124023299)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_GPU:0, XLA_GPU, 17179869184, 17760761706563925312)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 1175599564916411245)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 15427033274328587252)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 11995323192741160700)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 10572054117553235355)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 1187876647131221859)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 8130720152970464558)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 240703176264203696)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 2342271245137911487)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 12994409452324068660)
WARNING:tensorflow:tpu_model (from tensorflow.contrib.tpu.python.tpu.keras_support) is experimental and may change or be removed at any time, and without warning.
INFO:tensorflow:Cloning Adam {'lr': 0.0010000000474974513, 'beta_1': 0.8999999761581421, 'beta_2': 0.9990000128746033, 'decay': 0.0, 'epsilon': 1e-07, 'amsgrad': False}
INFO:tensorflow:Cloning Adam {'lr': 0.0010000000474974513, 'beta_1': 0.8999999761581421, 'beta_2': 0.9990000128746033, 'decay': 0.0, 'epsilon': 1e-07, 'amsgrad': False}
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-12-b92ce1ea0c16> in <module>()
8 tpu_model.fit(lambda : train_input_fn(images, labels),
9 steps_per_epoch=10,
---> 10 epochs=10)
/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/tpu/python/tpu/keras_support.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
1455 with ops.device(
1456 '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
-> 1457 dataset = x()
1458 if steps_per_epoch is None:
1459 raise ValueError('When using tf.data as input to a model, you '
<ipython-input-12-b92ce1ea0c16> in <lambda>()
6 tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)))
7
----> 8 tpu_model.fit(lambda : train_input_fn(images, labels),
9 steps_per_epoch=10,
10 epochs=10)
<ipython-input-10-2f782939c783> in train_input_fn(images, labels, batch_size)
5
6 def train_input_fn(images, labels, batch_size=BATCH_SIZE):
----> 7 dataset = tf.data.Dataset.from_tensor_slices((images, labels.values))
8 dataset = dataset.cache()
9 dataset = dataset.map(augment)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py in from_tensor_slices(tensors)
287 Dataset: A `Dataset`.
288 """
--> 289 return TensorSliceDataset(tensors)
290
291 @staticmethod
/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, tensors)
1563 if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
1564 t, name="component_%d" % i)
-> 1565 for i, t in enumerate(nest.flatten(tensors))
1566 ])
1567 flat_tensors = nest.flatten(tensors)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py in <listcomp>(.0)
1563 if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
1564 t, name="component_%d" % i)
-> 1565 for i, t in enumerate(nest.flatten(tensors))
1566 ])
1567 flat_tensors = nest.flatten(tensors)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, preferred_dtype)
1048 name=name,
1049 preferred_dtype=preferred_dtype,
-> 1050 as_ref=False)
1051
1052
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, ctx)
1144
1145 if ret is None:
-> 1146 ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
1147
1148 if ret is NotImplemented:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/constant_op.py in _constant_tensor_conversion_function(v, dtype, name, as_ref)
227 as_ref=False):
228 _ = as_ref
--> 229 return constant(v, dtype=dtype, name=name)
230
231
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/constant_op.py in constant(value, dtype, shape, name, verify_shape)
206 tensor_value.tensor.CopyFrom(
207 tensor_util.make_tensor_proto(
--> 208 value, dtype=dtype, shape=shape, verify_shape=verify_shape))
209 dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
210 const_tensor = g.create_op(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape)
504 if nparray.size * nparray.itemsize >= (1 << 31):
505 raise ValueError(
--> 506 "Cannot create a tensor proto whose content is larger than 2GB.")
507 tensor_proto.tensor_content = nparray.tostring()
508 return tensor_proto
ValueError: Cannot create a tensor proto whose content is larger than 2GB.
大小从0.57GB增加到2.28 GB。
有什么我可以解决的吗?
The following guide建议为我的图像创建一个tf.placeholder
,但是我应该在哪里向数据集提供真实的images
?
使用tf.data.Dataset
的主要目的是能够轻松进行{{1}}和shuffle()
的扩增。