简化动态列表中字符串的修改?

时间:2018-04-25 21:17:58

标签: python string replace split

我想将字符串列表从一种格式更改为另一种格式。

可以在此处找到完整列表的示例:https://gist.github.com/ProGamerGov/1d728e7ca4cc52abf398277642e4ee78

我想要做的一些例子:

Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

要:

(2): Conv2d(3 -> 64, 3x3, 1,1, 1,1)

而且:

Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

对此:

(8): Conv2d(64 -> 128, 3x3, 1,1, 1,1)

我目前的代码设置如下:

def modify_text(text, new): 
    return text.replace(", ","*", 1).replace(", ", new, 1)

for i, layer in enumerate(net): 
        if "Conv2d" in str(layer):
           layer = str(layer).replace(","," ->", 1)
           layer = modify_text(layer, "x").replace("kernel_size=(", "").replace("stride=(", "").replace("padding=(", "").replace(")","", 3)
           layer = modify_text(modify_text(layer, ","), ",").replace("*",", ")
           print("  (" + str(i+1) + "): " + layer)

但我觉得我可以有一个更好看/更简单的方法来做到这一点?

编辑,我已将我的设置简化为:

regx_map = r'(2d).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*'
regx_pool = r'(2d).*?(\d+).*?(\d+).*?(\d+).*'
for i, layer in enumerate(net): 
     if "Conv2d" in str(layer):
          print("  (" + str(i+1) + "): " + re.sub(regx_map, r'\1(\2 -> \3, \4x\5, \6,\7, \8,\9)', str(layer)))
     elif "MaxPool2d" in str(layer) or "AvgPool2d" in str(layer):
          print("  (" + str(i+1) + "): " + re.sub(regx_pool, r'\1(\2x\2, \3,\3)', str(layer)))
     else:
          print("  (" + str(i+1) + "): " + "nn." + str(layer).split("(", 1)[0]) 

2 个答案:

答案 0 :(得分:4)

您可以使用正则表达式来匹配:

dependencies {
    implementation group: 'org.joda', name: 'joda-convert', version: '2.0.1', classifier: 'classic'
    implementation 'joda-time:joda-time:2.9.4'
}

输出:

import re
s = 'Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))'
regx = r'(Conv2d).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*?(\d+).*'

print(re.sub(regx, r'\1(\2 -> \3, \4x\5, \6,\7,\8,\9)', s))

如果你想让正则表达式更加强大,代价是它很长:

Conv2d(3 -> 64, 3x3, 1,1,1,1)

尝试here

答案 1 :(得分:0)

您可以使用正则表达式来解析字符串中函数的签名,将参数转换为所需的语法,然后使用re.sub添加新签名:

import re
import itertools
def signature(header):
  s = re.findall('(?<=\()[\w\W]+(?=\)$)', header)
  return re.split(',\s(?=\w\w)', s[0]) if s else ''

def combine_args(d):
   if '=' in d:
      return '{}x{}'.format(*re.findall('\d+', d)) if 'kernel_size' in d and len(re.findall('\d+', d)) == 2 else '{},{}'.format(*re.findall('\d+', d)) if len(re.findall('\d+', d)) == 2 else d
   return d

def combine_header(d):
  vals = [[a, list(b)] for a, b in itertools.groupby(d, key=lambda x:x.isdigit())]
  return list(itertools.chain(*[[' -> '.join(b)] if a else [combine_args(i) for i in b] for a, b in vals]))

lines = ['TVLoss()', 'Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'ContentLoss((crit): MSELoss())', 'Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())']
final_lines = [re.sub('(?<=\()[\w\W]+(?=\)$)', ', '.join(combine_header(c)), a) for a, c in zip(lines, map(signature, lines))]

输出:

['TVLoss()', 'Conv2d(3 -> 64, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(64 -> 64, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(64 -> 128, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(128 -> 128, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(128 -> 256, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(256 -> 256, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'Conv2d(256 -> 256, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'Conv2d(256 -> 256, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(256 -> 512, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())', 'Conv2d(512 -> 512, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'ContentLoss((crit): MSELoss())', 'Conv2d(512 -> 512, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'Conv2d(512 -> 512, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)', 'Conv2d(512 -> 512, 3x3, 1,1, 1,1)', 'ReLU(inplace)', 'StyleLoss((gram): GramMatrix()(crit): MSELoss())']