在tensorflow网站上有关数据集创建的教程之后,我编写了以下图像分类代码。问题在于创建完数据集变量后,代码将停止执行并终止。
path_ds = tf.data.Dataset.from_sparse_tensor_slices(all_images)
# PROGRAM TERMINATES AND DOES NOT EXECUTE THE FOLLOWING LINE
image_ds = path_ds.map(load_image)
我最初的猜测是急切执行的问题。 我尝试使用Visual Studio代码调试器进行调试,但步履蹒跚。
import os
import numpy as np
import tensorflow as tf
from PIL import Image
tf.compat.v1.enable_eager_execution()
test_folder = "../Test"
train_folder = "../Train"
BATCH_SIZE = 32
folders = os.listdir(train_folder)
labels = sorted(folders, key = lambda x : int(x.split("_")[1]) if "c" in x else \
int(x.split("_")[1])+ 100)
char_to_int = dict((label, index) for index, label in enumerate(labels))
int_to_char = dict((index, label) for index, label in enumerate(labels))
def load_image(infilename) :
img = Image.open( infilename )
img.load()
data = np.asarray( img, dtype="int32" )
return tf.convert_to_tensor(np.expand_dims(data, axis=2))
def get_all_image_names(folder):
children = os.listdir(folder)
all_images = []
all_labels = []
for i in children:
images = os.listdir(os.path.join(folder,i))
for image in images:
path = os.path.join(folder, i, image)
all_images.append(path)
all_labels.append(char_to_int[i])
path_ds = tf.data.Dataset.from_sparse_tensor_slices(all_images)
# PROGRAM TERMINATES AFTER THIS LINE
image_ds = path_ds.map(load_image)
label_ds = tf.data.Dataset.from_sparse_tensor_slices(tf.cast(all_labels, tf.int64))
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
return image_label_ds
dataset = get_all_image_names(train_folder)
ds = dataset.shuffle(buffer_size=10000)
ds = ds.batch(BATCH_SIZE)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(32, 32,1)),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(64, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(46, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.summary()
model.fit(ds, epochs=55, verbose=1)
print("MOdel fitted")