这种类型的损失函数在tensorflow中是否可以优化?

时间:2017-03-17 15:37:52

标签: python tensorflow

通常,损失函数采用平方diff或其他一些应用于预测和实际值的操作。我尝试使用预测值来提出一个输入版本,然后通过输入进行差分输出以得出损失函数。由于预测不直接用于计算损失函数,优化器是否会知道要更新的内容?

这是输入 -

at org.apache.jasper.servlet.JspServletWrapper.handleJspException(JspServletWrapper.java:568)[apache-jsp-8.0.9.M3.jar:2.3]
at org.apache.jasper.servlet.JspServletWrapper.service(JspServletWrapper.java:470)[apache-jsp-8.0.9.M3.jar:2.3]
at org.apache.jasper.servlet.JspServlet.serviceJspFile(JspServlet.java:405)[apache-jsp-8.0.9.M3.jar:2.3]
at org.apache.jasper.servlet.JspServlet.service(JspServlet.java:349)[apache-jsp-8.0.9.M3.jar:2.3]
at org.eclipse.jetty.jsp.JettyJspServlet.service(JettyJspServlet.java:107)[apache-jsp-9.2.13.v20150730.jar:9.2.13.v20150730]
at javax.servlet.http.HttpServlet.service(HttpServlet.java:729)[tomcat-servlet-api-8.0.24.jar:]
at org.eclipse.jetty.servlet.ServletHolder.handle(ServletHolder.java:808)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.servlet.ServletHandler$CachedChain.doFilter(ServletHandler.java:1669)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:99)[spring-web-4.1.9.RELEASE.jar:4.1.9.RELEASE]
at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)[spring-web-4.1.9.RELEASE.jar:4.1.9.RELEASE]
at org.eclipse.jetty.servlet.ServletHandler$CachedChain.doFilter(ServletHandler.java:1652)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.apache.activemq.web.SessionFilter.doFilter(SessionFilter.java:45)[activemq-web-5.13.4.jar:5.13.4]
at org.eclipse.jetty.servlet.ServletHandler$CachedChain.doFilter(ServletHandler.java:1652)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.apache.activemq.web.filter.ApplicationContextFilter.doFilter(ApplicationContextFilter.java:102)[file:/home/openkm/activemq-5.13.4/webapps/admin/WEB-INF/classes/:]
at org.eclipse.jetty.servlet.ServletHandler$CachedChain.doFilter(ServletHandler.java:1652)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:585)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:143)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.security.SecurityHandler.handle(SecurityHandler.java:542)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.server.session.SessionHandler.doHandle(SessionHandler.java:223)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.server.handler.ContextHandler.doHandle(ContextHandler.java:1127)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.servlet.ServletHandler.doScope(ServletHandler.java:515)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.server.session.SessionHandler.doScope(SessionHandler.java:185)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.server.handler.ContextHandler.doScope(ContextHandler.java:1061)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:141)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.server.handler.HandlerCollection.handle(HandlerCollection.java:110)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.security.SecurityHandler.handle(SecurityHandler.java:542)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.server.handler.HandlerCollection.handle(HandlerCollection.java:110)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:97)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.server.Server.handle(Server.java:499)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.server.HttpChannel.handle(HttpChannel.java:310)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.server.HttpConnection.onFillable(HttpConnection.java:257)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.io.AbstractConnection$2.run(AbstractConnection.java:540)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:635)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at org.eclipse.jetty.util.thread.QueuedThreadPool$3.run(QueuedThreadPool.java:555)[jetty-all-9.2.13.v20150730.jar:9.2.13.v20150730]
at java.lang.Thread.run(Thread.java:745)[:1.8.0_121]

这是与输入相同形状的另一个占位符 - 也是 -

x = tf.placeholder(tf.float32, shape=[None, 150 * 150 * 3])

预测(删除了中间操作) -

x_out = tf.placeholder(tf.float32, shape=[None, 150 * 150 * 3])

基本上,我们提供带有x_out预测的while循环,循环以各种方式应用预测以返回另一个版本的x_out(此处名为xout2) -

W_fc_idx = weight_variable([4, 100])
b_fc_idx = bias_variable([100])
y_conv_idx = tf.matmul(h_fc1_drop, W_fc_idx) + b_fc_idx
y_conv_idx = tf.nn.relu(y_conv_idx)

W_fc_len = weight_variable([4, 300])
b_fc_len = bias_variable([300])
y_conv_len = tf.matmul(h_fc1_drop, W_fc_len) + b_fc_len
y_conv_len = tf.nn.relu(y_conv_len)

def condition(outeri, y_idx, y_len, y_val, xo): return outeri < 1000
def body(outeri, y_idx, y_len, y_val, xo):
    #Various operations that use the y_conv predictions to create and return tensor of same shape as x_out
    topcorner = tf.reduce_sum(tf.to_int64(y_idx[0, outeri]))
    nesti = tf.constant(0, tf.int64)
    while_condition2 = lambda nesti, xo: tf.less(nesti, tf.to_int64(y_len[0, outeri]))
    def body2(nesti, xo):
        shape = [1, 150 * 150 * 3]
        for thickness in range(1):
            index = [[0, topcorner + nesti + (thickness * 450)]]
            value = tf.reshape(tf.reduce_sum(y_val[0, outeri]), shape=[1])#[1.5]  # [[0]]
            delta = tf.SparseTensor(index, value, shape)
            xo = xo - tf.sparse_tensor_to_dense(delta)
        return [tf.add(nesti, 1), xo]
    _, xo = tf.while_loop(while_condition2, body2, [nesti, xo])
    #A few more inner/nested while loops in here
    return xo

最后,这是损失函数 -

_, _, _, _, xout2 = tf.while_loop(condition, body, [outeri, y_conv_idx, y_conv_len, y_conv_val, x_out])

现在,我的问题是 - 由于损失不是直接从预测或图中的任何变量计算的(它是从使用后台预测的while循环返回的张量计算的),优化函数是否会还知道要更新哪些权重等?

注意:我已经实现了这一点,并且各种学习率的损失根本没有变化,所以我想知道它可能无法正常工作,因此问题。

0 个答案:

没有答案