如何输入布尔张量到tf.cond()而不仅仅是一个布尔值?

时间:2017-04-04 09:20:34

标签: machine-learning tensorflow neural-network data-science

这是我想用tensorflow实现f(x)

输入x =(x1,x2,x3,x4,x5,x6,x7,x8,x9)

定义f(x)= f1(x1,x2,x3,x4,x5)+ f2(x5,x6,x7,x8,x9)

,其中

  

f1(x1,x2,x3,x4,x5)= {1 if   (X1,X2,X3,X4,X5)=(0,0,0,0,0),

                   g1(x1,x2,x3,x4,x5) otherwise}
     

f2(x5,x6,x7,x8,x9)= {1 if   (X5,X6,X7,X8,X9)=(0,0,0,0,0),

                   g2(x5,x6,x7,x8,x9) otherwise}

这是我的张量流代码

import tensorflow as tf
import numpy as np
ph = tf.placeholder(dtype=tf.float32, shape=[None, 9])
x1 = tf.slice(ph, [0, 0], [-1, 5])
x2 = tf.slice(ph, [0, 4], [-1, 5])
fixed1 = tf.placeholder(dtype=tf.float32, shape=[1, 5])
fixed2 = tf.placeholder(dtype=tf.float32, shape=[1, 5])

# MLP 1
w1 = tf.Variable(tf.ones([5, 1]))
g1 = tf.matmul(x1, w1)

# MLP 2
w2 = tf.Variable(-tf.ones([5, 1]))
g2 = tf.matmul(x2, w2)

check1 = tf.reduce_all(tf.equal(x1, fixed1), axis=1, keep_dims=True)

check2 = tf.reduce_all(tf.equal(x2, fixed2), axis=1, keep_dims=True)

#### with Problem
f1 = tf.cond(check1,
                   lambda: tf.constant([2], dtype=tf.float32), lambda: g1)
f2 = tf.cond(check2,
                   lambda: tf.constant([1], dtype=tf.float32), lambda: g2)
####
f = tf.add(f1, f2)

x = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0, 0, 0, 0, 1],
              [1, 0, 0, 0, 0, 0, 0, 0, 0],
              [2, 0, 0, 0, 0, 0, 0, 0, 0],
              [9, 0, 0, 0, 0, 0, 0, 0, 0]])

fixed = np.array([[0, 0, 0, 0, 0]])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('(1)\n', sess.run(check1, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))
    print('(2)\n', sess.run(check2, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))
    print('(3)\n', sess.run(f, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))
    print('(4)\n', sess.run(f1, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))
    print('(5)\n', sess.run(f2, feed_dict={ph: x, fixed1: fixed, fixed2: fixed}))

在这种情况下,

check1为[[True],[True],[False],[False],[False]],形状为(5,1)

check2是[[True],[False],[True],[True],[True]],形状为(5,1)

我希望f的结果是[[3],[1],[2],[3],[10]]

但似乎tf.cond()无法将输入处理为具有形状(5,1)的布尔张量

请问如何使用tensorflow建议如何实现f(x)。

这是我收到的错误消息

  

Traceback(最近一次调用最后一次):文件   “C:\用户\香港\应用程序数据\本地\连续\ Anaconda3 \ LIB \站点包\ tensorflow \ python的\框架\ common_shapes.py”   第670行,在_call_cpp_shape_fn_impl中       status)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ contextlib.py”,   第66行,在退出       next(self.gen)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ errors_impl.py”,   第469行,在raise_exception_on_not_ok_status中       pywrap_tensorflow.TF_GetCode(status))tensorflow.python.framework.errors_impl.InvalidArgumentError:Shape   必须是0级,但是'cond / Switch'(op:'Switch')的排名是2   输入形状:[?,1],[?,1]。

     

在处理上述异常期间,发生了另一个异常:

     

Traceback(最近一次调用最后一次):文件   “C:/Users/hong/Dropbox/MLILAB/Research/GM-MLP/code/tensorflow_cond.py”   第23行,在       lambda:tf.constant([2],dtype = tf.float32),lambda:g1)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ ops \ control_flow_ops.py”   第1765行,在cond       p_2,p_1 = switch(pred,pred)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ ops \ control_flow_ops.py”,   318行,在开关中       return gen_control_flow_ops._switch(data,pred,name = name)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ ops \ gen_control_flow_ops.py”,   第368行,在_switch       result = _op_def_lib.apply_op(“Switch”,data = data,pred = pred,name = name)文件   “C:\用户\香港\应用程序数据\本地\连续\ Anaconda3 \ LIB \站点包\ tensorflow \ python的\框架\ op_def_library.py”   第759行,在apply_op中       op_def = op_def)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ ops.py”,   第2242行,在create_op中       set_shapes_for_outputs(ret)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ ops.py”,   第1617行,在set_shapes_for_outputs中       shapes = shape_func(op)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ ops.py”,   第1568行,在call_with_requiring中       return call_cpp_shape_fn(op,require_shape_fn = True)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ common_shapes.py”,   第610行,在call_cpp_shape_fn中       debug_python_shape_fn,require_shape_fn)文件“C:\ Users \ hong \ AppData \ Local \ Continuum \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ common_shapes.py”,   第675行,在_call_cpp_shape_fn_impl中       raise ValueError(err.message)ValueError:Shape必须为0级,但对于'cond / Switch'(op:'Switch')的排名为2,输入形状为:[?,1],   [?,1]。

     

使用退出代码1完成处理

1 个答案:

答案 0 :(得分:0)

我认为你需要tf.where,而不是tf.cond。

请参阅此问题的答案:How to use tf.cond for batch processing