TensorFlow错误:Shape [10]具有等级1< 2

时间:2018-04-06 00:22:41

标签: python-3.x tensorflow neural-network

我目前正在处理facebook评论预测数据集(http://uksim.info/uksim2015/data/8713a015.pdf)。作为评估指标,我使用top_10和AUC(如文中所述)。当我尝试在测试集上评估top_10时,我得到错误,“TensorFlow错误:Shape [10]的排名为1< 2。”。

我的数据集具有以下大小:X_test_set_1有10个数据集,每个数据集的长度为100x53,y_test_set_1有10个数据集,每个数据集的长度为10x1。

我的代码如下:

n_inputs = 53
n_hidden1 = 300
n_hidden2 = 200
n_hidden3 = 100
n_outputs = 1

reset_graph()

X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
y = tf.placeholder(tf.float32, shape=(None), name="y")

# Implement dropout

training = tf.placeholder_with_default(False, shape=(), name='training')

dropout_rate = 0.5  # == 1 - keep_prob
X_drop = tf.layers.dropout(X, dropout_rate, training=training)

with tf.name_scope("dnn"):
    hidden1 = tf.layers.dense(X_drop, n_hidden1, name="hidden1",
                              activation=tf.nn.selu)
    hidden2 = tf.layers.dense(hidden1, n_hidden2, name="hidden2",
                              activation=tf.nn.selu)
    hidden3 = tf.layers.dense(hidden2, n_hidden3, name="hidden3",
                              activation=tf.nn.selu)
    logit = tf.layers.dense(hidden3, n_outputs, name="outputs")
    logits = logit[:, 0]

with tf.name_scope("loss"):
    mse = tf.losses.mean_squared_error(labels=y, predictions=logits)
    loss = tf.reduce_mean(mse, name="loss")
    loss_summary = tf.summary.scalar('loss', loss)

learning_rate = 0.0001

with tf.name_scope("train"):
    optimizer = tf.train.AdamOptimizer(learning_rate)
    training_op = optimizer.minimize(loss)

with tf.name_scope("eval"):
    # MAE
    #correct = tf.metrics.mean_absolute_error(labels = y, predictions = logits)
    mae = tf.reduce_mean(tf.abs(y-logits))
    mae_summary = tf.summary.scalar('mae', mae) # MAE DONE

    # HITS
    # First, get top 10 values for labels and logits; this will be used for HITS and AUC
    # Next, find the intersection of indexes for the top 10 values and use this for HITS
    top_10_labels_vals, top_10_labels_index = tf.nn.top_k(y, 10)
    top_10_logits_vals, top_10_logits_index = tf.nn.top_k(logits, 10)

    # Add an extra column of None to the top_10_logits_index since we want to match the shape of top_10_labels_index
    hits = tf.sets.set_intersection(top_10_labels_index, top_10_logits_index[None])
    hits_10 = tf.size(hits)
    hits_summary = tf.summary.scalar('hits_10', hits_10)

    # AUC
    label_bool = tf.greater_equal(y, tf.reduce_min(top_10_labels_vals))
    label = tf.cast(label_bool, tf.int32)

    predictions_bool = tf.greater_equal(logits, tf.reduce_min(top_10_logits_vals))
    prediction = tf.cast(predictions_bool, tf.int32)
    # ORIGINAL! DO NOT DELETE!
    #auc_pre = tf.metrics.auc(labels = tf.greater_equal(y, tf.reduce_min(top_10_labels_vals)) , predictions = tf.greater_equal(logits, tf.reduce_min(top_10_logits_vals)))
    auc = tf.metrics.auc(label, prediction)
    #auc = tf.reduce_mean(auc_pre) # Is it necessary to the tf.reduce_mean for the auc?
    auc_summary = tf.summary.scalar('auc', auc)

init = tf.global_variables_initializer()
saver = tf.train.Saver()

...

# HITS@10 test_set_1
with tf.Session() as sess:
    saver.restore(sess, final_model_path_1)
    hits_val_list_1 = []
    for i in range(10):
        hits_val = hits_10.eval(feed_dict={X : X_test_set_1[i], y : y_test_set_1[i]})
        hits_val_list_1.append(hits_val)
    hits_val_list_1_mean = np.mean(hits_val_list_1)

INFO:tensorflow:Restoring parameters from ./1_train_dnn_reg_model
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1322     try:
-> 1323       return fn(*args)
   1324     except errors.OpError as e:

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1301                                    feed_dict, fetch_list, target_list,
-> 1302                                    status, run_metadata)
   1303 

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
    472             compat.as_text(c_api.TF_Message(self.status.status)),
