TensorFlow,什么时候可以使用类似Python的负索引?

时间:2017-08-01 10:00:49

标签: python numpy tensorflow

我是TensorFlow(版本1.2)的新手,但不是Python或Numpy。我正在建立一个模型来预测蛋白质分子的形状。我需要在一些额外的代码中包含TensorFlow的标准tf.losses.cosine_distance函数,因为我需要停止将一些NaN值传播到损失计算中。

我确切地知道哪些细胞是NaN。无论我的机器学习系统预测哪些细胞都不计算在内。我计划在总结损失函数之前将tf.losses.cosine_distance输出的NaN部分转换为零。

这里是一段工作代码,使用tf.scatter_nd_update进行元素分配:

def custom_distance(predict, actual):
    with tf.name_scope("CustomDistance"):
        loss = tf.losses.cosine_distance(predict, actual, -1, 
               reduction=tf.losses.Reduction.NONE)
        loss = tf.Variable(loss) # individual elements can be modified
        indices = tf.constant([[0,0,0],[29,1,0],[29,2,0]])
        updates = tf.constant([0., 0., 0.])
        loss = tf.scatter_nd_update(loss, indices, updates)
        return loss

但是,这只适用于我所拥有的一种长度为30个氨基酸的蛋白质。如果我有不同长度的蛋白质怎么办?我会有很多。  在Numpy中,我只使用Python的负索引,并将-1' s替换为索引行上的两个29。 Tensorflow不接受这一点。如果我进行替换,我会得到一个很长的追溯,但我认为最重要的部分是:

File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
    pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Invalid indices: [0,1] = [-1, 1, 0] is not in [0, 30)

(我还可以修改预测 Tensor,以便在计算损失之前,所讨论的单元格与实际 Tensor完全匹配,但在每种情况下挑战都是相同的:分配TensorFlow对象中各个元素的值。)

我应该忘记TensorFlow中的负索引吗?我正在通过TensorFlow文档仔细研究这个问题的正确方法。我假设我可以检索我的输入Tensors长度的主轴并使用它。但在看到TensorFlow和Numpy之间的强烈相似之后,我不禁要问这是不是很笨拙。

感谢您的建议。

1 个答案:

答案 0 :(得分:0)

它可以与tensorflow绑定到python切片运算符一起使用。例如,loss[-1]loss的有效切片。

在您的情况下,如果您只有三个切片,则可以单独指定它们:

update_op0 = indices[0,0,0].assign(updates[0])
update_op1 = indices[-1,1,0].assign(updates[1])
update_op2 = indices[-1,2,0].assign(updates[2])

如果你有更多的切片,或者切片数量可变,那么之前的方法是不实际的。你可以写一个像这样的小辅助函数来将“正或负指数”转换为“仅正指数”:

def to_pos_idx(idx, x):
  # todo: shape & bound checking
  idx = tf.convert_to_tensor(idx)
  s = tf.shape(x)[:tf.size(idx)]
  idx = tf.where(idx < 0, s + idx, idx)
  return idx

并修改你的代码:

indices = tf.constant([[0,0,0],[-1,1,0],[-1,2,0]])
indices = tf.map_fn(lambda i: to_pos_idx(i, loss), indices) # transform indices here
updates = tf.constant([0., 0., 0.])
loss = tf.scatter_nd_update(loss, indices, updates)