我想通过定义自己的自定义指标(使用Theano后端)来监控y_pred
的维度
def shape_test(y_true, y_pred):
return K.shape(y_pred)[0]
我假设自定义指标函数中的y_pred
维度等于迷你批量大小。但是,我得到了奇怪的输出。请参阅下面的一个小型可重复示例。
#imports and definitions
import numpy
numpy.random.seed(1234)
import keras.backend as K
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
neuron_num=20
dim_input=2
#batch size will be important below!
batch_size=2048
TT=int(1e4)
#sample data
X=numpy.random.randn(TT,dim_input)
eps=numpy.random.randn(TT)
Y=0.3*X[:,0]+0.5*X[:,1]+eps
x={"is":X[:(TT/2),:],"os":X[(TT/2+1):,:]}
y={"is":Y[:(TT/2)],"os":Y[(TT/2+1):]}
这是上面给出的自定义指标
def shape_test(y_true, y_pred):
return K.shape(y_pred)[0]
现在定义一个简单的NN
sgd=SGD(lr=1e-2,nesterov=True)
model=Sequential()
model.add(Dense(neuron_num,
input_dim=x["is"].shape[1],
init="glorot_normal",
activation="tanh"))
model.add(Dense(neuron_num,init="glorot_normal",activation="tanh"))
model.add(Dense(1,init="glorot_normal",activation="linear"))
model.compile(loss="mean_squared_error",
optimizer=sgd,
metrics=["mean_squared_error",shape_test])
model.fit(x["is"],
y["is"],
validation_data=(x["os"],y["os"]),
nb_epoch=1,
batch_size=batch_size,
verbose=False).history
这给出了
#{'loss': [1.834826689338684],
# 'mean_squared_error': [1.834826689338684],
# 'shape_test': [1841],
# 'val_loss': [1.4931119817522769],
# 'val_mean_squared_error': [1.4931119817522769],
# 'val_shape_test': [1841.1716343268654]}
我希望看到'shape_test': [2048]
代替'shape_test': [1841]
,因为批量大小为2048.
这看起来很奇怪。这可能是个错误吗?
我使用的是Python 2.7.6
,Keras==1.0.8
,Theano==0.8.2
和CPU。
答案 0 :(得分:1)
使用neuron_num=2000
和verbose=True
,以下是我能够根据您的示例制作的内容:
Epoch 1/1
2048/5000 [========>............] - ETA: 9s - loss: 1.4507 - shape_test: 2048.000
4096/5000 [=================>...] - ETA: 3s - loss: 1.3577 - shape_test: 2048.000
5000/5000 [=====================] - 26s - loss: 1.3087 - shape_test: 1841.1648 - val_shape_test: 1841.1716
正如您所看到的,您的形状函数似乎工作正常。但由于batch_size不是训练集大小的除数,因此最后一批仅包含904个示例。我似乎无法猜测Keras如何在一分钟内提出1841,但它可能并不复杂。
batch_size=2500
的另一次尝试看起来更好:
2500/5000 [==========>..........] - ETA: 9s - loss: 1.4292 - shape_test: 2500.0000
5000/5000 [=====================] - 24s - loss: 1.3311 - shape_test: 2500.0000 - val_shape_test: 2499.5001