tensorflow tf.cond不对tf.reduce_mean执行true_fn或false_fn

时间:2018-02-21 14:03:11

标签: tensorflow

我试图调整损失函数tf.reduce_mean的输出,以避免NaN错误。我的代码是:

     limit=[]
     for i in xrange(12):
         limit.append(10000.0)
     limit = tf.constant(limit)

     predictions["loss"] =tf.cond(tf.reduce_mean(
             (prediction - transformed_values) ** 2, axis=-1) < limit,
                                  lambda:tf.reduce_mean(
             (prediction - transformed_values) ** 2, axis=-1),
                                  lambda:tf.reduce_mean(
             (prediction - transformed_values), axis=-1)).

然而,我收到错误

    INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpfnvr6j
INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_task_type': 'worker', '_is_chief': True, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f7eaa5bd750>, '_save_checkpoints_steps': None, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_num_ps_replicas': 0, '_tf_random_seed': None, '_master': '', '_num_worker_replicas': 1, '_task_id': 0, '_log_step_count_steps': 100, '_model_dir': '/tmp/tmpfnvr6j', '_save_summary_steps': 100}
shape: pred  (12,)  true_t  (12,)  false_t  (12,)
Traceback (most recent call last):
  File "/home/paul/workspace/workspace/Master/Elec_Price_Prediction/Time_Series.py", line 302, in <module>
    obtain_prediction()
  File "/home/paul/workspace/workspace/Master/Elec_Price_Prediction/Time_Series.py", line 212, in obtain_prediction
    estimator.train(input_fn=train_input_fn, steps=10000)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 302, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 711, in _train_model
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 694, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/timeseries/python/timeseries/head.py", line 201, in create_estimator_spec
    return self._train_ops(features)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/timeseries/python/timeseries/head.py", line 60, in _train_ops
    estimator_lib.ModeKeys.TRAIN)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/timeseries/python/timeseries/state_management.py", line 67, in define_loss
    return model.define_loss(features, mode)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/timeseries/python/timeseries/model.py", line 196, in define_loss
    return self.get_batch_loss(features=features, mode=mode, state=start_state)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/timeseries/python/timeseries/model.py", line 509, in get_batch_loss
    features, mode, state)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/timeseries/python/timeseries/model.py", line 609, in per_step_batch_loss
    outputs=["loss"] + self._train_output_names)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/timeseries/python/timeseries/model.py", line 775, in _state_update_loop
    loop_vars=initial_loop_arguments)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2816, in while_loop
    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2640, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2590, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/timeseries/python/timeseries/model.py", line 726, in _state_update_step
    state=state)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/timeseries/python/timeseries/model.py", line 605, in _batch_loss_filtering_step
    predictions=predictions)
  File "/home/paul/workspace/workspace/Master/Elec_Price_Prediction/Time_Series.py", line 105, in _filtering_step
    prediction=tf.cond(pred,lambda:true_t,lambda:false_t)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/util/deprecation.py", line 316, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1844, in cond
    p_2, p_1 = switch(pred, pred)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 305, in switch
    return gen_control_flow_ops._switch(data, pred, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_control_flow_ops.py", line 562, in _switch
    "Switch", data=data, pred=pred, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2958, in create_op
    set_shapes_for_outputs(ret)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2209, in set_shapes_for_outputs
    shapes = shape_func(op)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2159, in call_with_requiring
    return call_cpp_shape_fn(op, require_shape_fn=True)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/common_shapes.py", line 627, in call_cpp_shape_fn
    require_shape_fn)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/common_shapes.py", line 691, in _call_cpp_shape_fn_impl
    raise ValueError(err.message)
ValueError: Shape must be rank 0 but is rank 1 for 'head/model/while/state_update_step/cond/Switch' (op: 'Switch') with input shapes: [12], [12].

我的问题是为什么这是不可能的以及如何解决它。我试过检查pred和true_fn以及false_fn是否具有相同的形状。

1 个答案:

答案 0 :(得分:0)

我更喜欢tf.where。如何使用tf.where