解析numpy数组的字符串表示

时间:2017-05-09 20:33:32

标签: python string numpy

如果我只有numpy.array的字符串表示形式:

>>> import numpy as np
>>> arr = np.random.randint(0, 10, (10, 10))
>>> print(arr)  # this one!
[[9 4 7 3]
 [1 6 4 2]
 [6 7 6 0]
 [0 5 6 7]]

如何将其转换回numpy数组?实际插入,并不复杂,但我正在寻找一种程序化的方法。

使用,替换空格的简单正则表达式实际上适用于单位数整数:

>>> import re
>>> sub = re.sub('\s+', ',', """[[8 6 2 4 0 2]
...  [3 5 8 4 5 6]
...  [4 6 3 3 0 3]]
... """)
>>> sub
'[[8,6,2,4,0,2],[3,5,8,4,5,6],[4,6,3,3,0,3]],'  # the trailing "," is a bit annoying

它可以被转换为几乎(dtype可能丢失,但没关系)相同的数组:

>>> import ast
>>> np.array(ast.literal_eval(sub)[0])
array([[8, 6, 2, 4, 0, 2],
       [3, 5, 8, 4, 5, 6],
       [4, 6, 3, 3, 0, 3]])

但它无法用于多位数整数和浮点数:

>>> re.sub('\s+', ',', """[[ 0.  1.  6.  9.  1.  4.]
... [ 4.  8.  2.  3.  6.  1.]]
... """)
'[[,0.,1.,6.,9.,1.,4.],[,4.,8.,2.,3.,6.,1.]],'

因为它们在开头有一个额外的,

解决方案不一定需要基于正则表达式,任何其他适用于 unabriged 的方法(不会缩短为...)bool / int / float / complex数组为1 -4维度可以。

2 个答案:

答案 0 :(得分:3)

这是一个非常手动的解决方案:

import re
import numpy

def parse_array_str(array_string):
    tokens = re.findall(r'''             # Find all...
                            \[         | # opening brackets,
                            \]         | # closing brackets, or
                            [^\[\]\s]+   # sequences of other non-whitespace characters''',
                        array_string,
                        flags = re.VERBOSE)
    tokens = iter(tokens)

    # Chomp first [, handle case where it's not a [
    first_token = next(tokens)
    if first_token != '[':
        # Input must represent a scalar
        if next(tokens, None) is not None:
            raise ValueError("Can't parse input.")
        return float(first_token)  # or int(token), but not bool(token) for bools

    list_form = []
    stack = [list_form]

    for token in tokens:
        if token == '[':
            # enter a new list
            stack.append([])
            stack[-2].append(stack[-1])
        elif token == ']':
            # close a list
            stack.pop()
        else:
            stack[-1].append(float(token))  # or int(token), but not bool(token) for bools

    if stack:
        raise ValueError("Can't parse input - it might be missing text at the end.")

    return numpy.array(list_form)

或者基于检测插入逗号的位置的手动解决方案:

import re
import numpy

pattern = r'''# Match (mandatory) whitespace between...
              (?<=\]) # ] and
              \s+
              (?= \[) # [, or
              |
              (?<=[^\[\]\s]) 
              \s+
              (?= [^\[\]\s]) # two non-bracket non-whitespace characters
           '''

# Replace such whitespace with a comma
fixed_string = re.sub(pattern, ',', array_string, flags=re.VERBOSE)

output_array = numpy.array(ast.literal_eval(fixed_string))

答案 1 :(得分:1)

<强>更新

np.array(ast.literal_eval(re.sub(r'\]\s*\[',
                                 r'],[',
                                 re.sub(r'(\d+)\s+(\d+)', 
                                        r'\1,\2', 
                                        a.replace('\n','')))))

<强>测试

In [345]: a = np.random.rand(3,3,20).__str__()

In [346]: np.array(ast.literal_eval(re.sub(r'\]\s*\[',
     ...:                                  r'],[',
     ...:                                  re.sub(r'(\d+)\s+(\d+)',
     ...:                                         r'\1,\2',
     ...:                                         a.replace('\n','')))))
