将NumPy数组消耗到tf.data.Dataset

时间:2019-01-10 13:53:31

标签: python tensorflow

我当前的input_fn是:

def augment(image, label):
  image = tf.image.random_flip_up_down(image)
  image = tf.image.random_flip_left_right(image)
  return image, label

def train_input_fn(images, labels, batch_size=BATCH_SIZE):
  dataset = tf.data.Dataset.from_tensor_slices((images, labels.values))
  dataset = dataset.cache()
  dataset = dataset.map(augment)
  dataset = dataset.shuffle(batch_size * 10)
  dataset = dataset.repeat()
  dataset = dataset.batch(batch_size, drop_remainder=True)
  return dataset

imageslabels是内存中的numpy数组。

我启动了一个keras模型并使用TPU进行训练:

model = from_base_model(tf.keras.applications.ResNet50)

tpu_model = tf.contrib.tpu.keras_to_tpu_model(
    model,
    strategy=tf.contrib.tpu.TPUDistributionStrategy(
        tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)))

tpu_model.fit(lambda : train_input_fn(images, labels),
              steps_per_epoch=10,
              epochs=10)

哪一种效果很好,但速度有点慢,我感觉这与我的input_fn有关。所以我的第一个问题是,我在input_fn上做错什么了吗?

除此之外,当我增加images数组时,我得到以下信息:

/usr/local/lib/python3.6/dist-packages/keras_applications/resnet50.py:265: UserWarning: The output shape of `ResNet50(include_top=False)` has been changed since Keras 2.2.0.
warnings.warn('The output shape of `ResNet50(include_top=False)` '
INFO:tensorflow:Querying Tensorflow master (b'grpc://10.120.179.10:8470') for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 1028257885212618014)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 3818076851124023299)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_GPU:0, XLA_GPU, 17179869184, 17760761706563925312)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 1175599564916411245)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 15427033274328587252)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 11995323192741160700)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 10572054117553235355)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 1187876647131221859)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 8130720152970464558)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 240703176264203696)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 2342271245137911487)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 12994409452324068660)
WARNING:tensorflow:tpu_model (from tensorflow.contrib.tpu.python.tpu.keras_support) is experimental and may change or be removed at any time, and without warning.
INFO:tensorflow:Cloning Adam {'lr': 0.0010000000474974513, 'beta_1': 0.8999999761581421, 'beta_2': 0.9990000128746033, 'decay': 0.0, 'epsilon': 1e-07, 'amsgrad': False}
INFO:tensorflow:Cloning Adam {'lr': 0.0010000000474974513, 'beta_1': 0.8999999761581421, 'beta_2': 0.9990000128746033, 'decay': 0.0, 'epsilon': 1e-07, 'amsgrad': False}
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-12-b92ce1ea0c16> in <module>()
    8 tpu_model.fit(lambda : train_input_fn(images, labels),
    9               steps_per_epoch=10,
---> 10               epochs=10)

/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/tpu/python/tpu/keras_support.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
1455         with ops.device(
1456             '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
-> 1457           dataset = x()
1458           if steps_per_epoch is None:
1459             raise ValueError('When using tf.data as input to a model, you '

<ipython-input-12-b92ce1ea0c16> in <lambda>()
    6         tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)))
    7 
----> 8 tpu_model.fit(lambda : train_input_fn(images, labels),
    9               steps_per_epoch=10,
    10               epochs=10)

<ipython-input-10-2f782939c783> in train_input_fn(images, labels, batch_size)
    5 
    6 def train_input_fn(images, labels, batch_size=BATCH_SIZE):
----> 7   dataset = tf.data.Dataset.from_tensor_slices((images, labels.values))
    8   dataset = dataset.cache()
    9   dataset = dataset.map(augment)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py in from_tensor_slices(tensors)
    287       Dataset: A `Dataset`.
    288     """
--> 289     return TensorSliceDataset(tensors)
    290 
    291   @staticmethod

/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, tensors)
1563           if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
1564               t, name="component_%d" % i)
-> 1565           for i, t in enumerate(nest.flatten(tensors))
1566       ])
1567       flat_tensors = nest.flatten(tensors)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py in <listcomp>(.0)
1563           if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
1564               t, name="component_%d" % i)
-> 1565           for i, t in enumerate(nest.flatten(tensors))
1566       ])
1567       flat_tensors = nest.flatten(tensors)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, preferred_dtype)
1048       name=name,
1049       preferred_dtype=preferred_dtype,
-> 1050       as_ref=False)
1051 
1052 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, ctx)
1144 
1145     if ret is None:
-> 1146       ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
1147 
1148     if ret is NotImplemented:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/constant_op.py in _constant_tensor_conversion_function(v, dtype, name, as_ref)
    227                                          as_ref=False):
    228   _ = as_ref
--> 229   return constant(v, dtype=dtype, name=name)
    230 
    231 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/constant_op.py in constant(value, dtype, shape, name, verify_shape)
    206   tensor_value.tensor.CopyFrom(
    207       tensor_util.make_tensor_proto(
--> 208           value, dtype=dtype, shape=shape, verify_shape=verify_shape))
    209   dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
    210   const_tensor = g.create_op(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape)
    504     if nparray.size * nparray.itemsize >= (1 << 31):
    505       raise ValueError(
--> 506           "Cannot create a tensor proto whose content is larger than 2GB.")
    507     tensor_proto.tensor_content = nparray.tostring()
    508     return tensor_proto

ValueError: Cannot create a tensor proto whose content is larger than 2GB.

大小从0.57GB增加到2.28 GB。

有什么我可以解决的吗?

The following guide建议为我的图像创建一个tf.placeholder,但是我应该在哪里向数据集提供真实的images

使用tf.data.Dataset的主要目的是能够轻松进行{​​{1}}和shuffle()的扩增。

0 个答案:

没有答案