如何找出给定的内置tensorflow函数接受哪些dtype张量?

时间:2019-03-25 06:42:24

标签: python tensorflow

我正在使用复数进行张量流项目,因此我经常需要在复数输入上应用内置函数。那么如何检查哪些tensorflow函数接受复杂的参数作为输入呢?

例如, 当我尝试如下使用函数tf.math.scalar_mul()时-

...
self.scalar = tf.Variable(3, tf.int16)
output = tf.math.scalar_mul(x, self.scalar)
...

它会产生以下错误-

ValueError: Tensor conversion requested dtype int32 for Tensor with dtype complex64: 'Tensor("fourier__conv2d_5/mul:0", shape=(?, 28, 28, 17), dtype=complex64)'

我认为这可能是由于tf.math.scalar_mul()不接受复杂的输入。我是正确的,如果不是,那可能是错误的。 (我尝试使用tf函数代替基本的python函数,因为我认为在GPU上运行时可能会受益)

在此先感谢您的帮助。

1 个答案:

答案 0 :(得分:1)

您可以找到答案,但是结果将根据操作和内核给出,它们无法精确映射到更高级别的Python函数。如果您不熟悉TensorFlow的体系结构,它是围绕“ ops”的概念构建的,“ ops”只是对具有张量的操作的正式描述(例如,op“ Add”采用两个值并输出第三个值)。 TensorFlow计算图由互连的操作节点组成。 Ops本身并不实现任何逻辑,它们仅指定操作的名称和属性,包括可以将其应用于的数据类型。操作的实现由内核给出,内核是完成工作的实际代码。一个op可以具有许多注册的内核,这些内核可以使用不同的数据类型和/或不同的设备(CPU,GPU)运行。

TensorFlow将所有这些信息保存为“注册表”,并存储为不同的Protocol Buffers消息。尽管它不是公共API的一部分,但您实际上可以查询这些注册表以获得符合特定条件的操作或内核列表。例如,这是您如何获取所有使用某种复杂类型进行操作的操作的方法:

import tensorflow as tf

def get_ops_with_dtypes(dtypes):
    from tensorflow.python.framework import ops
    valid_ops = []
    dtype_enums = set(dtype.as_datatype_enum for dtype in dtypes)
    reg_ops = ops.op_def_registry.get_registered_ops()
    for op in reg_ops.values():
        for attr in op.attr:
            if (attr.type == 'type' and
                any(t in dtype_enums for t in attr.allowed_values.list.type)):
                valid_ops.append(op)
                break
    # Sort by name for convenience
    return sorted(valid_ops, key=lambda op: op.name)

complex_dtypes = [tf.complex64, tf.complex128]
complex_ops = get_ops_with_dtypes(complex_dtypes)

# Print one op
print(complex_ops[0])
# name: "AccumulateNV2"
# input_arg {
#   name: "inputs"
#   type_attr: "T"
#   number_attr: "N"
# }
# output_arg {
#   name: "sum"
#   type_attr: "T"
# }
# attr {
#   name: "N"
#   type: "int"
#   has_minimum: true
#   minimum: 1
# }
# attr {
#   name: "T"
#   type: "type"
#   allowed_values {
#     list {
#       type: DT_FLOAT
#       type: DT_DOUBLE
#       type: DT_INT32
#       type: DT_UINT8
#       type: DT_INT16
#       type: DT_INT8
#       type: DT_COMPLEX64
#       type: DT_INT64
#       type: DT_QINT8
#       type: DT_QUINT8
#       type: DT_QINT32
#       type: DT_BFLOAT16
#       type: DT_UINT16
#       type: DT_COMPLEX128
#       type: DT_HALF
#       type: DT_UINT32
#       type: DT_UINT64
#     }
#   }
# }
# attr {
#   name: "shape"
#   type: "shape"
# }
# is_aggregate: true
# is_commutative: true

# Print op names
print(*(op.name for op in complex_ops), sep='\n')
# AccumulateNV2
# AccumulatorApplyGradient
# AccumulatorTakeGradient
# Acos
# Acosh
# Add
# AddN
# AddV2
# Angle
# ApplyAdaMax
# ...

这里complex_ops中的元素是OpDef消息,您可以检查它们以找到操作的确切结构。在这种情况下,get_ops_with_dtypes仅返回其type属性中具有给定数据类型之一的每个操作,因此复数值可以应用于输入或输出之一。

另一种选择是直接查找与您感兴趣的数据类型配合使用的内核。内核存储为KernelDef消息,其中不包含有关操作的所有信息,但是例如有关它们可以在其上运行的设备的信息,因此您还可以查询支持特定设备的内核。

import tensorflow as tf

