与不同的占位符共享相同的变量

时间:2016-07-27 00:31:19

标签: python tensorflow

我想编写脚本来构建Tensorflow图,该图可以与不同大小的tf.placeholder一起使用。我运行以下脚本时得到InvalidArgumentError

# Loading training and validation data
# train data are in x_train, y_train
# validation data are in x_valid, y_valid
...

with tf.variable_scope("model", reuse=None):
    # Build a Graph that computes the logits
    network = Model(
        batch_size=BATCH_SIZE,
        input_dims=INPUT_DIMS,
        seq_length=SEQ_LENGTH,
        num_classes=NUM_CLASSES   !!!!! ----> line 528, in main_seq
    )
    # Calculate loss
    network.loss_op = softmax_seq_loss_by_example(
        logits=network.model, 
        labels=network.targets,
        batch_size=network.batch_size,
        seq_length=network.seq_length
    )
    # Calculate predictions
    network.pred_op = tf.argmax(network.model, 1)

with tf.variable_scope("model", reuse=True):
    # Build a Graph that computes the logits
    network_valid = Model(
        batch_size=1,
        input_dims=INPUT_DIMS,
        seq_length=1,
        num_classes=NUM_CLASSES
    )
    # Calculate loss
    network_valid.loss_op = softmax_seq_loss_by_example(
        logits=network_valid.model, 
        labels=network_valid.targets,
        batch_size=network_valid.batch_size,
        seq_length=network_valid.seq_length
    )
    # Calculate predictions
    network_valid.pred_op = tf.argmax(network_valid.model, 1)

with tf.Session() as sess:
    # Initialize variables in the graph
    sess.run(tf.initialize_all_variables())

    for epoch in xrange(n_epochs):
        # Update parameters and compute loss of training set
        y_true_train, y_pred_train, train_loss, train_duration = \
            run_seq_epoch(
                sess=sess, network=network, 
                inputs=x_train, targets=y_train, 
                train_op=train_op,
                is_train=True
            )

        # Evaluate the model on the validation set
        y_true_val, y_pred_val, valid_loss, valid_duration = \
            run_seq_epoch(
                sess=sess, network=network_valid,     **** ----> replace_location
                inputs=x_valid, targets=x_valid, 
                train_op=tf.no_op(),
                is_train=False   !!!!! ----> line 644, in main_seq
            )

        ...

功能run_seq_epoch和课程Model如下所示:

class Model(object):

    def __init__(self, batch_size, input_dims, seq_length, num_classes):
        self.batch_size = batch_size
        self.input_dims = input_dims
        self.seq_length = seq_length
        self.num_classes = num_classes

        # Operations to compute loss and prediction, which will be assigned
        # after the initialization
        self.loss_op = None
        self.pred_op = None

        # Placeholder for input data
        self.inputs = tf.placeholder(
            tf.float32,
            shape=[batch_size*seq_length, input_dims, 1, 1]
        )
        self.targets = tf.placeholder(
            tf.int32,
            shape=[batch_size*seq_length, ]
        )
        self.is_train = tf.placeholder(tf.bool)

        # Use the defined placeholder above create a model
        ...


def run_seq_epoch(sess, network, inputs, targets, train_op, is_train):

    ...

    # Initial state for LSTM
    state = network.initial_state.eval()
    for x_batch, y_batch in iterate_batch_seq_minibatches(inputs,
                                                          targets,
                                                          network.batch_size,
                                                          network.seq_length):
        feed_dict = {
            network.inputs: x_batch, 
            network.targets: y_batch, 
            network.initial_state: state,
            network.is_train: is_train
        }
        _, loss_value, y_pred, state = sess.run(
            [train_op, network.loss_op, network.pred_op, network.final_state],
            feed_dict=feed_dict   !!!!! ----> line 471, in run_seq_epoch
        )

    ...

    return _y_true, y_pred, loss, duration

