在tensorflow中从磁盘读取numpy数组

时间:2018-03-01 13:01:31

标签: python numpy tensorflow

我是张力流初学者,尝试使用TextLineReader将存储在磁盘上的numpy数组读入TF。但是当我在TF中读取数组时,我看到的值与原始数组不同。有人可以指出我在这里犯的错误吗?请参阅下面的示例代码。谢谢

import tensorflow as tf
import numpy as np
import csv

#Write two numpy arrays to disk 
a = np.arange(15).reshape(3, 5)
np.save("a.npy",a,allow_pickle=False)

b = np.arange(30).reshape(5, 6)
np.save("b.npy",b,allow_pickle=False)

with open('files.csv', 'w') as csvfile:
    filewriter = csv.writer(csvfile, delimiter=',')
    filewriter.writerow(['a.npy', 'b.npy'])


# Load a csv with the two array filenames

csv_filename = "files.csv"
filename_queue = tf.train.string_input_producer([csv_filename])

reader = tf.TextLineReader()
_, csv_filename_tf = reader.read(filename_queue)


record_defaults = [tf.constant([], dtype=tf.string), tf.constant([], dtype=tf.string)]
filename_i,filename_j = tf.decode_csv(
    csv_filename_tf, record_defaults=record_defaults)

file_contents_i = tf.read_file(filename_i)
file_contents_j = tf.read_file(filename_j)

bytes_i = tf.decode_raw(file_contents_i, tf.int16)
array_i = tf.reshape(tf.cast(tf.slice(bytes_i, [0], [3*5]), tf.int16), [3, 5])

bytes_j = tf.decode_raw(file_contents_j, tf.int16)
array_j = tf.reshape(tf.cast(tf.slice(bytes_j, [0], [5*6]), tf.int16), [5, 6])

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    a_out, b_out = (sess.run([array_i, array_j]))

    print(a)
    print(a_out)

    coord.request_stop()
    coord.join(threads)

以下是我得到的输出:

预期产出(a)

[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]]

收到输出:(a_out)

[[20115 19797 22864     1   118]
 [10107 25956 25459 10098  8250]
 [15399 14441 11303 10016 28518]]

3 个答案:

答案 0 :(得分:0)

我认为tensorflow decode_raw和numpy的np.save不兼容。

答案 1 :(得分:0)

使用numpy的zip文件.npz

保存变量a b

weights = {w.name : sess.run(w) for w in [a, b]}
np.savez(path, **weights)

加载:

weights = [a, b]
npz_weights = np.load(path)
for i,k in enumerate([w.name for w in weights]):
    sess.run(weights[i].assign(npz_weights[k]))

答案 2 :(得分:0)

为了弄清楚发生了什么,我打印了bytes_i而不是array_i

a_out, b_out = (sess.run([bytes_i, bytes_j]))
print(a_out)

我获得了以下列表:

[20115 19797 22864     1   118 10107 25956 25459 10098  8250 15399 14441
 11303 10016 28518 29810 24946 24430 29295 25956 10098  8250 24902 29548
 11365 10016 26739 28769 10085  8250 13096  8236 10549  8236  8317  8224
  8224  8224  8224  8224  8224  8224  8224  8224  8224  8224  8224  8224
  8224  8224  8224  8224  8224  8224  8224  8224  8224  8224  8224  8224
  8224  8224  8224  2592     0     0     0     0     1     0     0     0
     2     0     0     0     3     0     0     0     4     0     0     0
     5     0     0     0     6     0     0     0     7     0     0     0
     8     0     0     0     9     0     0     0    10     0     0     0
    11     0     0     0    12     0     0     0    13     0     0     0
    14     0     0     0]

在numpy文件中存储的数据前面似乎有一个标头。此外,似乎数据值另存为int64,而不是int16

解决方案

首先指定数组中值的类型:

a = np.arange(15).reshape(3, 5).astype(np.int16)
b = np.arange(30).reshape(5, 6).astype(np.int16)

然后读取文件的最后一个字节:

array_i = tf.reshape(tf.cast(tf.slice(bytes_i,
                                      begin=[tf.size(bytes_i) - (3*5)],
                                      size=[3*5]), tf.int16), [3, 5])
array_j = tf.reshape(tf.cast(tf.slice(bytes_j,
                                      begin=[tf.size(bytes_j) - (5*6)],
                                      size=[5*6]), tf.int16), [5, 6])