使用Keras进行多处理 - Tensor不是此图的元素

时间:2018-05-28 19:05:38

标签: python-3.x tensorflow keras pathos

我尝试使用预测函数,但我遇到了一些问题。 我的keras模型在同一个包中的不同脚本(nnGenerator)中编译。编译模型的函数然后将模型返回到主脚本,在那里进行处理。

以下是我的代码的简化版本:

from pathos.multiprocessing import ProcessingPool
from pathos.helpers import freeze_support

from GeneticAlgorithm import genotype
from GeneticAlgorithm import nnGenerator
import h5py
import numpy as np
import json
import dill
import tensorflow as tf


###############################

with open('parameter_ga.json') as json_data:
    params = json.load(json_data)

params_l = params["layers"]

input_dims_actor = [(params_l["input_dims_actor"][0], params_l["input_dims_actor"][1])]
output_dims_actor = [(params_l["output_dims_actor"][0],)]

geno = genotype.Genotype('actor', 1, 1, input_dims_actor, output_dims_actor)

model = nnGenerator.create_models(geno, input_dims_actor, 'actor')
#####################################

#retrieve data
with h5py.File(r"C:\Users\Lennart\Masterarbeit\data.h5", 'r') as f:
    states = (f['states'][:]+1)/2

experience_buffer = range(states.shape[0])

experience = np.random.choice(experience_buffer, size=32)

batch_states = np.array([states[i] for i in experience])  # bx(sxs)

data = batch_states.copy()

print("shape:", data.shape)
print("type:", type(data))

def predict(model, data):
    result = model.predict(data)
    return result

def main():
    with ProcessingPool() as pool:
        M = pool.map(predict, [model], [data])
    print(M)

if __name__=="__main__":
    freeze_support()
    main()

当我将此模型输入pathos multiprocessing-pool时。我得到像ValueError这样的错误:Tensor Tensor(“concatenate_1 / concat:0”,shape =(?,3),dtype = float32)不是此图的元素。 :

    multiprocess.pool.RemoteTraceback: 
"""
Traceback (most recent call last):
  File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\multiprocess\pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\multiprocess\pool.py", line 44, in mapstar
    return list(map(*args))
  File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\pathos\helpers\mp_helper.py", line 15, in <lambda>
    func = lambda args: f(*args)
  File "C:/Users/Lennart/PycharmProjects/demonstrator_RLAlgorithm/GeneticAlgorithm/test.py", line 55, in predict
    result = model.predict(data)
  File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\keras\engine\training.py", line 1710, in predict
    self._make_predict_function()
  File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\keras\engine\training.py", line 999, in _make_predict_function
    **kwargs)
  File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\keras\backend\tensorflow_backend.py", line 2297, in function
    return Function(inputs, outputs, updates=updates, **kwargs)
  File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\keras\backend\tensorflow_backend.py", line 2246, in __init__
    with tf.control_dependencies(self.outputs):
  File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\tensorflow\python\framework\ops.py", line 3782, in control_dependencies
    return get_default_graph().control_dependencies(control_inputs)
  File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\tensorflow\python\framework\ops.py", line 3511, in control_dependencies
    c = self.as_graph_element(c)
  File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\tensorflow\python\framework\ops.py", line 2584, in as_graph_element
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
  File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\tensorflow\python\framework\ops.py", line 2663, in _as_graph_element_locked
    raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("concatenate_1/concat:0", shape=(?, 3), dtype=float32) is not an element of this graph.
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:/Users/Lennart/PycharmProjects/demonstrator_RLAlgorithm/GeneticAlgorithm/test.py", line 66, in <module>
    main()
  File "C:/Users/Lennart/PycharmProjects/demonstrator_RLAlgorithm/GeneticAlgorithm/test.py", line 61, in main
    M = pool.map(predict, [model], [data])
  File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\pathos\multiprocessing.py", line 137, in map
    return _pool.map(star(f), zip(*args)) # chunksize
  File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\multiprocess\pool.py", line 260, in map
    return self._map_async(func, iterable, mapstar, chunksize).get()
  File "C:\Users\Lennart\AppData\Local\Continuum\anaconda3\envs\Cense5\lib\site-packages\multiprocess\pool.py", line 608, in get
    raise self._value
ValueError: Tensor Tensor("concatenate_1/concat:0", shape=(?, 3), dtype=float32) is not an element of this graph.

模型摘要如下所示:

    ____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 40, 40, 3)     0                                            
____________________________________________________________________________________________________
flatten_1 (Flatten)              (None, 4800)          0           input_1[0][0]                    
____________________________________________________________________________________________________
dropout_1 (Dropout)              (None, 4800)          0           flatten_1[0][0]                  
____________________________________________________________________________________________________
forward_1_actor_ (Dense)         (None, 1)             4801        dropout_1[0][0]                  
____________________________________________________________________________________________________
sideways_1_actor_ (Dense)        (None, 1)             4801        dropout_1[0][0]                  
____________________________________________________________________________________________________
rotation_1_actor_ (Dense)        (None, 1)             4801        dropout_1[0][0]                  
____________________________________________________________________________________________________
concatenate_1 (Concatenate)      (None, 3)             0           forward_1_actor_[0][0]           
                                                                   sideways_1_actor_[0][0]          
                                                                   rotation_1_actor_[0][0]          
====================================================================================================
Total params: 14,403
Trainable params: 14,403
Non-trainable params: 0
____________________________________________________________________________________________________

所以我毕竟不知道为什么这段代码不起作用。我搜索了所有相关问题。但由于我的新手身份,我可能没有正确理解它。

感谢任何帮助,提前感谢Lennart

0 个答案:

没有答案