model.fit在Keras中如何工作?

时间:2018-09-13 09:55:48

标签: tensorflow keras loss-function

我以前的帖子或错误是此one 。 因此,我发现了编写函数的另一种方式,使其与Tensorflow兼容。我对其进行了测试,并且工作正常。 但是,当我想将其集成到keras中时,我做不到。 这是我上一篇文章的解决方案:

graph = tf.Graph()
with graph.as_default():
i = tf.Variable(0)
error = tf.Variable(initial_value=0,dtype=tf.float64)
sol = tf.random_uniform(shape=[10, 36], dtype=tf.float64, 
maxval=1)
error_1 = tf.Variable(initial_value=0,dtype=tf.float64)
final_loss = tf.Variable(0)

def cond(i, sol, error):
    return tf.less(i, 9)
def body(i, sol,error):
    i = tf.add(i, 1)
    print('i',i)
    #sol = tf.add(sol, 1)
    original_reshaped_elem = original_dim* sol[i]
    original_reshaped_elem = tf.reshape(original_reshaped_elem, 
    [DIM,DIM])
    a = tf.reshape(original_reshaped_elem[:,DIM-1], [DIM,1])
    b = tf.reshape(original_reshaped_elem[:,1], [DIM,1])

    original_reshaped_elem = tf.concat 
    ([b,original_reshaped_elem], axis= 1)
    original_reshaped_elem = tf.concat 
    ([original_reshaped_elem,a], axis= 1)

    c= tf.reshape(original_reshaped_elem[DIM-1,:], [1,DIM+2])
    d= tf.reshape(original_reshaped_elem[1,:], [1,DIM+2])
    original_reshaped_elem = tf.concat 
    ([d,original_reshaped_elem],axis=0)
    reshaped_elem_extended = tf.concat 
    ([original_reshaped_elem,c],axis=0)
    print('reshaped shape', reshaped_elem_extended)


    error = 
tf.add(error,tf.norm(tf.norm(reshaped_elem_extended,ord=2,axis=0),ord=2,axis=0))
    error_1 = tf.divide(error, 36)
    return [i, sol, error_1]


with tf.Session(graph=graph) as session:
     tf.global_variables_initializer().run()

result = tf.while_loop(cond, body, [i, sol, error])
final_loss = tf.divide(result[2], 10)
print(final_loss.eval())
print(result[1].eval())

这是我在模型中的称呼方式:

result = tf.while_loop(cond, body, [i, inputs, error])
final_loss = tf.divide(result[2], 10)
vae.add_loss(final_loss)

然后我再次收到此错误

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

所以,我想知道model.fit在keras中如何工作?它会实例化图形吗?我没有找到任何有关其工作原理的清晰文档,因此可以相应地集成损失函数。

0 个答案:

没有答案