调查TensorFlow中的Float16培训

时间:2018-06-26 19:20:12

标签: python-3.x tensorflow

TensorFlow苗条的预训练模型及其权重保存在tf.float32中。为了提高内存效率,我想在tf.float16中加载一个预先训练的模型,然后在tf.float16中附加一些其他模块之后运行一些训练操作。

我知道由于张量数据类型不匹配,通常无法执行此操作。但是,是否有专门用于解决此问题的资源? 有人曾经尝试过类似的东西吗?

对于记录,以下是我用来验证是否可以在tf.float16中恢复预训练模型并且不起作用的摘要。

import tensorflow as tf
from inception_v4 import inception_v4, inception_v4_arg_scope
import numpy as np
slim = tf.contrib.slim
dtype = tf.float16
input = tf.placeholder(dtype, [None, 299,299,3])
with slim.arg_scope(inception_v4_arg_scope()):
    logits, endpoints = inception_v4(input)

saver = tf.train.Saver()

init_op = tf.global_variables_initializer()

img = np.random.randn(10, 299, 299, 3)

with tf.Session() as sess:
    sess.run(init_op)
    saver.restore(sess, './inception_v4.ckpt')
    l = sess.run(logits, feed_dict={input : img})

输出为

2018-06-26 21:14:51.708190: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
Traceback (most recent call last):
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1361, in _do_call
    return fn(*args)
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1340, in _run_fn
    target_list, status, run_metadata)
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 516, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected to restore a tensor of type half, got a tensor of type float instead: tensor_name = InceptionV4/AuxLogits/Aux_logits/biases
     [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_HALF, DT_HALF, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_HALF, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_HALF], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "test.py", line 18, in <module>
    saver.restore(sess, './inception_v4.ckpt')
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1755, in restore
    {self.saver_def.filename_tensor_name: save_path})
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 905, in run
    run_metadata_ptr)
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1137, in _run
    feed_dict_tensor, options, run_metadata)
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1355, in _do_run
    options, run_metadata)
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1374, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected to restore a tensor of type half, got a tensor of type float instead: tensor_name = InceptionV4/AuxLogits/Aux_logits/biases
     [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_HALF, DT_HALF, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_HALF, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_HALF], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

Caused by op 'save/RestoreV2', defined at:
  File "test.py", line 10, in <module>
    saver = tf.train.Saver()
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1293, in __init__
    self.build()
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1302, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1339, in _build
    build_save=build_save, build_restore=build_restore)
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 796, in _build_internal
    restore_sequentially, reshape)
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 449, in _AddRestoreOps
    restore_sequentially)
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 847, in bulk_restore
    return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1030, in restore_v2
    shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3271, in create_op
    op_def=op_def)
  File "/Users/meetukme/anaconda/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1650, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Expected to restore a tensor of type half, got a tensor of type float instead: tensor_name = InceptionV4/AuxLogits/Aux_logits/biases
     [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_HALF, DT_HALF, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_HALF, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_HALF], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

0 个答案:

没有答案