theano是否有类似于python字典的东西,可以使用符号键进行迭代?

时间:2015-08-03 16:33:15

标签: theano

基本上,以下玩具问题是否可能?:

def gpu(arg,name=None, _type=None):
    """Helper function to make variables shared"""
    aux = shared(np.array(arg,dtype=theano.config.floatX), name=name)
    if _type is None:
        return aux
    else:
        return T.cast(aux, _type)

init = [0 for each in range(10)] # weight initialization value
classes = list(range(1,4+1)) # I have four classes
weights = dict(zip(classes, [gpu(init, name=('w'+str(w))) for w in classes])) 

# Based on prior information I sometimes know a dataset contains
# no knowledge about certain classes, so in these cases I don't want
# want to "try" to learn(nothing) 
dont_care_about_these = [1,3] # I don't want to try to learn these classes
no_w3 = [w for w in classes if w not in dont_care_about_these] 

i_seq = T.vector(dtype='int8') 
outputs, updates = theano.scan(lambda i: ({weights[i]: weights[i]+1}), 
                               sequences=i_seq)

f = function(inputs=[i_seq], updates=updates) 
f(no_w3)

print(weights['w1'].get_value()) # Should still only contain zeros
print(weights['w2'].get_value()) # Should be updated

这里的问题是我已经为扫描更新插入了一个字典,然后失败,因为它不喜欢theano.tensors作为键。

我很好奇的原因是,有时基于先前的信息,我有时会知道数据集不包含某些类的知识,因此在这些情况下我不想“尝试”学习(当我注定时)什么都不学。)

编辑:不是真正的答案,而是一种解决方法。另请参阅下面的DrBwts的评论

def gpu(arg,name=None, _type=None):
    """Helper function to make variables shared"""
    aux = shared(np.array(arg,dtype=theano.config.floatX), name=name)
    if _type is None:
        return aux
    else:
        return T.cast(aux, _type)

init = [0. for each in range(10)] # weight initialization value
classes = ['w'+str(i) for i in range(1,4+1)]
weights = dict(zip(classes, [init]*len(classes)))

dont_care_about_these = ['w1','w3'] # I don't want to try to learn these classes
picked = [w for w in classes if w not in dont_care_about_these] 
w_gpu = gpu([weights[key] for key in picked])
def replace_picked(w_gpu, picked):
    updated_weights = w_gpu.get_value()
    for i, key in enumerate(picked):
        weights[key] = updated_weights[i]

i_seq = T.vector(dtype='int8') 
outputs, updates = theano.scan(lambda i: ({w_gpu: T.inc_subtensor(w_gpu[i], 1.)}), 
                               sequences=i_seq)

f = function(inputs=[i_seq], updates=updates) 

f(range(len(picked)))
replace_picked(w_gpu,picked)
print(weights['w1'])
print(weights['w2'])

0 个答案:

没有答案