我正在尝试在同一Python文件中使用策略网络和CNN。我只训练PN,并将CNN用作前馈特征提取器。我遇到一个问题,即Train-PC不在会话中“构建”导入的Keras图,因此当我尝试在CNN(feature_vector = self.extract_feature_vector([X_img_norm])[0]
)上运行前馈功能时,错误发布如下。
我正在开发的家用PC和Train-PC上的文件之间的唯一区别是我导入模型的方式,我认为这可能是导致这种情况的原因(可能无法识别)。我以前发布过有关此错误here的信息。
由于Train-PC不是我自己的,因此无法更改软件包的版本。因此,我正在寻找一种不涉及降级的解决方案,因为导入问题也无法通过此解决。
家用PC(Mac):
tensorflow==1.11.0
keras 2.2.4
火车PC(Ubuntu):
tensorflow-gpu==1.10.1
keras 2.1.2
CNN的代码:
class ResNetCNN:
def __init__(self):
self.graph = tf.Graph()
self.sess = tf.Session(graph=self.graph)
with self.sess.graph.as_default():
# Load pre-trained model from file
self.CNN_ResNet_Pascal = models.load_model(CNN_MODEL_DIR)
# self.CNN_ResNet_Pascal = tf.keras.models.load_model(CNN_MODEL_DIR)
self.extract_feature_vector = K.function([self.CNN_ResNet_Pascal.layers[0].input],
[self.CNN_ResNet_Pascal.layers[-4].output])
def feed_forward(self, img):
"""Extract the feature vector of the passed image using the pre-trained CNN."""
# Resize observable region into input volume
X_img = np.array(img, dtype=np.float32).reshape(-1, IMG_SIZE, IMG_SIZE, 3)
# Normalise feature data
X_img_norm = X_img / 255.0
# Extract feature vector from the pre-trained CNN (1, 4096)
feature_vector = self.extract_feature_vector([X_img_norm])[0]
# Reshape tensor to (4096, )
feature_vector = np.array(feature_vector).reshape(4096, )
# Confirm that the correct output layer was used
assert feature_vector.shape == (4096, ), "Incorrect CNN output layer: shape = {}".format(feature_vector.shape)
return feature_vector
错误:
Traceback (most recent call last):
File "keras_pn.py", line 856, in <module>
s, a, r, d_r, n = rollout(epsilon, RENDER, PolicyNetwork, ResNetCNN)
File "keras_pn.py", line 682, in rollout
feature_vector = ResNetCNN.feed_forward(observable_region)
File "keras_pn.py", line 87, in feed_forward
feature_vector = self.extract_feature_vector([X_img_norm])[0]
File "/home/name/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2357, in __call__
**self.session_kwargs)
File "/home/name/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 877, in run
run_metadata_ptr)
File "/home/name/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1025, in _run
raise RuntimeError('The Session graph is empty. Add operations to the '
RuntimeError: The Session graph is empty. Add operations to the graph before calling run().
编辑:
def __init__(self):
self.graph = tf.Graph()
with self.graph.as_default():
# Load pre-trained model from file
self.CNN_ResNet_Pascal = models.load_model(CNN_MODEL_DIR)
# self.CNN_ResNet_Pascal = tf.keras.models.load_model(CNN_MODEL_DIR)
self.extract_feature_vector = K.function([self.CNN_ResNet_Pascal.layers[0].input],
[self.CNN_ResNet_Pascal.layers[-4].output])
self.sess = tf.Session(graph=self.graph)