用于训练无效类型tf.int32

时间:2018-06-21 15:32:57

标签: python tensorflow training-data

我正在做实验,通过实现tf.casttf.py_functf.RegisterGradient函数提供渐变,如下所示:

# Define custom py_func which takes also a grad op as argument:
def py_func(func, inp, Tout, stateful=True, name=None, grad=None):

    # Need to generate a unique name to avoid duplicates:
    rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))

    tf.RegisterGradient(rnd_name)(grad)  
    g = tf.get_default_graph()
    with g.gradient_override_map({"PyFunc": rnd_name}):
        return tf.py_func(func, inp, Tout, stateful=stateful, name=name)

def castInt32(x, name=None):
    with ops.name_scope(name, "CastInt32", [x]) as name:
        sqr_x = py_func(np.int32,
                        [x],
                        [tf.int32],
                        name=name,
                        grad=_castInt32)  # <-- here's the call to the gradient
        return sqr_x[0]

# Actual gradient:
def _castInt32(op, grad):
    t = [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]
    src_type = op.inputs[0].dtype.base_dtype
    dst_type = grad.dtype.base_dtype

    if src_type in t and dst_type in t:
        return math_ops.cast(grad, src_type)
    else:
        return None

with tf.Session() as session:
    #session.run(init)
    x = tf.constant([20., 1.])
    y = castInt32(x)
    z = tf.gradients(y,x)[0]
    print z

它输出[1. , 1.]作为输出。但是,当我使用这些功能来训练简单的网络时,执行optimizer.compute_gradient时会出现错误。

错误显示如下:

ValueError: Invalid type tf.int32 for Cast:0, expected: [tf.float32, tf.float64, tf.float16, tf.bfloat16].

如果是因为渐变的数据类型,则该数据类型已经在tf.float32中。

有人可以帮助我了解这种情况吗?预先感谢。

0 个答案:

没有答案