如何将keras float预测转换为int?

时间:2018-02-24 16:14:25

标签: python tensorflow keras

我收到了错误

  

文件“/anaconda/envs/tf3/lib/python3.6/site-packages/keras/engine/training.py”,第830行,在编译中       sample_weight,mask)
    文件“/anaconda/envs/tf3/lib/python3.6/site-packages/keras/engine/training.py”,第445行,加权       score_array * =权重
    在binary_op_wrapper中输入文件“/anaconda/envs/tf3/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py”,第898行       y = ops.convert_to_tensor(y,dtype = x.dtype.base_dtype,name =“y”)
    在convert_to_tensor中输入文件“/anaconda/envs/tf3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py”,第932行       as_ref =假)
    在internal_convert_to_tensor中输入文件“/anaconda/envs/tf3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py”,第1022行       ret = conversion_func(value,dtype = dtype,name = name,as_ref = as_ref)
    _TensorTensorConversionFunction中的文件“/anaconda/envs/tf3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py”,第866行       (dtype.name,t.dtype.name,str(t)))   ValueError:tensor转换请求dtype int32 for Tensor with dtype float32:'Tensor(“global_average_pooling2d_1_sample_weights:0”,shape =(?,),dtype = float32)'

在训练阶段。

运行最新的Keras (2.1.3)TensorFlow (1.5)Conda

以下是重现错误的最小代码:

from keras.layers import Input, Conv2D, GlobalAveragePooling2D
from keras.models import Model

import keras.backend as K
import numpy as np

def test_loss(y_input, x_input):

    x1 = K.cast(x_input, dtype='int32')
    y1 = K.cast(y_input, dtype='int32')

    loss = K.square(x1 - y1)


    reduced_loss = K.cumsum(loss)

    return reduced_loss

train_data = 10*np.random.rand(1600, 18,18,512)
validation_data = 10*np.random.rand(200, 18,18,512)

Y_train = np.random.rand(1600, 803)
Y_test = np.random.rand(200, 803)

#model
inputs = Input(shape=train_data.shape[1:])
x = Conv2D(803, (1,1), activation='sigmoid')(inputs)
predictions = GlobalAveragePooling2D(input_shape=train_data.shape[1:])(x)
model = Model(inputs=inputs, outputs=predictions)

model.summary()

model.compile(optimizer='adam', loss=test_loss,  metrics=['accuracy'])

model.fit(train_data, Y_train,
              epochs=200,
              batch_size=1,
              validation_data=(validation_data, Y_test))

0 个答案:

没有答案