添加新图层并从TensorFlow中的检查点恢复培训

时间:2017-09-20 08:27:56

标签: tensorflow

我用TensorFlow训练了一个模型,训练时间花了我几周。现在,我想在当前模型中添加新图层,并根据训练过的权重继续训练。但是,如果我恢复检查点,将发生import numpy as np import multiprocessing import Queue import h5py import glob import time def read_file(filename): reader = h5py.File(filename) data = {} for key in reader.keys(): data[key] = np.array(reader.get(key)) return data def producer(filename_queue, data_queue): total_begin = time.time() while True: begin = time.time() f = filename_queue.get_nowait() if f is False: break data = read_file(f) data_queue.put(data) time_cost = time.time() - begin print "Time cost of producer: %.2f ms" % ( time_cost * 1000) total_time_cost = time.time() - total_begin print "Avg time cost of producer: %.2f ms" % ( total_time_cost / len(filenames) * 1000) if __name__ == "__main__": process_num = 1 filenames = glob.glob("./data/*.hdf5") filename_queue = multiprocessing.Queue(len(filenames) + process_num) for f in filenames: filename_queue.put(f) for _ in range(process_num): filename_queue.put(False) data_queue = multiprocessing.Queue(1024) processes = [] for _ in range(process_num): processes.append(multiprocessing.Process( target=producer, args=(filename_queue, data_queue))) for p in processes: p.daemon = True for p in processes: p.start() # Reading data from queue total_begin = time.time() for _ in xrange(len(filenames)): begin = time.time() data = data_queue.get() time_cost = time.time() - begin print "Time cost of consumer: %.2f ms" % ( time_cost * 1000) toal_time_cost = time.time() - total_begin print "Avg time cost of consumer: %.2f ms" % ( total_time_cost / len(filenames) * 1000) for p in processes: p.join() 之类的错误。

那么如何在TensorFlow中从检查点恢复到修改过的模型?

0 个答案:

没有答案