keras软复用器层

时间:2017-10-01 13:45:49

标签: neural-network keras

我是Keras(和NN)的新手,我的问题可能很简单。然而,我无法弄清楚如何在Keras中实现以下层:

它应该有3个输入:2D,0D,0D(矩阵,标量,标量)。

图层应该在第二个和第三个参数定义的位置返回第一个参数的元素。因此,如果输入是(m,i,j),它应该返回m [i,j]。如果pair(i,j)“在元素之间命中”(例如i = 2.5和j = 3.7),它应该返回由(i,j)定义的点周围的元素的线性近似。

该函数对于m,i和j的元素是可微分的(至少足以使Keras可微分),因此定义NN层应该没问题。

1 个答案:

答案 0 :(得分:0)

我们可以尝试以下功能,我们将传递给Lambda图层:

from keras.layers import *
import keras.backend as K
from keras.models import Model

import tensorflow as tf

def getValue(x):

    #x is a group with 3 tensors here. m, i and j
    m = x[0]
    i = x[1]
    j = x[2] 

    #let's take the i and j as integers around the actual point:
    #as well as the distances between them and the float indices
    lowI, distI = getLowIndexAndDistance(i)
    lowJ, distJ = getLowIndexAndDistance(j)

    #higher indices
    highI = lowI + 1 
    highJ = lowJ + 1 
        #in the special case when these high values are exatly equal to the 
        #unrounded ones, the distance below will be 0 and this index will be discarded

    #special case when i or j is exactly the maximum index
    mShape = K.shape(m)
    highI = highI - K.cast(K.equal(highI,mShape[1]),'int32')
    highJ = highJ - K.cast(K.equal(highJ,mShape[2]),'int32')


    #interpolations  
    valILeft = getInterpolated(getValueFromM(m,lowI,lowJ),
                               getValueFromM(m,highI,lowJ), 
                               distI)
    valIRight = getInterpolated(getValueFromM(m,lowI,highJ), 
                                getValueFromM(m,highI,highJ),
                                distI)   

    return getInterpolated(valILeft,valIRight,distJ)


#function to get the index rounded down
    #unfortunately I couldn't find K.floor() or a similar function
def getLowIndexAndDistance(i):

    #getting the closest round number 
    roundI = K.round(i) 

    #comparisons to check wheter the rounded index is greater than i
    isGreater = K.cast(K.greater(roundI,i),K.floatx())
        #1 if true, 0 if false

    #if greater, let's take one number below:
    lowI = roundI - isGreater 

    #returns the integer lowI and the distance between i and lowI
    return K.cast(lowI,'int32'), i - lowI




#function to get interpolated values
def getInterpolated(val1, val2, distanceFromLowI):

    valRange = val2 - val1
    #span = 1

    return val1 + (valRange * distanceFromLowI)


def getEntireIndexMatrix(i,j):

    batchIndex = K.ones_like(i)
    batchIndex = K.cumsum(batchIndex) - 1 #equivalent to range(batch)

    #warning, i and j must be (?,1), if they're reduced, the results will be weird. 
    return K.stack([batchIndex,i,j],axis=-1) 

        #this is a matrix of indices from which to get values in m
        #the first element in the last axis is the batch index   
        #the second element is I
        #the third is J


def getValueFromM(m, i, j):

    indexMatrix = getEntireIndexMatrix(i,j)

    #tensorflow is an easy solution kere. Keras doesn't have this available,
    #but there may be a workaround using K.gather 3 times, one for each dimension
    return tf.gather_nd(m, indexMatrix)

在非常基本的模型中测试

m = Input((5,5))
i = Input((1,))
j = Input((1,))

out = Lambda(getValue, output_shape=(1,))([m,i,j])

model = Model([m,i,j],out)

mVals = np.asarray(range(75)).reshape((3,5,5))
#iVals = np.asarray([[4],[2.3],[4]]) #for special cases
#jVals = np.asarray([[4],[4],[1.7]]) #for special cases
iVals = np.random.uniform(0,4,(3,1)) #for all cases
jVals = np.random.uniform(0,4,(3,1)) #for all cases

print(mVals)
print(iVals)
print(jVals)

print(model.predict([mVals,iVals,jVals]))