我正在使用tensorflow对象检测API,我希望能够在python中动态编辑配置文件,如下所示。我想到了在python中使用协议缓冲区库,但是我不确定该怎么做。
model {
ssd {
num_classes: 1
image_resizer {
fixed_shape_resizer {
height: 300
width: 300
}
}
feature_extractor {
type: "ssd_inception_v2"
depth_multiplier: 1.0
min_depth: 16
conv_hyperparams {
regularizer {
l2_regularizer {
weight: 3.99999989895e-05
}
}
initializer {
truncated_normal_initializer {
mean: 0.0
stddev: 0.0299999993294
}
}
activation: RELU_6
batch_norm {
decay: 0.999700009823
center: true
scale: true
epsilon: 0.0010000000475
train: true
}
}
...
...
}
是否有一种简单/简便的方法来将image_resizer-> fixed_shape_resizer中的height等字段的特定值从300更改为500?并用修改后的值写回文件,而无需进行其他任何更改?
编辑: 虽然@DmytroPrylipko提供的答案适用于配置中的大多数参数,但我仍然遇到“复合字段”问题。
也就是说,如果我们有如下配置:
train_input_reader: {
label_map_path: "/tensorflow/data/label_map.pbtxt"
tf_record_input_reader {
input_path: "/tensorflow/models/data/train.record"
}
}
然后我添加以下行来编辑input_path:
pipeline_config.train_input_reader.tf_record_input_reader.input_path = "/tensorflow/models/data/train100.record"
它引发错误:
TypeError: Can't set composite field
答案 0 :(得分:2)
是的,使用Protobuf Python API非常简单:
edit_pipeline.py :
import argparse
import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2
def parse_arguments():
parser = argparse.ArgumentParser(description='')
parser.add_argument('pipeline')
parser.add_argument('output')
return parser.parse_args()
def main():
args = parse_arguments()
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(args.pipeline, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, pipeline_config)
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.height = 300
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.width = 300
config_text = text_format.MessageToString(pipeline_config)
with tf.gfile.Open(args.output, "wb") as f:
f.write(config_text)
if __name__ == '__main__':
main()
我调用脚本的方式:
TOOL_DIR=tool/tf-models/research
(
cd $TOOL_DIR
protoc object_detection/protos/*.proto --python_out=.
)
export PYTHONPATH=$PYTHONPATH:$TOOL_DIR:$TOOL_DIR/slim
python3 edit_pipeline.py pipeline.config pipeline_new.config
复合字段
如果字段重复,则必须将它们视为数组(例如,使用extend()
,append()
方法):
pipeline_config.train_input_reader.tf_record_input_reader.input_path[0] = '/tensorflow/models/data/train100.record'
答案 1 :(得分:0)
pipeline_config.eval_input_reader[0].label_map_path = label_map_full_path
pipeline_config.eval_input_reader[0].tf_record_input_reader.input_path[0] = val_record_path
答案 2 :(得分:0)
这与上面的代码相同,但有一些小的变化以适合tensorflow V2。
import argparse
import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2
def parse_arguments():
parser = argparse.ArgumentParser(description='')
parser.add_argument('pipeline')
parser.add_argument('output')
return parser.parse_args()
def main():
args = parse_arguments()
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.io.gfile.GFile(args.pipeline, "r") as f:
proto_str = f.read()
text_format.Merge(proto_str, pipeline_config)
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.height = 300
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.width = 300
config_text = text_format.MessageToString(pipeline_config)
with tf.io.gfile.GFile(args.output, "wb") as f:
f.write(config_text)
if __name__ == '__main__':
main()