将tf.data.Dataset与tf集线器模块一起使用

时间:2020-05-14 22:28:21

标签: tensorflow2.0 tensorflow-datasets tf.keras tensorflow-hub

如何为tf.keras模型(包含一维输入TF Hub模块)和tf.data.Dataset提供数据?

(最终目的是使用具有多输入,多输出keras函数式api模型的单个tf.data.Dataset。)

试过这个:

import tensorflow as tf
import tensorflow_hub as hub

embed = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1"
hub_layer = hub.KerasLayer(embed, output_shape=[20], input_shape=[], 
                           dtype=tf.string, trainable=True, name='hub_layer')

# From tf hub webpage: "The module takes a batch of sentences in a 1-D tensor of strings as input."

input_tensor = tf.keras.Input(shape=(), dtype=tf.string)
hub_tensor = hub_layer(input_tensor)
x = tf.keras.layers.Dense(16, activation='relu')(hub_tensor)#(x)
main_output = tf.keras.layers.Dense(units=4, activation='softmax', name='main_output')(x)

model = tf.keras.models.Model(inputs=[input_tensor], outputs=[main_output])

# This works as expected.
X_tensor = tf.constant(['Hello World', 'The Quick Brown Fox'])
model(X_tensor)

# This fails
X_ds = tf.data.Dataset.from_tensors(X_tensor)
X_ds.element_spec
model(X_ds)

期望是数据集中的一维张量将被模型自动提取和使用。

错误消息:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
 in 
     21 X_ds = tf.data.Dataset.from_tensors(X_tensor)
     22 X_ds.element_spec
---> 23 model(X_ds)
     24 
     25 

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    966           with base_layer_utils.autocast_context_manager(
    967               self._compute_dtype):
--> 968             outputs = self.call(cast_inputs, *args, **kwargs)
    969           self._handle_activity_regularization(inputs, outputs)
    970           self._set_mask_metadata(inputs, outputs, input_masks)

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py in call(self, inputs, training, mask)
    717     return self._run_internal_graph(
    718         inputs, training=training, mask=mask,
--> 719         convert_kwargs_to_constants=base_layer_utils.call_context().saving)
    720 
    721   def compute_output_shape(self, input_shape):

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py in _run_internal_graph(self, inputs, training, mask, convert_kwargs_to_constants)
    835     tensor_dict = {}
    836     for x, y in zip(self.inputs, inputs):
--> 837       y = self._conform_to_reference_input(y, ref_input=x)
    838       x_id = str(id(x))
    839       tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py in _conform_to_reference_input(self, tensor, ref_input)
    959     # Dtype handling.
    960     if isinstance(ref_input, (ops.Tensor, composite_tensor.CompositeTensor)):
--> 961       tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
    962 
    963     return tensor

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    178     """Call target, and fall back on dispatchers if there is a TypeError."""
    179     try:
--> 180       return target(*args, **kwargs)
    181     except (TypeError, ValueError):
    182       # Note: convert_to_eager_tensor currently raises a ValueError, not a

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py in cast(x, dtype, name)
    785       # allows some conversions that cast() can't do, e.g. casting numbers to
    786       # strings.
--> 787       x = ops.convert_to_tensor(x, name="x")
    788       if x.dtype.base_dtype != base_type:
    789         x = gen_math_ops.cast(x, base_type, name=name)

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, dtype_hint, ctx, accepted_result_types)
   1339 
   1340     if ret is None:
-> 1341       ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
   1342 
   1343     if ret is NotImplemented:

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in _constant_tensor_conversion_function(v, dtype, name, as_ref)
    319                                          as_ref=False):
    320   _ = as_ref
--> 321   return constant(v, dtype=dtype, name=name)
    322 
    323 

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in constant(value, dtype, shape, name)
    260   """
    261   return _constant_impl(value, dtype, shape, name, verify_shape=False,
--> 262                         allow_broadcast=True)
    263 
    264 

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in _constant_impl(value, dtype, shape, name, verify_shape, allow_broadcast)
    268   ctx = context.context()
    269   if ctx.executing_eagerly():
--> 270     t = convert_to_eager_tensor(value, ctx, dtype)
    271     if shape is None:
    272       return t

~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in convert_to_eager_tensor(value, ctx, dtype)
     94       dtype = dtypes.as_dtype(dtype).as_datatype_enum
     95   ctx.ensure_initialized()
---> 96   return ops.EagerTensor(value, ctx.device_name, dtype)
     97 
     98 

ValueError: Attempt to convert a value () with an unsupported type () to a Tensor.

1 个答案:

答案 0 :(得分:0)

数据集的目的是提供张量的序列,如下所示:

all_data = tf.constant([['Hello', 'World'], ['Brown Fox', 'lazy dog']])
ds = tf.data.Dataset.from_tensor_slices(all_data)
for tensor in ds:
  print(tensor)

输出

tf.Tensor([b'Hello' b'World'], shape=(2,), dtype=string)
tf.Tensor([b'Brown Fox' b'lazy dog'], shape=(2,), dtype=string)

您不仅可以打印tensor,还可以使用它进行计算:

for tensor in ds:
  print(hub_layer(tensor))

分别输出2个形状为(2,20)的张量。

有关更多信息,请参见https://www.tensorflow.org/guide/data