如何将Tensorflow图表从使用float32
转换为float16
?目前,存在用于量化和转换为八位整数的图优化。
尝试将float32
权重加载到float16
图表失败,并显示:
DataLossError (see above for traceback): Invalid size in bundle entry: key model/conv5_1/biases; stored size 1536; expected size 768
[[Node: save/RestoreV2_16 = RestoreV2[dtypes=[DT_HALF], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/RestoreV2_16/tensor_names, save/RestoreV2_16/shape_and_slices)]]
[[Node: save/RestoreV2_3/_39 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_107_save/RestoreV2_3", tensor_type=DT_HALF, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
答案 0 :(得分:6)
我认为我的解决方案绝对不是最好的,也不是最直接的解决方案,但没有其他人发布任何内容:
我所做的是以完全精确的方式训练网络并将其保存在检查点中。然后我构建了一个网络副本,将所有需要的变量设置为tf.float16的dtype并删除所有训练节点。最后,我按以下方式加载和转换变量:
previous_variables = [
var_name for var_name, _
in tf.contrib.framework.list_variables('path-to-checkpoint-file')]
#print(previous_variables)
sess.run(tf.global_variables_initializer())
restore_map = {}
for variable in tf.global_variables():
if variable.op.name in previous_variables:
var = tf.contrib.framework.load_variable(
'path-to-checkpoint-file', variable.op.name)
if(var.dtype == np.float32):
tf.add_to_collection('assignOps', variable.assign(
tf.cast(var, tf.float16)))
else:
tf.add_to_collection('assignOps', variable.assign(var))
sess.run(tf.get_collection('assignOps'))
如果你不想转换float32的张量,这显然有问题,我很幸运没有,因为我想将所有节点转换为float16精度。如果你有那些你可以进一步过滤其他if语句。我希望这能回答你的问题。
答案 1 :(得分:0)
我遇到了这个问题,但是我正在加载一个子图,其中包含一些需要加载或转换的变量,而有些则不需要。 在@Jendrik的基础上,这是一个返回分配操作的函数,给定一个字典,该字典将已保存的变量映射到新图形:
def assign_and_convert_halfPrecision(restore_dictinary, CHECKPOINT_PATH):
# Iterate over the dictionary containing the variables to load
for variable_name_old, varible_new in restore_dictinary.items():
# Load the variable from the checkpoint
var = tf.contrib.framework.load_variable(CHECKPOINT_PATH, variable_name_old)
# Assign to new graph
if(var.dtype == np.float32) and (varible_new.dtype == np.float16):
# If the variable is float16 in the new graph, we cast it
tf.add_to_collection('assignOps', varible_new.assign(tf.cast(var, tf.float16)))
else:
# If the variable in the old graph is float16 or the new variable is float32,
# we load it directly
tf.add_to_collection('assignOps', varible_new.assign(var))
# Return the operation
return tf.get_collection('assignOps')
要使用它,只需执行以下操作:
# Create a trivial dictionary (all custom loading can be added here, like change of scope names)
restore_dictionary = dict()
for a in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=''):
restore_dictionary[a.name[:-2]] = a
# Create the assignment and conversion op
assign_operation = assign_and_convert_halfPrecision(restore_dictionary, CHECKPOINT_PATH)
# Load
sess.run(assign_operation)
可以通过修改字典来控制加载,避免不应该加载的变量或更改要加载的变量的范围。