当我运行以下代码时:
array([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)
一切正常,输出为:
arr
但是,如果我将tf.constant(arr)
替换为tf.keras.utils.to_categorical(tf.constant(arr), 10)
(或者我猜测但不确定任何张量),如以下代码所示:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-139-507d6ce7c0a3> in <module>()
----> 1 tf.keras.utils.to_categorical(tf.constant(arr), 10)
.../miniconda2/lib/python2.7/site-packages/tensorflow/python/keras/utils/np_utils.pyc in to_categorical(y, num_classes, dtype)
38 last.
39 """
---> 40 y = np.array(y, dtype='int')
41 input_shape = y.shape
42 if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
ValueError: setting an array element with a sequence.
我收到以下错误:
model.compile
配置:
如何摆脱/解决此问题?
一些上下文:
我的主要问题是,当我致电def loss(y_true, y_pred):
# Cross entropy loss
bin_true = y_true[:, 0]
print bin_true.eval()
dum = tf.keras.utils.to_categorical(bin_true, 66)
cls_loss = tf.keras.losses.categorical_crossentropy(dum, y_pred, True)
# MSE loss
cont_true = y_true[:, 1]
pred_cont = tf.keras.backend.sum(tf.nn.softmax(y_pred) * idx_tensor, 1) * 3 - 99
mse_loss = tf.keras.losses.mean_squared_error(cont_true, pred_cont)
# Total loss
return cls_loss + 0.5 * mse_loss
时,遭受了以下损失:
dum = tf.keras.utils.to_categorical(bin_true, 66))
我在第{{1}}行出现了完全相同的错误
(我提供了一些背景信息,因为我的整个“做事方式”可能是错误的...)
答案 0 :(得分:1)
您可以尝试使用tf.one_hot代替keras to_categorical
答案 1 :(得分:1)
您对问题的天真的答案是使用get_value(x)
bin_true_array = tf.keras.backend.get_value(bin_true)
dum = tf.keras.utils.to_categorical(bin_true_array, 66)
....
这将以一个numpy数组的形式检索bin_true
的值,然后将其输入到numpy实用程序utils.to_categorical
中。这是您的示例(已调整):
import numpy as np
import tensorflow as tf
arr = np.array([2., 4., 5., 9.])
tf.keras.utils.to_categorical(arr, 10)
tensor = tf.constant(arr)
tensor_as_array = tf.keras.backend.get_value(tensor)
tf.keras.utils.to_categorical(tensor_as_array, 10)
此方法的问题是退出后端,返回python和numpy(所有python好东西(例如GIL)附带),然后将其向下推送到后端。这可能会(也可能不会)在您的管道中造成瓶颈。在这种情况下,后端依赖的解决方案(如@Robin Beilvert建议的解决方案)可能会为您提供更好的服务。