我用TensorFlow训练了一个模型,训练时间花了我几周。现在,我想在当前模型中添加新图层,并根据训练过的权重继续训练。但是,如果我恢复检查点,将发生import numpy as np
import multiprocessing
import Queue
import h5py
import glob
import time
def read_file(filename):
reader = h5py.File(filename)
data = {}
for key in reader.keys():
data[key] = np.array(reader.get(key))
return data
def producer(filename_queue, data_queue):
total_begin = time.time()
while True:
begin = time.time()
f = filename_queue.get_nowait()
if f is False:
break
data = read_file(f)
data_queue.put(data)
time_cost = time.time() - begin
print "Time cost of producer: %.2f ms" % (
time_cost * 1000)
total_time_cost = time.time() - total_begin
print "Avg time cost of producer: %.2f ms" % (
total_time_cost / len(filenames) * 1000)
if __name__ == "__main__":
process_num = 1
filenames = glob.glob("./data/*.hdf5")
filename_queue = multiprocessing.Queue(len(filenames) + process_num)
for f in filenames:
filename_queue.put(f)
for _ in range(process_num):
filename_queue.put(False)
data_queue = multiprocessing.Queue(1024)
processes = []
for _ in range(process_num):
processes.append(multiprocessing.Process(
target=producer, args=(filename_queue, data_queue)))
for p in processes:
p.daemon = True
for p in processes:
p.start()
# Reading data from queue
total_begin = time.time()
for _ in xrange(len(filenames)):
begin = time.time()
data = data_queue.get()
time_cost = time.time() - begin
print "Time cost of consumer: %.2f ms" % (
time_cost * 1000)
toal_time_cost = time.time() - total_begin
print "Avg time cost of consumer: %.2f ms" % (
total_time_cost / len(filenames) * 1000)
for p in processes:
p.join()
之类的错误。
那么如何在TensorFlow中从检查点恢复到修改过的模型?