def get_kernels_with_dtypes(dtypes, device_type=None):
    from tensorflow.python.framework import kernels
    valid_kernels = []
    dtype_enums = set(dtype.as_datatype_enum for dtype in dtypes)
    reg_kernels = kernels.get_all_registered_kernels()
    for kernel in reg_kernels.kernel:
        if device_type and kernel.device_type != device_type:
            continue
        for const in kernel.constraint:
            if any(t in dtype_enums for t in const.allowed_values.list.type):
                valid_kernels.append(kernel)
                break
    # Sort by name for convenience
    return sorted(valid_kernels, key=lambda kernel: kernel.op)

complex_dtypes = [tf.complex64, tf.complex128]
complex_gpu_kernels = get_kernels_with_dtypes(complex_dtypes, device_type='GPU')

# Print one kernel
print(complex_gpu_kernels[0])
# op: "Add"
# device_type: "GPU"
# constraint {
#   name: "T"
#   allowed_values {
#     list {
#       type: DT_COMPLEX64
#     }
#   }
# }

# Print kernel op names
print(*(kernel.op for kernel in complex_gpu_kernels), sep='\n')
# Add
# Add
# AddN
# AddN
# AddV2
# AddV2
# Assign
# Assign
# AssignVariableOp
# AssignVariableOp
# ...

问题在于,当您在Python中使用TensorFlow进行编程时,您从未真正真正地直接使用ops或内核。 Python函数接受您提供的参数,对其进行验证并在图中生成一个或多个新操作,通常将最后一个的输出值返回给您。因此,最终找出与您相关的运维/内核需要一些检查。例如,考虑以下示例:

import tensorflow as tf

with tf.Graph().as_default():
    # Matrix multiplication: (2, 3) x (3, 4)
    tf.matmul([[1, 2, 3], [4, 5, 6]], [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
    # Print all op names and types
    all_ops = tf.get_default_graph().get_operations()
    print(*(f'Op name: {op.name}, Op type: {op.type}' for op in all_ops), sep='\n')
    # Op name: MatMul/a, Op type: Const
    # Op name: MatMul/b, Op type: Const
    # Op name: MatMul, Op type: MatMul

with tf.Graph().as_default():
    # Matrix multiplication: (1, 2, 3) x (1, 3, 4)
    tf.matmul([[[1, 2, 3], [4, 5, 6]]], [[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
    # Print all op names and types
    all_ops = tf.get_default_graph().get_operations()
    print(*(f'Op name: {op.name}, Op type: {op.type}' for op in all_ops), sep='\n')
    # Op name: MatMul/a, Op type: Const
    # Op name: MatMul/b, Op type: Const
    # Op name: MatMul, Op type: BatchMatMul

在这里,相同的Python函数tf.matmul在每种情况下都产生了op类型。在这两种情况下,前两个操作均为Const,这是由于将给定列表转换为TensorFlow张量而导致的,而第三个操作在一种情况下为MatMul,在另一情况下是BatchedMatMul,因为在第二种情况下,输入具有一个额外的初始尺寸。

在任何情况下,如果您可以结合以上方法来查找有关您感兴趣的一个op名称的所有op和内核信息:

def get_op_info(op_name):
    from tensorflow.python.framework import ops
    from tensorflow.python.framework import kernels
    reg_ops = ops.op_def_registry.get_registered_ops()
    op_def = reg_ops[op_name]
    op_kernels = list(kernels.get_registered_kernels_for_op(op_name).kernel)
    return op_def, op_kernels

# Get MatMul information
matmul_def, matmul_kernels = get_op_info('MatMul')

# Print op definition
print(matmul_def)
# name: "MatMul"
# input_arg {
#   name: "a"
#   type_attr: "T"
# }
# input_arg {
#   name: "b"
#   type_attr: "T"
# }
# output_arg {
#   name: "product"
#   type_attr: "T"
# }
# attr {
#   name: "transpose_a"
#   type: "bool"
#   default_value {
#     b: false
#   }
# }
# attr {
#   name: "transpose_b"
#   type: "bool"
#   default_value {
#     b: false
#   }
# }
# attr {
#   name: "T"
#   type: "type"
#   allowed_values {
#     list {
#       type: DT_BFLOAT16
#       type: DT_HALF
#       type: DT_FLOAT
#       type: DT_DOUBLE
#       type: DT_INT32
#       type: DT_COMPLEX64
#       type: DT_COMPLEX128
#     }
#   }
# }

# Total number of matrix multiplication kernels
print(len(matmul_kernels))
# 24

# Print one kernel definition
print(matmul_kernels[0])
# op: "MatMul"
# device_type: "CPU"
# constraint {
#   name: "T"
#   allowed_values {
#     list {
#       type: DT_FLOAT
#     }
#   }
# }