我尝试在培训期间获取输出图层。我正在尝试对模型进行实时3D可视化并使其具有交互性。我正在将tensorflow 2.0和python 3用于google colab。
这是我的代码:
进口
from __future__ import absolute_import, division, print_function, unicode_literals
try:
# Use the %tensorflow_version magic if in colab.
%tensorflow_version 2.x
except Exception:
pass
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_hub as hub
import tensorflow_datasets as tfds
from tensorflow.keras import datasets, layers, models
from tensorflow.keras import backend as K
from tensorflow.keras.backend import clear_session
from tensorflow.keras.callbacks import Callback as Callback
import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)
获取数据
splits = tfds.Split.TRAIN.subsplit([70, 30])
(training_set, validation_set), dataset_info = tfds.load('tf_flowers',with_info=True, as_supervised=True, split=splits)
for i, example in enumerate(training_set.take(5)):
print('Image {} shape: {} label: {}'.format(i+1, example[0].shape, example[1]))
检查班级和图片的数量
num_classes = dataset_info.features['label'].num_classes
num_training_examples = 0
num_validation_examples = 0
for example in training_set:
num_training_examples += 1
for example in validation_set:
num_validation_examples += 1
print('Total Number of Classes: {}'.format(num_classes))
print('Total Number of Training Images: {}'.format(num_training_examples))
print('Total Number of Validation Images: {} \n'.format(num_validation_examples))
开始创建
IMAGE_RES = 299
BATCH_SIZE = 32
def format_image(image, label):
image = tf.image.resize(image, (IMAGE_RES, IMAGE_RES))/255.0
return image, label
(training_set, validation_set), dataset_info = tfds.load('tf_flowers', with_info=True, as_supervised=True, split=splits)
train_batches = training_set.shuffle(num_training_examples//4).map(format_image).batch(BATCH_SIZE).prefetch(1)
validation_batches = validation_set.map(format_image).batch(BATCH_SIZE).prefetch(1)
URL = "https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4"
feature_extractor = hub.KerasLayer(URL,
input_shape=(IMAGE_RES, IMAGE_RES, 3),
trainable=False)
model_inception = tf.keras.Sequential([
feature_extractor,
layers.Dense(num_classes, activation='softmax')
])
model_inception.summary()
这是自定义回调,我在训练期间尝试获取输出层
import datetime
from keras.callbacks import Callback
class MyCustomCallback(tf.keras.callbacks.Callback):
def on_train_batch_begin(self, batch, logs=None):
print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))
def on_train_batch_end(self, batch, logs=None):
for i in range(len(model_inception.layers)):
inp = self.model.input # input placeholder
outputs = [layer.output for layer in self.model.layers] # all layer outputs
functors = [K.function([inp, K.learning_phase()], [out]) for out in outputs] # evaluation functions
input_shape = [1] + list(self.model.input_shape[1:])
test = np.random.random(input_shape)
layer_outs = [func([test, 1.]) for func in functors]
print('\n Training: batch {} ends at {}'.format( layer_outs , datetime.datetime.now().time()))
def on_test_batch_begin(self, batch, logs=None):
print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))
def on_test_batch_end(self, batch, logs=None):
# layer_output = get_3rd_layer_output(self.validation_data)[0]
print('Training: batch {} ends at {} with the output layer {}'.format(batch, datetime.datetime.now().time()))
The problem is in callback of how i can get the output/input of each layer at the end of each batch
这是我的自定义回调的模型编译和培训
model_inception.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
EPOCHS = 2
history = model_inception.fit(train_batches,
epochs=EPOCHS,
steps_per_epoch=20,
validation_data=validation_batches,callbacks=[MyCustomCallback()])
我尝试运行它时出现当前错误
AttributeError Traceback (most recent call last)
<ipython-input-10-5909c67ba93f> in <module>()
9 epochs=EPOCHS,
10 steps_per_epoch=20,
---> 11 validation_data=validation_batches,callbacks=[MyCustomCallback()])
12
13 # #Testing
11 frames
/tensorflow-2.0.0/python3.6/tensorflow_core/python/eager/lift_to_graph.py in <listcomp>(.0)
247 # Check that the initializer does not depend on any placeholders.
248 sources = object_identity.ObjectIdentitySet(sources or [])
-->249 visited_ops = set([x.op for x in sources])
250 op_outputs = collections.defaultdict(set)
251
AttributeError: 'int' object has no attribute 'op'