我目前正在处理facebook评论预测数据集(http://uksim.info/uksim2015/data/8713a015.pdf)。作为评估指标,我使用top_10和AUC(如文中所述)。当我尝试在测试集上评估top_10时,我得到错误,“TensorFlow错误:Shape [10]的排名为1< 2。”。
我的数据集具有以下大小:X_test_set_1有10个数据集,每个数据集的长度为100x53,y_test_set_1有10个数据集,每个数据集的长度为10x1。
我的代码如下:
n_inputs = 53
n_hidden1 = 300
n_hidden2 = 200
n_hidden3 = 100
n_outputs = 1
reset_graph()
X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
y = tf.placeholder(tf.float32, shape=(None), name="y")
# Implement dropout
training = tf.placeholder_with_default(False, shape=(), name='training')
dropout_rate = 0.5 # == 1 - keep_prob
X_drop = tf.layers.dropout(X, dropout_rate, training=training)
with tf.name_scope("dnn"):
hidden1 = tf.layers.dense(X_drop, n_hidden1, name="hidden1",
activation=tf.nn.selu)
hidden2 = tf.layers.dense(hidden1, n_hidden2, name="hidden2",
activation=tf.nn.selu)
hidden3 = tf.layers.dense(hidden2, n_hidden3, name="hidden3",
activation=tf.nn.selu)
logit = tf.layers.dense(hidden3, n_outputs, name="outputs")
logits = logit[:, 0]
with tf.name_scope("loss"):
mse = tf.losses.mean_squared_error(labels=y, predictions=logits)
loss = tf.reduce_mean(mse, name="loss")
loss_summary = tf.summary.scalar('loss', loss)
learning_rate = 0.0001
with tf.name_scope("train"):
optimizer = tf.train.AdamOptimizer(learning_rate)
training_op = optimizer.minimize(loss)
with tf.name_scope("eval"):
# MAE
#correct = tf.metrics.mean_absolute_error(labels = y, predictions = logits)
mae = tf.reduce_mean(tf.abs(y-logits))
mae_summary = tf.summary.scalar('mae', mae) # MAE DONE
# HITS
# First, get top 10 values for labels and logits; this will be used for HITS and AUC
# Next, find the intersection of indexes for the top 10 values and use this for HITS
top_10_labels_vals, top_10_labels_index = tf.nn.top_k(y, 10)
top_10_logits_vals, top_10_logits_index = tf.nn.top_k(logits, 10)
# Add an extra column of None to the top_10_logits_index since we want to match the shape of top_10_labels_index
hits = tf.sets.set_intersection(top_10_labels_index, top_10_logits_index[None])
hits_10 = tf.size(hits)
hits_summary = tf.summary.scalar('hits_10', hits_10)
# AUC
label_bool = tf.greater_equal(y, tf.reduce_min(top_10_labels_vals))
label = tf.cast(label_bool, tf.int32)
predictions_bool = tf.greater_equal(logits, tf.reduce_min(top_10_logits_vals))
prediction = tf.cast(predictions_bool, tf.int32)
# ORIGINAL! DO NOT DELETE!
#auc_pre = tf.metrics.auc(labels = tf.greater_equal(y, tf.reduce_min(top_10_labels_vals)) , predictions = tf.greater_equal(logits, tf.reduce_min(top_10_logits_vals)))
auc = tf.metrics.auc(label, prediction)
#auc = tf.reduce_mean(auc_pre) # Is it necessary to the tf.reduce_mean for the auc?
auc_summary = tf.summary.scalar('auc', auc)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
...
# HITS@10 test_set_1
with tf.Session() as sess:
saver.restore(sess, final_model_path_1)
hits_val_list_1 = []
for i in range(10):
hits_val = hits_10.eval(feed_dict={X : X_test_set_1[i], y : y_test_set_1[i]})
hits_val_list_1.append(hits_val)
hits_val_list_1_mean = np.mean(hits_val_list_1)
INFO:tensorflow:Restoring parameters from ./1_train_dnn_reg_model
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
1322 try:
-> 1323 return fn(*args)
1324 except errors.OpError as e:
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
1301 feed_dict, fetch_list, target_list,
-> 1302 status, run_metadata)
1303
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
472 compat.as_text(c_api.TF_Message(self.status.status)),
--> 473 c_api.TF_GetCode(self.status.status))
474 # Delete the underlying status object from memory otherwise it stays alive
InvalidArgumentError: Shape [10] has rank 1 < 2
[[Node: eval_1/DenseToDenseSetOperation = DenseToDenseSetOperation[T=DT_INT32, set_operation="intersection", validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](eval_1/TopKV2:1, eval_1/strided_slice)]]
During handling of the above exception, another exception occurred:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-124-74de8454a00c> in <module>()
4 hits_val_list_1 = []
5 for i in range(10):
----> 6 hits_val = hits_10.eval(feed_dict={X : X_test_set_1[i], y : y_test_set_1[i]})
7 #mae_val = mae.eval(feed_dict={X : X_test_set_1, y : y_test_set_1})
8 hits_val_list_1.append(hits_val)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in eval(self, feed_dict, session)
568
569 """
--> 570 return _eval_using_default_session(self, feed_dict, self.graph, session)
571
572 def _dup(self):
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _eval_using_default_session(tensors, feed_dict, graph, session)
4453 "the tensor's graph is different from the session's "
4454 "graph.")
-> 4455 return session.run(tensors, feed_dict)
4456
4457
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
887 try:
888 result = self._run(None, fetches, feed_dict, options_ptr,
--> 889 run_metadata_ptr)
890 if run_metadata:
891 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1118 if final_fetches or final_targets or (handle and feed_dict_tensor):
1119 results = self._do_run(handle, final_targets, final_fetches,
-> 1120 feed_dict_tensor, options, run_metadata)
1121 else:
1122 results = []
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1315 if handle is None:
1316 return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1317 options, run_metadata)
1318 else:
1319 return self._do_call(_prun_fn, self._session, handle, feeds, fetches)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
1334 except KeyError:
1335 pass
-> 1336 raise type(e)(node_def, op, message)
1337
1338 def _extend_graph(self):
InvalidArgumentError: Shape [10] has rank 1 < 2
[[Node: eval_1/DenseToDenseSetOperation = DenseToDenseSetOperation[T=DT_INT32, set_operation="intersection", validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](eval_1/TopKV2:1, eval_1/strided_slice)]]
Caused by op 'eval_1/DenseToDenseSetOperation', defined at:
File "/home/isaac/anaconda3/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/home/isaac/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
app.launch_new_instance()
File "/home/isaac/anaconda3/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
app.start()
File "/home/isaac/anaconda3/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 477, in start
ioloop.IOLoop.instance().start()
File "/home/isaac/anaconda3/lib/python3.6/site-packages/zmq/eventloop/ioloop.py", line 177, in start
super(ZMQIOLoop, self).start()
File "/home/isaac/anaconda3/lib/python3.6/site-packages/tornado/ioloop.py", line 888, in start
handler_func(fd_obj, events)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
return fn(*args, **kwargs)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
self._handle_recv()
File "/home/isaac/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
self._run_callback(callback, msg)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
callback(*args, **kwargs)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 277, in null_wrapper
return fn(*args, **kwargs)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
return self.dispatch_shell(stream, msg)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 235, in dispatch_shell
handler(stream, idents, msg)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
user_expressions, allow_stdin)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 196, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 533, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2728, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2850, in run_ast_nodes
if self.run_code(code, result):
File "/home/isaac/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2910, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-90-419aa85d8f25>", line 14, in <module>
hits = tf.sets.set_intersection(top_10_labels_index, top_10_logits_index[None])
File "/home/isaac/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/sets_impl.py", line 197, in set_intersection
return _set_operation(a, b, "intersection", validate_indices)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/sets_impl.py", line 130, in _set_operation
a, b, set_operation, validate_indices)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_set_ops.py", line 69, in dense_to_dense_set_operation
name=name)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
op_def=op_def)
File "/home/isaac/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
InvalidArgumentError (see above for traceback): Shape [10] has rank 1 < 2
[[Node: eval_1/DenseToDenseSetOperation = DenseToDenseSetOperation[T=DT_INT32, set_operation="intersection", validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](eval_1/TopKV2:1, eval_1/strided_slice)]]