Out[346]:
array([[[  1.61804506e-01,   8.12734833e-01,   6.35872020e-01,   7.45560321e-01,   7.60322379e-01,   1.50271532e-01,   7.43559134e-01,   5.21169923e-
01,   4.10560219e-01,   1.77891635e-01,
           8.77997042e-01,   5.52165694e-02,   4.40322089e-01,   8.82732323e-01,   3.12101843e-01,   9.49019544e-01,   1.69709407e-01,   5.35675968e-
01,   3.53186538e-01,   2.39804555e-01],
        [  2.59834852e-01,   7.13464074e-01,   4.24374709e-01,   7.45214854e-01,   2.54193920e-01,   9.43753568e-01,   3.19657128e-02,   6.04311934e-
01,   4.58913230e-01,   9.21777675e-01,
           7.60741980e-02,   8.25952339e-01,   1.37270639e-01,   7.42065132e-01,   9.05089275e-01,   9.90206513e-02,   2.00671342e-01,   9.29283429e-
01,   8.87469279e-01,   2.78824797e-01],
        [  5.49303597e-01,   1.68139999e-01,   9.52643331e-01,   8.97801805e-01,   8.34317042e-01,   3.61338265e-01,   1.97822206e-01,   1.44672484e-
01,   4.62311800e-01,   6.45563044e-01,
           3.96650080e-01,   9.66557989e-01,   5.55279111e-01,   6.95327885e-01,   8.77989215e-01,   3.09452892e-01,   4.34898544e-02,   6.18538982e-
01,   6.11605477e-03,   5.30348496e-03]],

       [[  4.67741090e-01,   4.18749234e-01,   4.92742479e-01,   3.12952835e-01,   1.66866007e-01,   1.81524074e-01,   3.48737055e-01,   3.96121943e-
01,   7.56894807e-01,   4.99569007e-02,
           9.48425036e-01,   1.30331685e-01,   3.60872691e-01,   4.98930072e-01,   7.14775531e-01,   5.50048525e-01,   6.12293600e-01,   6.24329775e-
01,   3.74200599e-01,   6.77087300e-01],
        [  3.64029724e-01,   5.12225561e-01,   6.52844356e-01,   1.36063860e-01,   5.95311924e-01,   7.31286536e-01,   3.85353941e-01,   1.17983007e-
01,   3.78948410e-01,   3.66223737e-01,
           4.78195933e-01,   3.46903190e-01,   7.59476546e-01,   4.38877386e-01,   7.33342832e-01,   9.38044045e-01,   6.80193266e-01,   1.76191976e-
01,   2.84027688e-01,   8.85565762e-01],
        [  1.25801396e-01,   7.62014084e-01,   7.57817614e-01,   5.44511396e-01,   2.77615151e-01,   6.94968328e-01,   9.64537639e-01,   7.79804895e-
01,   8.45911428e-01,   1.59562236e-01,
           7.14207030e-01,   9.26019437e-01,   1.84258959e-01,   8.37627772e-01,   9.72586483e-01,   3.87408269e-01,   1.95596555e-01,   3.51684372e-
02,   7.14297398e-02,   2.70039164e-01]],

       [[  3.03855673e-01,   7.72762928e-01,   5.63591643e-01,   7.58142274e-01,   4.71340149e-01,   1.50447988e-01,   4.24416607e-01,   3.53647908e-
01,   4.83022443e-01,   7.72650844e-01,
           4.05579568e-01,   4.64825394e-01,   3.74864150e-01,   8.04635163e-01,   3.29960889e-01,   8.82488417e-01,   6.05332753e-01,   1.84514406e-
01,   4.47145930e-01,   6.96907260e-01],
        [  1.54041028e-01,   2.33380875e-01,   7.34935729e-01,   8.13397766e-01,   6.26194271e-02,   9.40103450e-01,   6.24356287e-01,   2.26074683e-
01,   5.43054373e-01,   7.03495296e-02,
           4.68091539e-02,   7.30366454e-01,   5.27159134e-01,   1.33293015e-01,   4.68391358e-01,   8.25307079e-01,   9.74953928e-01,   2.20242983e-
01,   3.42050900e-01,   7.86851567e-01],
        [  4.49176834e-01,   2.77129577e-01,   1.18051369e-01,   4.99016389e-01,   4.54702611e-04,   2.17932718e-01,   8.83065335e-01,   9.58966789e-
02,   1.52448380e-01,   7.18588641e-01,
           3.73546613e-01,   1.66186769e-01,   5.80381932e-01,   3.42510041e-01,   6.75739930e-01,   1.85382205e-01,   3.26533424e-01,   7.35004900e-
01,   9.22527439e-01,   9.96079190e-01]]])

旧回答:

我们可以尝试使用Pandas:

import io
import pandas as pd

In [294]: pd.read_csv(io.StringIO(a.replace('\n', '').replace(']', '\n').replace('[','')),
                      delim_whitespace=True, header=None).values
Out[294]:
array([[ 0.96725219,  0.01808783,  0.63087793,  0.45407222,  0.30586779,  0.04848813,  0.01797095],
       [ 0.87762897,  0.07705762,  0.33049588,  0.91429797,  0.5776607 ,  0.18207652,  0.2355932 ],
       [ 0.68803166,  0.31540537,  0.92606902,  0.83542726,  0.43457601,  0.44952604,  0.35121332],
       [ 0.14366487,  0.23486924,  0.16421432,  0.27709387,  0.19646975,  0.8243488 ,  0.37708642],
       [ 0.07594925,  0.36608386,  0.02087877,  0.07507932,  0.40005067,  0.84625563,  0.62827931],
       [ 0.63662663,  0.41408688,  0.43447501,  0.22135816,  0.58944708,  0.66456168,  0.5871466 ],
       [ 0.16807584,  0.70981667,  0.18597074,  0.02034372,  0.94706437,  0.61333699,  0.8444439 ]])

注意:它可能仅适用于没有...(省略号)

的2D数组