以下是networknetwork_valid中占位符的名称和形状,以及x_batchy_batch的形状和类型,用于培训和验证:

# Placeholder of network.inputs and network.targets
model/Placeholder:0: (200, 7680, 1, 1)
model/Placeholder_1:0: (200,)

# x_batch and y_batch for training
x_batch.shape: (200, 7860, 1, 1), dtype=float32
y_batch.shape: (200, ), dtype=int32

# Placeholder of network_valid.inputs and network_valid.targets
model_1/Placeholder:0: (1, 7680, 1, 1)
model_1/Placeholder_1:0: (1,)

# x_batch and y_batch for validation
x_batch.shape: (1, 7860, 1, 1), type=float32
y_batch.shape: (1, ), type=int32

但是,当我运行脚本时,我为tensorflow.python.framework.errors.InvalidArgumentError占位符获得了targets

Traceback (most recent call last):
  File "main.py", line 789, in <module>
    fold_idx=fold_idx
  File "main.py", line 644, in main_seq
    is_train=False
  File "main.py", line 471, in run_seq_epoch
    feed_dict=feed_dict
  File "/home/akara/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 372, in run
    run_metadata_ptr)
  File "/home/akara/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 636, in _run
    feed_dict_string, options, run_metadata)
  File "/home/akara/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 708, in _do_run
    target_list, options, run_metadata)
  File "/home/akara/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 728, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors.InvalidArgumentError: You must feed a value for placeholder tensor 'model/Placeholder_1' with dtype int32 and shape [200]
     [[Node: model/Placeholder_1 = Placeholder[dtype=DT_INT32, shape=[200], _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
     [[Node: model_1/total_loss/_329 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_1326_model_1/total_loss", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
Caused by op u'model/Placeholder_1', defined at:
  File "main.py", line 789, in <module>
    fold_idx=fold_idx
  File "main.py", line 528, in main_seq
    num_classes=NUM_CLASSES
  File "/home/akara/Workspace/project/model.py", line 62, in __init__
    shape=[batch_size*seq_length, ]
  File "/home/akara/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 895, in placeholder
    name=name)
  File "/home/akara/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 1238, in _placeholder
    name=name)
  File "/home/akara/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/op_def_library.py", line 704, in apply_op
    op_def=op_def)
  File "/home/akara/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2260, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/home/akara/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1230, in __init__
    self._traceback = _extract_stack()

如果我多次重新运行它,它会报告相同的错误,但使用不同的占位符:

is_train占位符

...
tensorflow.python.framework.errors.InvalidArgumentError: You must feed a value for placeholder tensor 'model/Placeholder_2' with dtype bool
     [[Node: model/Placeholder_2 = Placeholder[dtype=DT_BOOL, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
     [[Node: _recv_model_1/Placeholder_0/_349 = _Send[T=DT_FLOAT, client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_2567__recv_model_1/Placeholder_0", _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_model_1/Placeholder_0)]]
Caused by op u'model/Placeholder_2', defined at:
...

inputs占位符

...
tensorflow.python.framework.errors.InvalidArgumentError: You must feed a value for placeholder tensor 'model/Placeholder' with dtype float and shape [200,7680,1,1]
     [[Node: model/Placeholder = Placeholder[dtype=DT_FLOAT, shape=[200,7680,1,1], _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
     [[Node: model_1/ArgMax/_354 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_105_model_1/ArgMax", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
Caused by op u'model/Placeholder', defined at:
...

如果我将network_valid替换为 replace_location 中的network,则这些脚本完美运作。因此,当我使用network_valid时,我不确定为什么会出现错误。

更新

我还发现错误来自model变量范围的占位符(即network用于培训),而不是model_1范围(即network_valid验证),即使脚本在验证期间停止运行。

我重命名了占位符并将它们移到Model类之外。然后我在Tensorboard(Graph from Tensorboard)中查看我的图表。似乎network占位符与network_valid之间没有链接,类似于network_valid占位符。

0 个答案:

没有答案