这是我想用tensorflow实现f(x)
输入x =(x1,x2,x3,x4,x5,x6,x7,x8,x9)
定义f(x)= f1(x1,x2,x3,x4,x5)+ f2(x5,x6,x7,x8,x9)
,其中
f1(x1,x2,x3,x4,x5)= {1 if (X1,X2,X3,X4,X5)=(0,0,0,0,0),
g1(x1,x2,x3,x4,x5) otherwise}
f2(x5,x6,x7,x8,x9)= {1 if (X5,X6,X7,X8,X9)=(0,0,0,0,0),
g2(x5,x6,x7,x8,x9) otherwise}
import tensorflow as tf
import numpy as np
ph = tf.placeholder(dtype=tf.float32, shape=[None, 9])
x1 = tf.slice(ph, [0, 0], [-1, 5])
x2 = tf.slice(ph, [0, 4], [-1, 5])
fixed1 = tf.placeholder(dtype=tf.float32, shape=[1, 5])
fixed2 = tf.placeholder(dtype=tf.float32, shape=[1, 5])
# MLP 1
w1 = tf.Variable(tf.ones([5, 1]))
g1 = tf.matmul(x1, w1)
# MLP 2
w2 = tf.Variable(-tf.ones([5, 1]))
g2 = tf.matmul(x2, w2)
check1 = tf.reduce_all(tf.equal(x1, fixed1), axis=1, keep_dims=True)
check2 = tf.reduce_all(tf.equal(x2, fixed2), axis=1, keep_dims=True)
#### with Problem
f1 = tf.cond(check1,
lambda: tf.constant([2], dtype=tf.float32), lambda: g1)
f2 = tf.cond(check2,
lambda: tf.constant([1], dtype=tf.float32), lambda: g2)
####
f = tf.add(f1, f2)
x = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 0, 0, 0],
[2, 0, 0, 0, 0, 0, 0, 0, 0],
[9, 0, 0, 0, 0, 0, 0, 0, 0]])
fixed = np.array([[0, 0, 0, 0, 0]])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print('(1)\n', sess.run(check1, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))
print('(2)\n', sess.run(check2, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))
print('(3)\n', sess.run(f, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))
print('(4)\n', sess.run(f1, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))
print('(5)\n', sess.run(f2, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))
在这种情况下,
check1为[[True],[True],[False],[False],[False]],形状为(5,1)
check2是[[True],[False],[True],[True],[True]],形状为(5,1)
我希望f的结果是[[3],[1],[2],[3],[10]]
但似乎tf.cond()无法将输入处理为具有形状(5,1)的布尔张量
请问如何使用tensorflow建议如何实现f(x)。
这是我收到的错误消息
Traceback(最近一次调用最后一次):文件 “C:\用户\香港\应用程序数据\本地\连续\ Anaconda3 \ LIB \站点包\ tensorflow \ python的\框架\ common_shapes.py” 第670行,在_call_cpp_shape_fn_impl中 status)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ contextlib.py”, 第66行,在退出 next(self.gen)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ errors_impl.py”, 第469行,在raise_exception_on_not_ok_status中 pywrap_tensorflow.TF_GetCode(status))tensorflow.python.framework.errors_impl.InvalidArgumentError:Shape 必须是0级,但是'cond / Switch'(op:'Switch')的排名是2 输入形状:[?,1],[?,1]。
在处理上述异常期间,发生了另一个异常:
Traceback(最近一次调用最后一次):文件 “C:/Users/hong/Dropbox/MLILAB/Research/GM-MLP/code/tensorflow_cond.py” 第23行,在 lambda:tf.constant([2],dtype = tf.float32),lambda:g1)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ ops \ control_flow_ops.py” 第1765行,在cond p_2,p_1 = switch(pred,pred)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ ops \ control_flow_ops.py”, 318行,在开关中 return gen_control_flow_ops._switch(data,pred,name = name)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ ops \ gen_control_flow_ops.py”, 第368行,在_switch result = _op_def_lib.apply_op(“Switch”,data = data,pred = pred,name = name)文件 “C:\用户\香港\应用程序数据\本地\连续\ Anaconda3 \ LIB \站点包\ tensorflow \ python的\框架\ op_def_library.py” 第759行,在apply_op中 op_def = op_def)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ ops.py”, 第2242行,在create_op中 set_shapes_for_outputs(ret)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ ops.py”, 第1617行,在set_shapes_for_outputs中 shapes = shape_func(op)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ ops.py”, 第1568行,在call_with_requiring中 return call_cpp_shape_fn(op,require_shape_fn = True)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ common_shapes.py”, 第610行,在call_cpp_shape_fn中 debug_python_shape_fn,require_shape_fn)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ common_shapes.py”, 第675行,在_call_cpp_shape_fn_impl中 raise ValueError(err.message)ValueError:Shape必须为0级,但对于'cond / Switch'(op:'Switch')的排名为2,输入形状为:[?,1], [?,1]。
使用退出代码1完成处理