开发一种启发式方法来测试简单的匿名Python函数的等价性

时间:2012-04-01 08:56:40

标签: python function python-3.x bytecode

我知道Python 3中的函数比较是如何工作的(只是比较内存中的地址),我理解为什么。

我也理解" true"比较(函数fg返回给定相同参数的相同结果,对于任何参数?)实际上是不可能的。

我正在寻找介于两者之间的东西。我希望比较能够处理相同函数的最简单的情况,也可能是一些不那么简单的情况:

lambda x : x == lambda x : x # True
lambda x : 2 * x == lambda y : 2 * y # True
lambda x : 2 * x == lambda x : x * 2 # True or False is fine, but must be stable
lambda x : 2 * x == lambda x : x + x # True or False is fine, but must be stable

请注意,我有兴趣为匿名函数(lambda)解决此问题,但如果解决方案也适用于命名函数,则不会介意。

这样做的动机是在blist模块内部,在对它们执行union等之前验证两个sortedset实例具有相同的排序函数会很好。

命名函数的兴趣不大,因为我可以假设它们不相同时它们是不同的。毕竟,假设有人在key参数中创建了两个带有命名函数的sortsets。如果他们打算这些实例兼容"出于集合操作的目的,它们可能使用相同的函数,而不是执行相同操作的两个单独的命名函数。

我只能想到三种方法。他们似乎都很难,所以任何想法都值得赞赏。

  1. 比较字节码可能会有效,但它的实现依赖可能很烦人(因此,在一个Python上运行的代码会在另一个Python上运行)。

  2. 比较标记化的源代码似乎合理且便携。当然,它的功能不那么强大(因为相同的功能更容易被拒绝)。

  3. 从一些符号计算教科书中借用的实体启发式理论上是最好的方法。它对我来说似乎太沉重了,但它实际上可能是一个很好的选择,因为lambda函数通常很小,因此运行速度很快。

  4. 修改

    一个更复杂的例子,基于@delnan的评论:

    # global variable
    fields = ['id', 'name']
    
    def my_function():
      global fields
      s1 = sortedset(key = lambda x : x[fields[0].lower()])
      # some intervening code here
      # ...
      s2 = sortedset(key = lambda x : x[fields[0].lower()])
    

    我希望s1s2的关键功能评估为相同吗?

    如果插入代码完全包含任何函数调用,则fields的值可能会被修改,从而导致s1s2的不同关键函数。由于我们显然没有进行控制流分析来解决这个问题,因此我们必须将这两个lambda函数评估为不同,如果我们试图在运行时之前执行此评估。 (即使fields不是全局的,也可能有另一个名字绑定它等等。)这会严重限制整个练习的用处,因为很少有lambda函数不依赖于环境。

    编辑2:

    我意识到比较运行时存在的函数对象非常重要。没有它,所有依赖外部范围变量的函数都无法比较;并且大多数有用的函数都有这样的依赖关系。在运行时考虑,具有相同签名的所有函数都可以以干净,逻辑的方式进行比较,无论它们依赖于什么,是否不纯,等等。

    因此,我不仅需要字节码,还需要创建函数对象时的全局状态(大概是__globals__)。然后我必须匹配外部范围的所有变量和__globals__的值。

2 个答案:

答案 0 :(得分:8)

编辑以检查外部状态是否会影响排序功能以及两个功能是否相同。


我砍了dis.dis和朋友输出到类似全局文件的对象。然后我删除行号和规范化变量名(不触及常量)并比较结果。

你可以将其清理干净,以便dis.dis和朋友yield排除这些行,这样你就不必记录他们的输出。但这是使用dis.dis进行功能比较和最小变化的概念验证概念。

import types
from opcode import *
_have_code = (types.MethodType, types.FunctionType, types.CodeType,
              types.ClassType, type)

