
时间:2018-02-21 20:55:11

标签: python tensorflow





# <make X* and alist* inputs placeholders>

# forwad through the layer-networks
H1 = network_fwd(X0, alist0, var_scope=SCOPE.format(0)) # params_0/*
H2 = network_fwd(H1, alist1, var_scope=SCOPE.format(1)) # params_1/*
H3 = network_fwd(H2, alist2, var_scope=SCOPE.format(2))
# ...
H10 = network_fwd(H9, alist9, var_scope=SCOPE.format(9))

# optimize on sum network error
error = loss_fn(H1, X1)
error += loss_fn(H2, X2)
error += loss_fn(H3, X3)
error += loss_fn(H10, X10)

train = tf.train.AdamOptimizer(learning_rate).minimize(error)

# train
for step in range(num_steps):
    x_batch = numpy_ops.next_minibatch(X_train, mb_size) # list of data inputs

    # fill placeholders for true inputs X0, X1, ..., X10
    fdict = {placeholder_name: x for placeholder_name, x in zip(pnames, x_batch)}

    # layer-network predictions and adj lists
    adj_list0 = numpy_ops.get_adjacency_list(x_batch[0], K)
    fdict[alist0] = adj_list0
    h1 = sess.run(H1, feed_dict=fdict)
    fdict[alist1] = numpy_ops.get_adjacency_list(h1, K)
    h2 = sess.run(H2, feed_dict=fdict)
    fdict[alist2] = numpy_ops.get_adjacency_list(h2, K)
    # ...

    # and finally the actual training pass




1 个答案:

答案 0 :(得分:0)



# <initialize all child network variables>

# make adjacency list function interface for tf.py_func
def alist_func(h_in):
    """ Given constraints on the input to the func arg in tf.py_func,
        you may need to interface the numpy function if it's signature
        has other args or kwargs
    return numpy_ops.get_adjacency_list(h_in, K)

# direct graph
data_in_shape = (11, None, num_points, D)
X_input = tf.placeholder(tf.float32, shape=data_in_shape, name='X_input')
X_pred, loss = tensor_ops.multi_func_model_fwd(X_input, var_scopes, alist_func)

# optimizer
train = tf.train.AdamOptimizer(learning_rate).minimize(loss)

# train
for step in range(num_steps):
    x_batch = numpy_ops.next_minibatch(X_train, mb_size) # list of data inputs
    train.run(feed_dict={X_input: x_batch})

# in tensor_ops script, this is how I use tf.py_func
def meta_model_fwd(x_in, var_scopes, alist_fn, *args, **kwargs):
    alist = tf.py_func(alist_fn, [x_in[0]], tf.int32) # inp is float32, out is int32
    h = model_fwd(x_in[0], alist, var_scopes[0], *args, **kwargs)
    loss = loss_fn(h, x_in[1])
    for idx, vscope in enumerate(var_scopes[1:]):
        alist = tf.py_func(alist_fn, [h], tf.int32)
        h = model_fwd(h, alist, vscope, *args, **kwargs)
        loss += loss_fn(h, x_in[idx+1])
    return h, loss
