我正在尝试对拍摄的某些照片运行CNN模型,但是从调试看来,当我在主目录中启动tf.session()时,它似乎卡住了。这是不允许的吗?有哪些替代方案?我正在使用会话将张量转换为numpy数组。
def main(unused_argv):
'birds = to be implemented'
shuffle_data = True
bird_train_path = 'C:/Users/first/Pictures/imds/*.png'
addrs = glob.glob(bird_train_path)
labels = [0 if 'duck' in addr else 1 if 'finch' in addrs else 2 for addr in addrs]
if shuffle_data:
c = list(zip(addrs, labels))
shuffle(c)
addrs, labels = zip(*c)
train_addrs = addrs[0:int(0.6 * len(addrs))]
train_labels = labels[0:int(0.6 * len(labels))]
val_addrs = addrs[int(0.6 * len(addrs)):int(0.8 * len(addrs))]
val_labels = labels[int(0.6 * len(addrs)):int(0.8 * len(addrs))]
def load_image(addr):
img = cv2.imread(addr)
img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_CUBIC)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32)
return img
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
train_filename = 'train.tfrecords' # address to save the TFRecords file
# open the TFRecords file
writer = tf.python_io.TFRecordWriter(train_filename)
for i in range(len(train_addrs)):
# print how many images are saved every 1000 images
if not i % 10:
print ('Train data: {}/{}'.format(i, len(train_addrs)))
sys.stdout.flush()
# Load the image
img = load_image(train_addrs[i])
labeltrain = train_labels[i]
# Create a feature
feature = {'train/label': _int64_feature(labeltrain),
'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}
# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
writer.close()
sys.stdout.flush()
val_filename = 'val.tfrecords' # address to save the TFRecords file
writer = tf.python_io.TFRecordWriter(val_filename)
for i in range(len(val_addrs)):
# print how many images are saved every 1000 images
if not i % 10:
print(
'Val data: {}/{}'.format(i, len(val_addrs)))
sys.stdout.flush()
# Load the image
img = load_image(val_addrs[i])
labelval = val_labels[i]
# Create a feature
feature = {'val/label': _int64_feature(labelval),
'val/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}
# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
writer.close()
sys.stdout.flush()
data_path = 'train.tfrecords'
feature = {'train/image': tf.FixedLenFeature([], tf.string),
'train/label': tf.FixedLenFeature([], tf.int64)}
filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features=feature)
image = tf.decode_raw(features['train/image'], tf.float32)
train_labels = tf.cast(features['train/label'], tf.int32)
train_data = tf.reshape(image, [256, 256, 3])
feature2 = {'val/image': tf.FixedLenFeature([], tf.string),
'val/label': tf.FixedLenFeature([], tf.int64)}
filename_queue2 = tf.train.string_input_producer([data_path], num_epochs=1)
reader2 = tf.TFRecordReader()
_, serialized_example = reader2.read(filename_queue)
features2 = tf.parse_single_example(serialized_example, features=feature2)
image2 = tf.decode_raw(features2['val/image'], tf.float32)
eval_labels = tf.cast(features2['val/label'], tf.int32)
eval_data = tf.reshape(image2, [256, 256, 3])
sess = tf.Session()
with sess.as_default():
train_data.eval()
train_labels.eval()
eval_data.eval()
eval_labels.eval(
https://gist.github.com/FirstSintax/1e64447da6d9be0ebd525319e8722d89