--> 473             c_api.TF_GetCode(self.status.status))
    474     # Delete the underlying status object from memory otherwise it stays alive

InvalidArgumentError: Shape [10] has rank 1 < 2
     [[Node: eval_1/DenseToDenseSetOperation = DenseToDenseSetOperation[T=DT_INT32, set_operation="intersection", validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](eval_1/TopKV2:1, eval_1/strided_slice)]]

During handling of the above exception, another exception occurred:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-124-74de8454a00c> in <module>()
      4     hits_val_list_1 = []
      5     for i in range(10):
----> 6         hits_val = hits_10.eval(feed_dict={X : X_test_set_1[i], y : y_test_set_1[i]})
      7     #mae_val = mae.eval(feed_dict={X : X_test_set_1, y : y_test_set_1})
      8         hits_val_list_1.append(hits_val)

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in eval(self, feed_dict, session)
    568 
    569     """
--> 570     return _eval_using_default_session(self, feed_dict, self.graph, session)
    571 
    572   def _dup(self):

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _eval_using_default_session(tensors, feed_dict, graph, session)
   4453                        "the tensor's graph is different from the session's "
   4454                        "graph.")
-> 4455   return session.run(tensors, feed_dict)
   4456 
   4457 

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    887     try:
    888       result = self._run(None, fetches, feed_dict, options_ptr,
--> 889                          run_metadata_ptr)
    890       if run_metadata:
    891         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1118     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1119       results = self._do_run(handle, final_targets, final_fetches,
-> 1120                              feed_dict_tensor, options, run_metadata)
   1121     else:
   1122       results = []

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1315     if handle is None:
   1316       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1317                            options, run_metadata)
   1318     else:
   1319       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1334         except KeyError:
   1335           pass
-> 1336       raise type(e)(node_def, op, message)
   1337 
   1338   def _extend_graph(self):

InvalidArgumentError: Shape [10] has rank 1 < 2
     [[Node: eval_1/DenseToDenseSetOperation = DenseToDenseSetOperation[T=DT_INT32, set_operation="intersection", validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](eval_1/TopKV2:1, eval_1/strided_slice)]]

Caused by op 'eval_1/DenseToDenseSetOperation', defined at:
  File "/home/isaac/anaconda3/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/isaac/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 477, in start
    ioloop.IOLoop.instance().start()
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/tornado/ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 235, in dispatch_shell
    handler(stream, idents, msg)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2728, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2850, in run_ast_nodes
    if self.run_code(code, result):
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2910, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-90-419aa85d8f25>", line 14, in <module>
    hits = tf.sets.set_intersection(top_10_labels_index, top_10_logits_index[None])
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/sets_impl.py", line 197, in set_intersection
    return _set_operation(a, b, "intersection", validate_indices)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/sets_impl.py", line 130, in _set_operation
    a, b, set_operation, validate_indices)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_set_ops.py", line 69, in dense_to_dense_set_operation
    name=name)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
    op_def=op_def)
  File "/home/isaac/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Shape [10] has rank 1 < 2
     [[Node: eval_1/DenseToDenseSetOperation = DenseToDenseSetOperation[T=DT_INT32, set_operation="intersection", validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](eval_1/TopKV2:1, eval_1/strided_slice)]]

0 个答案:

没有答案