def dis(x):
    """Disassemble classes, methods, functions, or code.

    With no argument, disassemble the last traceback.

    """
    if isinstance(x, types.InstanceType):
        x = x.__class__
    if hasattr(x, 'im_func'):
        x = x.im_func
    if hasattr(x, 'func_code'):
        x = x.func_code
    if hasattr(x, '__dict__'):
        items = x.__dict__.items()
        items.sort()
        for name, x1 in items:
            if isinstance(x1, _have_code):
                print >> out,  "Disassembly of %s:" % name
                try:
                    dis(x1)
                except TypeError, msg:
                    print >> out,  "Sorry:", msg
                print >> out
    elif hasattr(x, 'co_code'):
        disassemble(x)
    elif isinstance(x, str):
        disassemble_string(x)
    else:
        raise TypeError, \
              "don't know how to disassemble %s objects" % \
              type(x).__name__

def disassemble(co, lasti=-1):
    """Disassemble a code object."""
    code = co.co_code
    labels = findlabels(code)
    linestarts = dict(findlinestarts(co))
    n = len(code)
    i = 0
    extended_arg = 0
    free = None
    while i < n:
        c = code[i]
        op = ord(c)
        if i in linestarts:
            if i > 0:
                print >> out
            print >> out,  "%3d" % linestarts[i],
        else:
            print >> out,  '   ',

        if i == lasti: print >> out,  '-->',
        else: print >> out,  '   ',
        if i in labels: print >> out,  '>>',
        else: print >> out,  '  ',
        print >> out,  repr(i).rjust(4),
        print >> out,  opname[op].ljust(20),
        i = i+1
        if op >= HAVE_ARGUMENT:
            oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
            extended_arg = 0
            i = i+2
            if op == EXTENDED_ARG:
                extended_arg = oparg*65536L
            print >> out,  repr(oparg).rjust(5),
            if op in hasconst:
                print >> out,  '(' + repr(co.co_consts[oparg]) + ')',
            elif op in hasname:
                print >> out,  '(' + co.co_names[oparg] + ')',
            elif op in hasjrel:
                print >> out,  '(to ' + repr(i + oparg) + ')',
            elif op in haslocal:
                print >> out,  '(' + co.co_varnames[oparg] + ')',
            elif op in hascompare:
                print >> out,  '(' + cmp_op[oparg] + ')',
            elif op in hasfree:
                if free is None:
                    free = co.co_cellvars + co.co_freevars
                print >> out,  '(' + free[oparg] + ')',
        print >> out

def disassemble_string(code, lasti=-1, varnames=None, names=None,
                       constants=None):
    labels = findlabels(code)
    n = len(code)
    i = 0
    while i < n:
        c = code[i]
        op = ord(c)
        if i == lasti: print >> out,  '-->',
        else: print >> out,  '   ',
        if i in labels: print >> out,  '>>',
        else: print >> out,  '  ',
        print >> out,  repr(i).rjust(4),
        print >> out,  opname[op].ljust(15),
        i = i+1
        if op >= HAVE_ARGUMENT:
            oparg = ord(code[i]) + ord(code[i+1])*256
            i = i+2
            print >> out,  repr(oparg).rjust(5),
            if op in hasconst:
                if constants:
                    print >> out,  '(' + repr(constants[oparg]) + ')',
                else:
                    print >> out,  '(%d)'%oparg,
            elif op in hasname:
                if names is not None:
                    print >> out,  '(' + names[oparg] + ')',
                else:
                    print >> out,  '(%d)'%oparg,
            elif op in hasjrel:
                print >> out,  '(to ' + repr(i + oparg) + ')',
            elif op in haslocal:
                if varnames:
                    print >> out,  '(' + varnames[oparg] + ')',
                else:
                    print >> out,  '(%d)' % oparg,
            elif op in hascompare:
                print >> out,  '(' + cmp_op[oparg] + ')',
        print >> out

