我是张力流初学者,尝试使用TextLineReader将存储在磁盘上的numpy数组读入TF。但是当我在TF中读取数组时,我看到的值与原始数组不同。有人可以指出我在这里犯的错误吗?请参阅下面的示例代码。谢谢
import tensorflow as tf
import numpy as np
import csv
#Write two numpy arrays to disk
a = np.arange(15).reshape(3, 5)
np.save("a.npy",a,allow_pickle=False)
b = np.arange(30).reshape(5, 6)
np.save("b.npy",b,allow_pickle=False)
with open('files.csv', 'w') as csvfile:
filewriter = csv.writer(csvfile, delimiter=',')
filewriter.writerow(['a.npy', 'b.npy'])
# Load a csv with the two array filenames
csv_filename = "files.csv"
filename_queue = tf.train.string_input_producer([csv_filename])
reader = tf.TextLineReader()
_, csv_filename_tf = reader.read(filename_queue)
record_defaults = [tf.constant([], dtype=tf.string), tf.constant([], dtype=tf.string)]
filename_i,filename_j = tf.decode_csv(
csv_filename_tf, record_defaults=record_defaults)
file_contents_i = tf.read_file(filename_i)
file_contents_j = tf.read_file(filename_j)
bytes_i = tf.decode_raw(file_contents_i, tf.int16)
array_i = tf.reshape(tf.cast(tf.slice(bytes_i, [0], [3*5]), tf.int16), [3, 5])
bytes_j = tf.decode_raw(file_contents_j, tf.int16)
array_j = tf.reshape(tf.cast(tf.slice(bytes_j, [0], [5*6]), tf.int16), [5, 6])
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
a_out, b_out = (sess.run([array_i, array_j]))
print(a)
print(a_out)
coord.request_stop()
coord.join(threads)
以下是我得到的输出:
预期产出(a)
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]]
收到输出:(a_out)
[[20115 19797 22864 1 118]
[10107 25956 25459 10098 8250]
[15399 14441 11303 10016 28518]]
答案 0 :(得分:0)
我认为tensorflow decode_raw和numpy的np.save不兼容。
答案 1 :(得分:0)
使用numpy的zip文件.npz
保存变量a
b
:
weights = {w.name : sess.run(w) for w in [a, b]}
np.savez(path, **weights)
加载:
weights = [a, b]
npz_weights = np.load(path)
for i,k in enumerate([w.name for w in weights]):
sess.run(weights[i].assign(npz_weights[k]))
答案 2 :(得分:0)
为了弄清楚发生了什么,我打印了bytes_i
而不是array_i
。
a_out, b_out = (sess.run([bytes_i, bytes_j]))
print(a_out)
我获得了以下列表:
[20115 19797 22864 1 118 10107 25956 25459 10098 8250 15399 14441
11303 10016 28518 29810 24946 24430 29295 25956 10098 8250 24902 29548
11365 10016 26739 28769 10085 8250 13096 8236 10549 8236 8317 8224
8224 8224 8224 8224 8224 8224 8224 8224 8224 8224 8224 8224
8224 8224 8224 8224 8224 8224 8224 8224 8224 8224 8224 8224
8224 8224 8224 2592 0 0 0 0 1 0 0 0
2 0 0 0 3 0 0 0 4 0 0 0
5 0 0 0 6 0 0 0 7 0 0 0
8 0 0 0 9 0 0 0 10 0 0 0
11 0 0 0 12 0 0 0 13 0 0 0
14 0 0 0]
在numpy文件中存储的数据前面似乎有一个标头。此外,似乎数据值另存为int64
,而不是int16
。
首先指定数组中值的类型:
a = np.arange(15).reshape(3, 5).astype(np.int16)
b = np.arange(30).reshape(5, 6).astype(np.int16)
然后读取文件的最后一个字节:
array_i = tf.reshape(tf.cast(tf.slice(bytes_i,
begin=[tf.size(bytes_i) - (3*5)],
size=[3*5]), tf.int16), [3, 5])
array_j = tf.reshape(tf.cast(tf.slice(bytes_j,
begin=[tf.size(bytes_j) - (5*6)],
size=[5*6]), tf.int16), [5, 6])