如何在张量流中不给出out_type的情况下解析张量?

时间:2020-06-12 16:33:53

标签: python tensorflow serialization

解析序列化张量时,tf.io.parse_tensor具有必需的kwarg“ out_type”。但是,似乎tf确实需要这样做才能知道序列化张量的类型,因为当输入错误类型时,它会设法打印出正确的张量。

如果没有这个参数,我怎么能解析?

MWE:

tf.io.parse_tensor(tf.io.serialize_tensor(tf.constant([1])), tf.int32)

tf.io.parse_tensor(tf.io.serialize_tensor(tf.constant([1])))
Traceback (most recent call last):
  File "<input>", line 1, in <module>
TypeError: parse_tensor() missing 1 required positional argument: 'out_type'

tf.io.parse_tensor(tf.io.serialize_tensor(tf.constant([1])), tf.float32)
Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "/Users/clementwalter/.pyenv/versions/keras_fsl/lib/python3.6/site-packages/tensorflow/python/ops/gen_parsing_ops.py", line 2160, in parse_tensor
    _ops.raise_from_not_ok_status(e, name)
  File "/Users/clementwalter/.pyenv/versions/keras_fsl/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 6653, in raise_from_not_ok_status
    six.raise_from(core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Type mismatch between parsed tensor (int32) and dtype (float) [Op:ParseTensor]

1 个答案:

答案 0 :(得分:1)

编辑:

我设计了一种“ hacky”解决方案,可以读取不同类型的张量,将它们转换为给定类型,并与@tf.function一起使用(有趣的是,它有效不含@tf.function)。这个想法是读取TensorProto消息的第二个字节,该消息应该指示数据类型,然后使tf.switch_case从一系列可能的源数据类型进行转换。它是这样工作的:

import tensorflow as tf

# Different sets of data types you could use
INTEGER_DTYPES = frozenset({tf.bool, tf.uint8, tf.uint16, tf.uint32, tf.uint64,
                            tf.int8, tf.int16, tf.int32, tf.int64})
FLOAT_DTYPES = frozenset({tf.float16, tf.bfloat16, tf.float32, tf.float64})
COMPLEX_DTYPES = frozenset({tf.complex64, tf.complex128})
REAL_DTYPES = INTEGER_DTYPES | FLOAT_DTYPES
NUMERICAL_DTYPES = REAL_DTYPES | COMPLEX_DTYPES

@tf.function
def parse_tensor_cast(tensor_proto, out_dtype, possible_dtypes=REAL_DTYPES):
    # Prepare branches
    branches = {}
    dtype_idx = [0] * 128
    for i, dtype in enumerate(possible_dtypes):
        dtype_idx[dtype.as_datatype_enum] = i
        branches[i] = lambda: tf.dtypes.cast(
            tf.io.parse_tensor(tensor_proto, dtype), out_dtype)
    dtype_idx = tf.constant(dtype_idx, tf.int32)
    # Extract dtype byte ("hacky" part of the solution)
    dtype_code = tf.strings.substr(tensor_proto, 1, 1)
    dtype_num = tf.io.decode_raw(dtype_code, tf.uint8)[0]
    dtype_num_idx = dtype_idx[tf.dtypes.cast(dtype_num, tf.int32)]
    # Switch operation
    return tf.switch_case(dtype_num_idx, branches)

# Test
serialized_tensors = [
    tf.io.serialize_tensor(tf.constant([1, 2, 3], tf.int32)),
    tf.io.serialize_tensor(tf.constant([1, 2, 3], tf.float64))
]
for t in serialized_tensors:
    tf.print(parse_tensor_cast(t, tf.float32))
# [1 2 3]
# [1 2 3]

很遗憾,您不能跳过此参数。如果只有急切模式,则没有必要,但是如果要“绘制”此操作(例如,在@tf.function中),则需要在实际解析发生之前预先知道数据类型。

如果您只对渴望模式感兴趣,那么解决这一问题并不困难:

import numpy as np
import tensorflow as tf

input_tensor = tf.constant([1, 2, 3], tf.int32)

# A scalar tf.string tensor containing the serialized input_tensor
serialized_tensor = tf.io.serialize_tensor(input_tensor)

# Create a TensorProto from serialized_tensor content
tensor_proto = tf.core.framework.tensor_pb2.TensorProto()
tensor_proto.ParseFromString(serialized_tensor.numpy())

# At this point, this is equivalent to tf.make_tensor_proto
tf.make_tensor_proto(input_tensor)

# Read data back from tensor_proto
tensor_parsed = tf.io.parse_tensor(serialized_tensor.numpy(),
                                   tf.dtypes.as_dtype(tensor_proto.dtype))
tf.debugging.assert_equal(input_tensor, tensor_parsed)

# You can also just directly create the tensor from the extracted message
numpy_parsed = tf.make_ndarray(tensor_proto)
np.testing.assert_array_equal(input_tensor.numpy(), numpy_parsed)