在主要功能中使用tf.Session()

时间:2018-07-05 14:26:19

标签: python tensorflow

我正在尝试对拍摄的某些照片运行CNN模型,但是从调试看来,当我在主目录中启动tf.session()时,它似乎卡住了。这是不允许的吗?有哪些替代方案?我正在使用会话将张量转换为numpy数组。

def main(unused_argv):
'birds = to be implemented'
shuffle_data = True
bird_train_path = 'C:/Users/first/Pictures/imds/*.png'

addrs = glob.glob(bird_train_path)
labels = [0 if 'duck' in addr else 1 if 'finch' in addrs else 2 for addr in addrs]

if shuffle_data:
    c = list(zip(addrs, labels))
    shuffle(c)
    addrs, labels = zip(*c)

train_addrs = addrs[0:int(0.6 * len(addrs))]
train_labels = labels[0:int(0.6 * len(labels))]

val_addrs = addrs[int(0.6 * len(addrs)):int(0.8 * len(addrs))]
val_labels = labels[int(0.6 * len(addrs)):int(0.8 * len(addrs))]

def load_image(addr):
    img = cv2.imread(addr)
    img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_CUBIC)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32)
    return img

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

train_filename = 'train.tfrecords'  # address to save the TFRecords file
# open the TFRecords file
writer = tf.python_io.TFRecordWriter(train_filename)
for i in range(len(train_addrs)):
    # print how many images are saved every 1000 images
    if not i % 10:
        print ('Train data: {}/{}'.format(i, len(train_addrs)))
        sys.stdout.flush()
    # Load the image
    img = load_image(train_addrs[i])
    labeltrain = train_labels[i]
    # Create a feature
    feature = {'train/label': _int64_feature(labeltrain),
               'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}
    # Create an example protocol buffer
    example = tf.train.Example(features=tf.train.Features(feature=feature))

    # Serialize to string and write on the file
    writer.write(example.SerializeToString())

writer.close()
sys.stdout.flush()

val_filename = 'val.tfrecords'  # address to save the TFRecords file
writer = tf.python_io.TFRecordWriter(val_filename)
for i in range(len(val_addrs)):
    # print how many images are saved every 1000 images
    if not i % 10:
        print(
        'Val data: {}/{}'.format(i, len(val_addrs)))
        sys.stdout.flush()
    # Load the image
    img = load_image(val_addrs[i])
    labelval = val_labels[i]
    # Create a feature
    feature = {'val/label': _int64_feature(labelval),
               'val/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}
    # Create an example protocol buffer
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    # Serialize to string and write on the file
    writer.write(example.SerializeToString())
writer.close()
sys.stdout.flush()

data_path = 'train.tfrecords'
feature = {'train/image': tf.FixedLenFeature([], tf.string),
           'train/label': tf.FixedLenFeature([], tf.int64)}
filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features=feature)
image = tf.decode_raw(features['train/image'], tf.float32)
train_labels = tf.cast(features['train/label'], tf.int32)
train_data = tf.reshape(image, [256, 256, 3])

feature2 = {'val/image': tf.FixedLenFeature([], tf.string),
           'val/label': tf.FixedLenFeature([], tf.int64)}
filename_queue2 = tf.train.string_input_producer([data_path], num_epochs=1)
reader2 = tf.TFRecordReader()
_, serialized_example = reader2.read(filename_queue)
features2 = tf.parse_single_example(serialized_example, features=feature2)
image2 = tf.decode_raw(features2['val/image'], tf.float32)
eval_labels = tf.cast(features2['val/label'], tf.int32)
eval_data = tf.reshape(image2, [256, 256, 3])

sess = tf.Session()
with sess.as_default():
    train_data.eval()
    train_labels.eval()
    eval_data.eval()
    eval_labels.eval(

https://gist.github.com/FirstSintax/1e64447da6d9be0ebd525319e8722d89

0 个答案:

没有答案