Keras:如何访问乘法的特定索引

时间:2018-02-04 18:05:35

标签: python tensorflow neural-network deep-learning keras

我正在构建一个函数,它将来自一个模型分支的输入以特定方式与来自另一个模型分支的输入相乘,但是访问张量的特定部分并不是我所期望的。

最小示例:想象一下,我们得到两个张量,其中一个包含[1, 2]而另一个[10, 20, 30],其中一个输出应该是[1] x [10, 20, 30],取第一个张量的第一个值

如果我从这样的变量开始:

import keras.backend as K
import numpy as np
from keras.layers import Multiply    
x = K.variable(value=np.array([1,2]))
y = K.variable(value=np.array([[10,20,30]]))

然后我可以轻松地访问x [0]:

print(K.eval(x[0]))

给出:1.0

但似乎相同的索引不适用于Multiply,就像这段代码一样:

z = Multiply()([x[0], y])

生成:

  

IndexError:元组索引超出范围

因此问题是:如何在keras中的Multiply层中访问特定的值索引(或者我还能如何做等效的?)

1 个答案:

答案 0 :(得分:1)

只是为了向您展示如何实现您想要的目标。我们假设我们有两个输入:

input_1 = Input(shape=(2,))
input_2 = Input(shape=(3,))

现在 - 让我们定义以下功能:

def custom_multiply(list_):
    x, y = list_[0], list_[1]
    y = K.reshape(y, (-1, 1, 3)) # (1, 2, 3) -> ((1), (2), (3))
    x = K.reshape(x, (-1, 2, 1)) # (1, 2) -> ((1, 2)) 
    partial_result = K.batch_dot(x, y)
    return K.reshape(partial_result, (-1, 6))

现在 - output = custom_multiply([input_1, input_2])应该按照您的预期行事。在一对[(1, 2), (3, 4, 5)]上调用应返回(3, 4, 5, 6, 8, 10)