def findlabels(code):
    """Detect all offsets in a byte code which are jump targets.

    Return the list of offsets.

    """
    labels = []
    n = len(code)
    i = 0
    while i < n:
        c = code[i]
        op = ord(c)
        i = i+1
        if op >= HAVE_ARGUMENT:
            oparg = ord(code[i]) + ord(code[i+1])*256
            i = i+2
            label = -1
            if op in hasjrel:
                label = i+oparg
            elif op in hasjabs:
                label = oparg
            if label >= 0:
                if label not in labels:
                    labels.append(label)
    return labels

def findlinestarts(code):
    """Find the offsets in a byte code which are start of lines in the source.

    Generate pairs (offset, lineno) as described in Python/compile.c.

    """
    byte_increments = [ord(c) for c in code.co_lnotab[0::2]]
    line_increments = [ord(c) for c in code.co_lnotab[1::2]]

    lastlineno = None
    lineno = code.co_firstlineno
    addr = 0
    for byte_incr, line_incr in zip(byte_increments, line_increments):
        if byte_incr:
            if lineno != lastlineno:
                yield (addr, lineno)
                lastlineno = lineno
            addr += byte_incr
        lineno += line_incr
    if lineno != lastlineno:
        yield (addr, lineno)

class FakeFile(object):
    def __init__(self):
        self.store = []
    def write(self, data):
        self.store.append(data)

a = lambda x : x
b = lambda x : x # True
c = lambda x : 2 * x
d = lambda y : 2 * y # True
e = lambda x : 2 * x
f = lambda x : x * 2 # True or False is fine, but must be stable
g = lambda x : 2 * x
h = lambda x : x + x # True or False is fine, but must be stable

funcs = a, b, c, d, e, f, g, h

outs = []
for func in funcs:
    out = FakeFile()
    dis(func)
    outs.append(out.store)

import ast

def outfilter(out):
    for i in out:
        if i.strip().isdigit():
            continue
        if '(' in i:
            try:
                ast.literal_eval(i)
            except ValueError:
                i = "(x)"
        yield i

processed_outs = [(out, 'LOAD_GLOBAL' in out or 'LOAD_DECREF' in out)
                            for out in (''.join(outfilter(out)) for out in outs)]

for (out1, polluted1), (out2, polluted2) in zip(processed_outs[::2], processed_outs[1::2]):
    print 'Bytecode Equivalent:', out1 == out2, '\nPolluted by state:', polluted1 or polluted2

输出为TrueTrueFalseFalse并且稳定。如果输出将取决于外部状态 - 全局状态或闭包,则“污染”布尔值为真。

答案 1 :(得分:6)

所以,让我们首先解决一些技术问题。

1)字节代码:它可能不是问题,因为您可以使用dis模块来获取“字节码”,而不是检查pyc(二进制文件)。 e.g。

>>> f = lambda x, y : x+y
>>> dis.dis(f)
  1           0 LOAD_FAST                0 (x)
              3 LOAD_FAST                1 (y)
              6 BINARY_ADD          
              7 RETURN_VALUE 

无需担心平台。

2)标记化的源代码。 python再次拥有完成这项工作所需的一切。您可以使用ast模块来解析代码并获取ast。

>>> a = ast.parse("f = lambda x, y : x+y")
>>> ast.dump(a)
"Module(body=[Assign(targets=[Name(id='f', ctx=Store())], value=Lambda(args=arguments(args=[Name(id='x', ctx=Param()), Name(id='y', ctx=Param())], vararg=None, kwarg=None, defaults=[]), body=BinOp(left=Name(id='x', ctx=Load()), op=Add(), right=Name(id='y', ctx=Load()))))])"

因此,我们应该真正解决的问题是:确定两个函数是等效的分析是否可行?

人类很容易说2*x等于x+x,但我们如何创建算法来证明呢?

如果这是您想要实现的目标,您可能需要检查一下:http://en.wikipedia.org/wiki/Computer-assisted_proof

但是,如果最终你只想断言两个不同的数据集按相同的顺序排序,你只需要在数据集B上运行排序函数A,反之亦然,然后检查结果。如果它们相同,则功能可能在功能上相同。当然,检查仅对所述数据集有效。