Keras-vis opt.minimize()InvalidArgumentError:

时间:2019-01-27 02:40:31

标签: keras deep-learning visualization conv-neural-network

我需要了解我的cnn模型,并希望借助显着性图查看激活的输出,我了解了keras-vis并根据keras-vis文档中的示例实现了该方法,但是我使用了模型以查看模型的输出。

起初我无法导入

from vis.modifiers import Jitter

错误

No module named 'vis.modifiers'

尽管其他vis库成功导入

现在我的第二个问题是,当我删除该库并删除抖动时, Optimizer.optmize功能比此错误发生,请参阅下面的代码

我的模特:

model = Sequential()
model.add(Conv2D(16,kernel_size = (5,5),activation = 'relu', activity_regularizer=regularizers.l2(1e-8)))
model.add(Conv2D(32,kernel_size = (5,5),activation = 'relu', activity_regularizer = regularizers.l2(1e-8)))
model.add(MaxPooling2D(3,3))
model.add(Conv2D(64,kernel_size = (5,5),activation = 'relu', activity_regularizer = regularizers.l2(1e-8)))
model.add(MaxPooling2D(3,3))
model.add(Conv2D(128,activation = 'relu',kernel_size = (3,3),activity_regularizer = regularizers.l2(1e-8)))
model.add(Flatten())
model.add(Dropout(0.8))
model.add(Dense(64,activation = 'relu',activity_regularizer = regularizers.l2(1e-8)))
model.add(Dropout(0.8))
model.add(Dense(64,activation = 'relu',activity_regularizer = regularizers.l2(1e-8)))
model.add(Dropout(0.8))
model.add(Dense(2,activation = 'softmax'))
model.compile(loss=keras.losses.binary_crossentropy, optimizer=keras.optimizers.SGD(lr = 0.001,clipnorm = 1,momentum= 0.9), metrics=["accuracy"])
model.fit(X_train,y_train, epochs = 10 ,batch_size = 16,validation_data=(X_test,y_test_Categorical))
model.summary()
layer_name = 'dense_6'
layer_dict = dict([(layer.name, layer) for layer in model.layers[1:]])
output_class = [0]
losses = [
    (ActivationMaximization(layer_dict[layer_name], output_class), 1),
    (LPNorm(model.input), 1),
    (TotalVariation(model.input), 1)
]
opt = Optimizer(model.input, losses)
opt.minimize(max_iter=500, verbose=True, callbacks=[GifGenerator('opt_progress')])

错误:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-31-b9be9bcb5c26> in <module>()
      9 ]
     10 opt = Optimizer(model.input, losses)
---> 11 opt.minimize(max_iter=500, verbose=True, callbacks=[GifGenerator('opt_progress')])

/opt/conda/lib/python3.6/site-packages/vis/optimizer.py in minimize(self, seed_input, max_iter, input_modifiers, grad_modifier, callbacks, verbose)
    141 
    142             # 0 learning phase for 'test'
--> 143             computed_values = self.compute_fn([seed_input, 0])
    144             losses = computed_values[:len(self.loss_names)]
    145             named_losses = zip(self.loss_names, losses)

/opt/conda/lib/python3.6/site-packages/Keras-2.2.4-py3.6.egg/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2919                     return self._legacy_call(inputs)
   2920 
-> 2921             return self._call(inputs)
   2922         else:
   2923             if py_any(is_tensor(x) for x in inputs):

/opt/conda/lib/python3.6/site-packages/Keras-2.2.4-py3.6.egg/keras/backend/tensorflow_backend.py in _call(self, inputs)
   2873                                 feed_symbols,
   2874                                 symbol_vals,
-> 2875                                 session)
   2876         if self.run_metadata:
   2877             fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)

/opt/conda/lib/python3.6/site-packages/Keras-2.2.4-py3.6.egg/keras/backend/tensorflow_backend.py in _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session)
   2825             callable_opts.run_options.CopyFrom(self.run_options)
   2826         # Create callable.
-> 2827         callable_fn = session._make_callable_from_options(callable_opts)
   2828         # Cache parameters corresponding to the generated callable, so that
   2829         # we can detect future mismatches and refresh the callable.

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in _make_callable_from_options(self, callable_options)
   1469     """
   1470     self._extend_graph()
-> 1471     return BaseSession._Callable(self, callable_options)
   1472 
   1473 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, session, callable_options)
   1423         with errors.raise_exception_on_not_ok_status() as status:
   1424           self._handle = tf_session.TF_SessionMakeCallable(
-> 1425               session._session, options_ptr, status)
   1426       finally:
   1427         tf_session.TF_DeleteBuffer(options_ptr)

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
    526             None, None,
    527             compat.as_text(c_api.TF_Message(self.status.status)),
--> 528             c_api.TF_GetCode(self.status.status))
    529     # Delete the underlying status object from memory otherwise it stays alive
    530     # as there is a reference to status from this from the traceback due to

InvalidArgumentError: sequential_2_input:0 is both fed and fetched.

0 个答案:

没有答案