TF Slim:Fine Tune自定义数据集

时间:2018-04-05 19:51:32

标签: python tensorflow tensorflow-slim

我正在尝试微调我的自定义数据集上的 Mobilenet_v2_1.4_224 模型以进行图像分类任务。 我正在关注本教程TensorFlow-Slim image classification library。 我已经创建了.tfrecord列车和验证文件。当我尝试从现有检查点微调时,我收到以下错误:

  

InvalidArgumentError(参见上面的回溯):Assign要求两个张量的形状匹配。 lhs shape = [1,1,24,144] rhs shape = [1,1,32,192]            [[节点:保存/分配_149 =分配[T = DT_FLOAT,_class = [“loc:@ MobilenetV2 / expanded_conv_2 / expand / weights”],use_locking = true,validate_shape = true,_device =“/ job:localhost / replica:0 / task:0 / device:CPU:0“](MobilenetV2 / expanded_conv_2 / expand / weights,save / RestoreV2:149)]]

我使用的微调脚本是:

DATASET_DIR = G:\数据集

TRAIN_DIR = G:\数据集\情绪模型\ mobilenet_v2

CHECKPOINT_PATH = C:\ Users \用户联想\桌面\ mobilenet_v2 \ mobilenet_v2_1.4_224.ckpt

python train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=emotion \
--dataset_split_name=train \
--model_name=mobilenet_v2 \
--train_image_size=224 \
--clone_on_cpu=True \
--checkpoint_path=${CHECKPOINT_PATH} \
--checkpoint_exclude_scopes=MobilenetV2/Logits \
--trainable_scopes=MobilenetV2/Logits

我怀疑错误是由于最后两个参数“checkpoint_exclude_scopes”或“trainable_scopes”造成的。

我知道通过删除最后两层并为自定义数据集类别创建我们自己的softmax层,这两个参数被用于转移学习。但我不确定我是否为他们传递了正确的价值。

1 个答案:

答案 0 :(得分:5)

要重新训练模型,您必须根据自定义的班级数进行微调

  

MobilenetV2 / Predictions和MobilenetV2 / predics

--checkpoint_exclude_scopes=MobilenetV2/Logits,MobilenetV2/Predictions,MobilenetV2/predics \
--trainable_scopes=MobilenetV2/Logits,MobilenetV2/Predictions,MobilenetV2/predics \

在mobilenet_v2.py中,对于mobilenet和mobilenet_base, depth_multiplier = 1 ,您应该将其更改为1.4

@slim.add_arg_scope 
def mobilenet_base(input_tensor, depth_multiplier=1.4, **kwargs): 
"""Creates base of the mobilenet (no pooling and no logits) .""" 
return mobilenet(input_tensor,
                 depth_multiplier=depth_multiplier,
                 base_only=True, **kwargs)

@slim.add_arg_scope 
def mobilenet(input_tensor,
                  num_classes=1001,
                  depth_multiplier=1.4,
                  scope='MobilenetV2',
                  conv_defs=None,
                  finegrain_classification_mode=False,
                  min_depth=None,
                  divisible_by=None,
                  **kwargs):