model.fit ValueError升级到tensorflow 2(tf.keras)

时间:2020-02-23 12:51:37

标签: tensorflow keras tensorflow2.0

我正在尝试使用keras从tf 1.15升级到tensorflow 2,即tf.keras ...

使用keras的tf 1.15正常工作。

在调用model.fit()时,出现值错误(如下所示)。

#Train the model
import tensorflow as tf
model.fit(data, [labels, data], batch_size=1, epochs=1, verbose=1)

输入和目标数据: 数据是dtype('float32') 标签是dtype('uint8')

最终,代码在numpy乘法运算中失败: TypeError:“ Mul”操作的输入“ y”的类型为float32,与参数“ x”的uint8类型不匹配。

我尝试将np标签数组更改为tf.float32,将标签转换为tf.float32。我还尝试了更简单的损失函数。

任何方向将不胜感激。谢谢,杰伊。

model.fit() output:
Train on 4 samples
1/4 [======>.......................] - ETA: 3s
ValueError Traceback (most recent call last)
~/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(op_type_name, name, **keywords)
469 as_ref=input_arg.is_ref,
--> 470 preferred_dtype=default_dtype)
471 except TypeError as err:

~/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, dtype_hint, ctx, accepted_result_types)
1316 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
-> 1317 (dtype.name, value.dtype.name, value))
1318 return value

ValueError: Tensor conversion requested dtype uint8 for Tensor with dtype float32: <tf.Tensor 'model/Dec_GT_Output/Sigmoid:0' shape=(1, 3, 80, 96, 64) dtype=float32>

During handling of the above exception, another exception occurred:

TypeError Traceback (most recent call last)
in
2 # .astype("float32").values
3
----> 4 model.fit(data, [labels, data], batch_size=1, epochs=1, verbose=1)

...

~/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(op_type_name, name, **keywords)
504 "%s type %s of argument '%s'." %
505 (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
--> 506 inferred_from[input_arg.type_attr]))
507
508 types = [values.dtype]

TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type uint8 of argument 'x'.

0 个答案:

没有答案