如何使用带卷积的tf.where?

时间:2018-03-08 09:15:33

标签: python tensorflow

我想创建一个图表,根据分类结果在某个点之后分成几个其他图形。我认为tf.condtf.where可能是正确的,但我不确定如何。

这里不可能复制我的所有代码,但我创建了一个小段来说明问题。

import os
import sys
import tensorflow as tf
GPU_INDEX = 2

net_class = tf.constant([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1],[0.2, 0.4, 0.3, 0.1], [0.3, 0.2, 0.4, 0.1],[0.1, 0.3, 0.3, 0.4]]) # 3,0,1,2,3
classes = tf.argmax(net_class, axis=1)
cls_0_idx = tf.squeeze(tf.where(tf.equal(classes, 0)))
cls_3_idx = tf.squeeze(tf.where(tf.equal(classes, 3)))

cls_0 = tf.gather(params=net_class, indices=cls_0_idx)
cls_3 = tf.gather(params=net_class, indices=cls_3_idx)

params_0 = tf.constant([1.0,1,1,1])
params_3 = tf.constant([3.0,3,3,3])


output = tf.stack([tf.nn.conv1d(cls_0, params_0, 1,  padding='VALID'), tf.nn.conv1d(cls_3, params_3, 1,  padding='VALID')])

sess = tf.Session()
cls_0_idx_val = sess.run(output)

print(output)

这里我尝试提取分类为0或3的输入索引,并使用不同的变量将它们乘以输出(每个类的共享权重,这就是我使用卷积的原因)。

我收到以下错误:

ValueError: Shape must be rank 4 but is rank 2 for 'conv1d/Conv2D' (op: 'Conv2D') with input shapes: ?, [1,4].

我理解为什么会收到错误(因为tf.where并不“知道”它的大小)但问题是如何解决? (这些课程不平等,甚至在我的“真实”问题中也可能是空的)

1 个答案:

答案 0 :(得分:1)

我认为你应该

  1. axis

  2. 中将1设置为tf.squeeze
  3. tf.nn.conv1d更改为简单乘法

  4. tf.stack更改为tf.concat

  5. 然后你会有这样的事情:

    net_class = tf.constant([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1],[0.2, 0.4, 0.3, 0.1], [0.3, 0.2, 0.4, 0.1],[0.1, 0.3, 0.3, 0.4]]) # 3,0,1,2,3
    classes = tf.argmax(net_class, axis=1)
    cls_0_idx = tf.squeeze(tf.where(tf.equal(classes, 0)), -1)
    cls_3_idx = tf.squeeze(tf.where(tf.equal(classes, 3)), -1)
    
    cls_0 = tf.gather(params=net_class, indices=cls_0_idx)
    cls_3 = tf.gather(params=net_class, indices=cls_3_idx)
    
    params_0 = tf.constant([1.0,1,1,1])
    params_3 = tf.constant([3.0,3,3,3])
    output = tf.concat([cls_0 * params_0, cls_3 * params_3], axis = 0)