我有一个包含三个总和的损失函数 如下:
sum for each j in words:
sum for each i in documents:
sum for each r in Tj
我实现这个问题的原因是Tj是一个列表列表,其中每个列表的长度不同。 我试图将它作为一个numpy函数实现,然后使用tf.py_func,但现在我不知道如何优化它。
def loss_fun(COM,Wj,Wi,B,R,Tj):
def loss_factor(j):
L = R.shape[0]
sum_container = []
for l in range(L):
third_sum = 0
for i in range(len(Tj[j])):
entity_value = Tj[j][i]
wi = np.reshape(Wi[entity_value], [1, 300])
Rl = R[l,:,:]
wj = np.reshape(Wj[j,:], [1, 300])
wiRl = tf.matmul(wi, Rl)
wjRl = tf.matmul(wj, Rl)
wiRlxwjRl = np.matmul(wiRl, np.transpose(wjRl))
pmi = COM[entity_value, j]
bj = B[j,0]
distance_expr = np.square( (wiRlxwjRl + bj - pmi) )
third_sum = third_sum + distance_expr
sum_container.append(third_sum)
return(np.amin(sum_container) / len(Tj[j]))
L = R.shape[0]
M = B.shape[0]
p = -2
first_sum = 0
second_sum = 0
third_sum = 0
for j in range(M):
sigma = loss_factor(j)
sigma = 1/sigma
second_sum = 0
for l in range(L):
third_sum = 0
for i in range(len(Tj[j])):
entity_value = Tj[j][i]
wi = np.reshape(Wi[entity_value], [1, 300])
Rl = R[l,:,:]
wj = np.reshape(Wj[j,:], [1, 300])
wiRl = tf.matmul(wi, Rl)
wjRl = tf.matmul(wj, Rl)
wiRlxwjRl = np.matmul(wiRl, np.transpose(wjRl))
pmi = COM[entity_value, j]
bj = B[j,0]
distance_expr = np.square( (wiRlxwjRl + bj - pmi) )
third_sum = third_sum + distance_expr
second_sum = second_sum + np.power(third_sum, p)
first_sum = first_sum + sigma * second_sum
nuc_norm = 0
lambdaa = 0.1
for l in range(L):
Rl = R[l,:,:]
nuc_norm = nuc_norm + np.linalg.norm(Rl, ord = 'nuc')
nuc_norm = lambdaa * nuc_norm
return (first_sum + nuc_norm)
######损失函数结束
K = 50
COM = tf.constant(co_occurrence_matrix, name = 'COM')
Tj = tf.constant(t_j, name = 'Tj')
Wj = tf.Variable(Wj, name = 'Wi')
Wi = tf.Variable(tf.random_uniform([len(entities_list), 300], 1.0, -1.0), name = 'Wi')
B = tf.Variable(tf.random_uniform([len(features_list)-1, 1], 1.0, -1.0), name = 'B')
R = tf.Variable(tf.random_normal(shape = (K, 300, 300)), name = 'R')
loss_val = tf.py_func(loss_fun, [COM, Wj, Wi, B, R, Tj], tf.float32)
optimizer = tf.train.AdagradOptimizer(0.01).minimize(loss_val)
我会欣赏任何想法