我尝试使用预测函数,但我遇到了一些问题。 我的keras模型在同一个包中的不同脚本(nnGenerator)中编译。编译模型的函数然后将模型返回到主脚本,在那里进行处理。
以下是我的代码的简化版本:
from pathos.multiprocessing import ProcessingPool
from pathos.helpers import freeze_support
from GeneticAlgorithm import genotype
from GeneticAlgorithm import nnGenerator
import h5py
import numpy as np
import json
import dill
import tensorflow as tf
###############################
with open('parameter_ga.json') as json_data:
params = json.load(json_data)
params_l = params["layers"]
input_dims_actor = [(params_l["input_dims_actor"][0], params_l["input_dims_actor"][1])]
output_dims_actor = [(params_l["output_dims_actor"][0],)]
geno = genotype.Genotype('actor', 1, 1, input_dims_actor, output_dims_actor)
model = nnGenerator.create_models(geno, input_dims_actor, 'actor')
#####################################
#retrieve data
with h5py.File(r"C:\Users\Lennart\Masterarbeit\data.h5", 'r') as f:
states = (f['states'][:]+1)/2
experience_buffer = range(states.shape[0])
experience = np.random.choice(experience_buffer, size=32)
batch_states = np.array([states[i] for i in experience]) # bx(sxs)
data = batch_states.copy()
print("shape:", data.shape)
print("type:", type(data))
def predict(model, data):
result = model.predict(data)
return result
def main():
with ProcessingPool() as pool:
M = pool.map(predict, [model], [data])
print(M)
if __name__=="__main__":
freeze_support()
main()
当我将此模型输入pathos multiprocessing-pool时。我得到像ValueError这样的错误:Tensor Tensor(“concatenate_1 / concat:0”,shape =(?,3),dtype = float32)不是此图的元素。 :
multiprocess.pool.RemoteTraceback:
"""
Traceback (most recent call last):
File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\multiprocess\pool.py", line 119, in worker
result = (True, func(*args, **kwds))
File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\multiprocess\pool.py", line 44, in mapstar
return list(map(*args))
File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\pathos\helpers\mp_helper.py", line 15, in <lambda>
func = lambda args: f(*args)
File "C:/Users/Lennart/PycharmProjects/demonstrator_RLAlgorithm/GeneticAlgorithm/test.py", line 55, in predict
result = model.predict(data)
File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\keras\engine\training.py", line 1710, in predict
self._make_predict_function()
File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\keras\engine\training.py", line 999, in _make_predict_function
**kwargs)
File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\keras\backend\tensorflow_backend.py", line 2297, in function
return Function(inputs, outputs, updates=updates, **kwargs)
File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\keras\backend\tensorflow_backend.py", line 2246, in __init__
with tf.control_dependencies(self.outputs):
File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\tensorflow\python\framework\ops.py", line 3782, in control_dependencies
return get_default_graph().control_dependencies(control_inputs)
File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\tensorflow\python\framework\ops.py", line 3511, in control_dependencies
c = self.as_graph_element(c)
File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\tensorflow\python\framework\ops.py", line 2584, in as_graph_element
return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\tensorflow\python\framework\ops.py", line 2663, in _as_graph_element_locked
raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("concatenate_1/concat:0", shape=(?, 3), dtype=float32) is not an element of this graph.
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "C:/Users/Lennart/PycharmProjects/demonstrator_RLAlgorithm/GeneticAlgorithm/test.py", line 66, in <module>
main()
File "C:/Users/Lennart/PycharmProjects/demonstrator_RLAlgorithm/GeneticAlgorithm/test.py", line 61, in main
M = pool.map(predict, [model], [data])
File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\pathos\multiprocessing.py", line 137, in map
return _pool.map(star(f), zip(*args)) # chunksize
File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\multiprocess\pool.py", line 260, in map
return self._map_async(func, iterable, mapstar, chunksize).get()
File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\multiprocess\pool.py", line 608, in get
raise self._value
ValueError: Tensor Tensor("concatenate_1/concat:0", shape=(?, 3), dtype=float32) is not an element of this graph.
模型摘要如下所示:
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_1 (InputLayer) (None, 40, 40, 3) 0
____________________________________________________________________________________________________
flatten_1 (Flatten) (None, 4800) 0 input_1[0][0]
____________________________________________________________________________________________________
dropout_1 (Dropout) (None, 4800) 0 flatten_1[0][0]
____________________________________________________________________________________________________
forward_1_actor_ (Dense) (None, 1) 4801 dropout_1[0][0]
____________________________________________________________________________________________________
sideways_1_actor_ (Dense) (None, 1) 4801 dropout_1[0][0]
____________________________________________________________________________________________________
rotation_1_actor_ (Dense) (None, 1) 4801 dropout_1[0][0]
____________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 3) 0 forward_1_actor_[0][0]
sideways_1_actor_[0][0]
rotation_1_actor_[0][0]
====================================================================================================
Total params: 14,403
Trainable params: 14,403
Non-trainable params: 0
____________________________________________________________________________________________________
所以我毕竟不知道为什么这段代码不起作用。我搜索了所有相关问题。但由于我的新手身份,我可能没有正确理解它。
感谢任何帮助,提前感谢Lennart