Tensorflow实施阈值激活

时间:2018-08-28 16:15:04

标签: tensorflow

import tensorflow as tf

def return_1():
    return 1

def return_0():
    return 0

w = tf.Variable([1.0, 1.0], tf.float32)
b = tf.Variable(1.0, tf.float32)

x = tf.placeholder(tf.float32)
u = tf.tensordot(x, w, axes = 1) + b
y = tf.cond(u > 0, lambda: return_1, lambda: return_0)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    U, Y = sess.run([u, y], feed_dict={x: [0.0, 0.0]})
    print(Y)

我有一个简单的神经元。在x点w +偏斜b之后,我想通过将函数传递给输出来评估输出,如果u> 0,则y应该为1,否则y为0。上面的代码给了我

Traceback (most recent call last):
  File "chapter1example2.py", line 14, in <module>
    y = tf.cond(u > 0, lambda: return_1, lambda: return_0)
  File "/Users/pcdessy/anaconda3/envs/tensorflowfyp/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 454, in new_func
    return func(*args, **kwargs)
  File "/Users/pcdessy/anaconda3/envs/tensorflowfyp/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2048, in cond
    orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
  File "/Users/pcdessy/anaconda3/envs/tensorflowfyp/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1910, in BuildCondBranch
    result = nest.map_structure(self._BuildCondTensor, original_result)
  File "/Users/pcdessy/anaconda3/envs/tensorflowfyp/lib/python3.6/site-packages/tensorflow/python/util/nest.py", line 377, in map_structure
    structure[0], [func(*x) for x in entries])
  File "/Users/pcdessy/anaconda3/envs/tensorflowfyp/lib/python3.6/site-packages/tensorflow/python/util/nest.py", line 377, in <listcomp>
    structure[0], [func(*x) for x in entries])
  File "/Users/pcdessy/anaconda3/envs/tensorflowfyp/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1890, in _BuildCondTensor
    return self._ProcessOutputTensor(ops.convert_to_tensor(v))
  File "/Users/pcdessy/anaconda3/envs/tensorflowfyp/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 998, in convert_to_tensor
    as_ref=False)
  File "/Users/pcdessy/anaconda3/envs/tensorflowfyp/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1094, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/Users/pcdessy/anaconda3/envs/tensorflowfyp/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 217, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
  File "/Users/pcdessy/anaconda3/envs/tensorflowfyp/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 196, in constant
    value, dtype=dtype, shape=shape, verify_shape=verify_shape))
  File "/Users/pcdessy/anaconda3/envs/tensorflowfyp/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 525, in make_tensor_proto
    "supported type." % (type(values), values))
TypeError: Failed to convert object of type <class 'function'> to Tensor. Contents: <function return_1 at 0x109206e18>. Consider casting elements to a supported type.

我该怎么做?谢谢

1 个答案:

答案 0 :(得分:0)

您忘记了lambda表达式中的()

这是正确的代码:

import tensorflow as tf

def return_1():
    return 1

def return_0():
    return 0

w = tf.Variable([1.0, 1.0], tf.float32)
b = tf.Variable(1.0, tf.float32)

x = tf.placeholder(tf.float32)
u = tf.tensordot(x, w, axes = 1) + b
y = tf.cond(u > 0, lambda: return_1(), lambda: return_0())

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    U, Y = sess.run([u, y], feed_dict={x: [0.0, 0.0]})
    print(Y)