我正在使用此模型:
def controller(Cin,Cout,Ein,Eout,batch=1,load=1):
inc1 = Input(batch_shape=(batch,Cin))
h1 = Reshape(target_shape=[1,Cin])(inc1)
h1 = GRU(activation='relu',units=8,return_sequences=True,stateful=True)(h1)
h1 = Dense(units=8, activation='relu')(h1)
h1 = Dense(units=Cout, activation='relu')(h1)
h1 = Reshape(target_shape=[Cout,])(h1)
controller = Model(inc1,h1)
controller.compile(loss='mse',optimizer='adam')
if load:
try:
controller.load_weights('controller.md5')
print("Controller Loaded wieghts Sucessfully")
except: print("couldn't load the weights")
controller.summary()
ine1 = Input(batch_shape=(batch,Ein-Cout))
h1 = Dense(units=8,activation='relu')(ine1)
ine2 = Input(batch_shape=(batch,Cout))
h2 = Dense(units=8,activation='relu')(ine2)
h3 = Add()([h1,h2])
h3 = Dense(units=8, activation='relu')(h3)
h3 = Reshape(target_shape=[1,8])(h3)
h3 = GRU(activation='relu',units=8,return_sequences=True,stateful=True)(h3)
h3 = Dense(units=Eout, activation='linear')(h3)
h3 = Reshape(target_shape=[Eout,])(h3)
estimator = Model([ine1,ine2],h3)
estimator.compile(loss='mse',optimizer='adam')
if load:
try:
estimator.load_weights('estimator.md5')
print("estimator Loaded wieghts Sucessfully")
except: print("couldn't load the weights")
estimator.summary()
in1 = Input(batch_shape=(batch,Cin))
cont = controller(in1)
in2 = Input(batch_shape=(batch,Ein-Cout))
estm = estimator([in2,cont])
model = Model([in1,in2],estm)
model.compile(loss='mse',optimizer='adam')
model.summary()
return controller, estimator, model
control,estimator,model = controller(1,1,2,1,batch=10,load=1)
a,b=np.zeros((10,1)),np.zeros((10,1))
model.predict([a,b],batch_size=10)
当我在最后运行几条测试线时,出现此错误
couldn't load the weights
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_269 (InputLayer) (10, 1) 0
_________________________________________________________________
reshape_225 (Reshape) (10, 1, 1) 0
_________________________________________________________________
gru_94 (GRU) (10, 1, 8) 240
_________________________________________________________________
dense_421 (Dense) (10, 1, 8) 72
_________________________________________________________________
dense_422 (Dense) (10, 1, 1) 9
_________________________________________________________________
reshape_226 (Reshape) (10, 1) 0
=================================================================
Total params: 321
Trainable params: 321
Non-trainable params: 0
_________________________________________________________________
couldn't load the weights
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_270 (InputLayer) (10, 1) 0
__________________________________________________________________________________________________
input_271 (InputLayer) (10, 1) 0
__________________________________________________________________________________________________
dense_423 (Dense) (10, 8) 16 input_270[0][0]
__________________________________________________________________________________________________
dense_424 (Dense) (10, 8) 16 input_271[0][0]
__________________________________________________________________________________________________
add_65 (Add) (10, 8) 0 dense_423[0][0]
dense_424[0][0]
__________________________________________________________________________________________________
dense_425 (Dense) (10, 8) 72 add_65[0][0]
__________________________________________________________________________________________________
reshape_227 (Reshape) (10, 1, 8) 0 dense_425[0][0]
__________________________________________________________________________________________________
gru_95 (GRU) (10, 1, 8) 408 reshape_227[0][0]
__________________________________________________________________________________________________
dense_426 (Dense) (10, 1, 1) 9 gru_95[0][0]
__________________________________________________________________________________________________
reshape_228 (Reshape) (10, 1) 0 dense_426[0][0]
==================================================================================================
Total params: 521
Trainable params: 521
Non-trainable params: 0
__________________________________________________________________________________________________
Tensor("input_272:0", shape=(10, 1), dtype=float32)
Tensor("model_141/reshape_226/Reshape:0", shape=(10, 1), dtype=float32)
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_272 (InputLayer) (10, 1) 0
__________________________________________________________________________________________________
input_273 (InputLayer) (10, 1) 0
__________________________________________________________________________________________________
model_141 (Model) (10, 1) 321 input_272[0][0]
__________________________________________________________________________________________________
model_142 (Model) (10, 1) 521 input_273[0][0]
model_141[1][0]
==================================================================================================
Total params: 842
Trainable params: 842
Non-trainable params: 0
__________________________________________________________________________________________________
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
1291 try:
-> 1292 return fn(*args)
1293 except errors.OpError as e:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
1276 return self._call_tf_sessionrun(
-> 1277 options, feed_dict, fetch_list, target_list, run_metadata)
1278
/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata)
1366 self._session, options, feed_dict, fetch_list, target_list,
-> 1367 run_metadata)
1368
InvalidArgumentError: You must feed a value for placeholder tensor 'input_269' with dtype float and shape [10,1]
[[{{node input_269}} = Placeholder[dtype=DT_FLOAT, shape=[10,1], _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
[[{{node model_142/reshape_228/Reshape/_4033}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_653_model_142/reshape_228/Reshape", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
简单来说,该模型基于两个NN,
第一个被称为控制器的控制器采用一个形状为(#batches,1)的输入(第一个输入)并输出一个形状为(#batches,1)的输出。 第二个称为估计器,它有两个输入,一个是控制器NN的输出,第二个(第二个输入)是形状上的另一个输入(#batches,1)。
称为模型的模型通过获取第一个和第二个输入,调用控制器并将其输出带到Feed估计器,最后输出估计器输出来自动化该过程。
我无法理解错误的出处?谢谢