解析序列化张量时,tf.io.parse_tensor
具有必需的kwarg“ out_type”。但是,似乎tf确实需要这样做才能知道序列化张量的类型,因为当输入错误类型时,它会设法打印出正确的张量。
如果没有这个参数,我怎么能解析?
MWE:
tf.io.parse_tensor(tf.io.serialize_tensor(tf.constant([1])), tf.int32)
tf.io.parse_tensor(tf.io.serialize_tensor(tf.constant([1])))
Traceback (most recent call last):
File "<input>", line 1, in <module>
TypeError: parse_tensor() missing 1 required positional argument: 'out_type'
tf.io.parse_tensor(tf.io.serialize_tensor(tf.constant([1])), tf.float32)
Traceback (most recent call last):
File "<input>", line 1, in <module>
File "/Users/clementwalter/.pyenv/versions/keras_fsl/lib/python3.6/site-packages/tensorflow/python/ops/gen_parsing_ops.py", line 2160, in parse_tensor
_ops.raise_from_not_ok_status(e, name)
File "/Users/clementwalter/.pyenv/versions/keras_fsl/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 6653, in raise_from_not_ok_status
six.raise_from(core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Type mismatch between parsed tensor (int32) and dtype (float) [Op:ParseTensor]
答案 0 :(得分:1)
编辑:
我设计了一种“ hacky”解决方案,可以读取不同类型的张量,将它们转换为给定类型,并与@tf.function
一起使用(有趣的是,它不有效不含@tf.function
)。这个想法是读取TensorProto
消息的第二个字节,该消息应该指示数据类型,然后使tf.switch_case
从一系列可能的源数据类型进行转换。它是这样工作的:
import tensorflow as tf
# Different sets of data types you could use
INTEGER_DTYPES = frozenset({tf.bool, tf.uint8, tf.uint16, tf.uint32, tf.uint64,
tf.int8, tf.int16, tf.int32, tf.int64})
FLOAT_DTYPES = frozenset({tf.float16, tf.bfloat16, tf.float32, tf.float64})
COMPLEX_DTYPES = frozenset({tf.complex64, tf.complex128})
REAL_DTYPES = INTEGER_DTYPES | FLOAT_DTYPES
NUMERICAL_DTYPES = REAL_DTYPES | COMPLEX_DTYPES
@tf.function
def parse_tensor_cast(tensor_proto, out_dtype, possible_dtypes=REAL_DTYPES):
# Prepare branches
branches = {}
dtype_idx = [0] * 128
for i, dtype in enumerate(possible_dtypes):
dtype_idx[dtype.as_datatype_enum] = i
branches[i] = lambda: tf.dtypes.cast(
tf.io.parse_tensor(tensor_proto, dtype), out_dtype)
dtype_idx = tf.constant(dtype_idx, tf.int32)
# Extract dtype byte ("hacky" part of the solution)
dtype_code = tf.strings.substr(tensor_proto, 1, 1)
dtype_num = tf.io.decode_raw(dtype_code, tf.uint8)[0]
dtype_num_idx = dtype_idx[tf.dtypes.cast(dtype_num, tf.int32)]
# Switch operation
return tf.switch_case(dtype_num_idx, branches)
# Test
serialized_tensors = [
tf.io.serialize_tensor(tf.constant([1, 2, 3], tf.int32)),
tf.io.serialize_tensor(tf.constant([1, 2, 3], tf.float64))
]
for t in serialized_tensors:
tf.print(parse_tensor_cast(t, tf.float32))
# [1 2 3]
# [1 2 3]
很遗憾,您不能跳过此参数。如果只有急切模式,则没有必要,但是如果要“绘制”此操作(例如,在@tf.function
中),则需要在实际解析发生之前预先知道数据类型。
如果您只对渴望模式感兴趣,那么解决这一问题并不困难:
import numpy as np
import tensorflow as tf
input_tensor = tf.constant([1, 2, 3], tf.int32)
# A scalar tf.string tensor containing the serialized input_tensor
serialized_tensor = tf.io.serialize_tensor(input_tensor)
# Create a TensorProto from serialized_tensor content
tensor_proto = tf.core.framework.tensor_pb2.TensorProto()
tensor_proto.ParseFromString(serialized_tensor.numpy())
# At this point, this is equivalent to tf.make_tensor_proto
tf.make_tensor_proto(input_tensor)
# Read data back from tensor_proto
tensor_parsed = tf.io.parse_tensor(serialized_tensor.numpy(),
tf.dtypes.as_dtype(tensor_proto.dtype))
tf.debugging.assert_equal(input_tensor, tensor_parsed)
# You can also just directly create the tensor from the extracted message
numpy_parsed = tf.make_ndarray(tensor_proto)
np.testing.assert_array_equal(input_tensor.numpy(), numpy_parsed)