我正在创建一个方法来构造一个匿名方法来返回多个变量的函数,例如f(x,y,z)= b。我希望用户能够传递变量列表:
def get_multivar_lambda(expression, variables=["x"])
然后我希望返回的匿名函数能够准确地获取len(variables)
个参数(基于列表索引的位置或基于列表中字符串的关键字)。我知道我可以使用*args
并查看长度,但这似乎不够优雅。
这可能吗?我怎么能这样做?
以下是我如何为一个变量(其中seval
来自模块simple_eval
)执行此操作的示例:
def get_lambda(expression, variable="x"):
return lambda arg: seval(expression.replace(variable, str(arg)))
这就是我如何通过检查传递的arguments*
的长度来实现的:
def get_multivar_lambda(expression, variables=["x"]):
def to_return(*arguments):
if len(variables) != len(arguments):
raise Exception("Number of arguments != number of variables")
for v, a in zip(variables, arguments):
expression.replace(v, a)
return seval(expression)
return to_return
编辑:我正在从用户输入中获取表达式和变量,因此一种安全的方法是最好的。
答案 0 :(得分:6)
如果您可以使用Python 3,那么新引入的(Python 3.3+)inspect.Signature
和inspect.Parameter
可以使您的代码非常干净(PEP 362 - Function Signature Object)。这些在装饰器中也非常方便:
from inspect import Parameter, signature, Signature
def get_multivar_lambda(expression, variables=["x"]):
params = [Parameter(v, Parameter.POSITIONAL_OR_KEYWORD) for v in variables]
sig = Signature(params)
def to_return(*args, **kwargs):
values = sig.bind(*args, **kwargs)
for name, val in values.arguments.items():
print (name, val)
to_return.__signature__ = signature(to_return).replace(parameters=params)
return to_return
<强>演示:强>
>>> f = get_multivar_lambda('foo')
>>> f(1)
x 1
>>> f(1, 2)
Traceback (most recent call last):
File "<pyshell#43>", line 1, in <module>
...
raise TypeError('too many positional arguments') from None
TypeError: too many positional arguments
>>> f(x=100)
x 100
也会为用户产生有用的错误消息:
>>> g = get_multivar_lambda('foo', variables=['x', 'y', 'z'])
>>> g(20, 30, x=1000)
Traceback (most recent call last):
File "<pyshell#48>", line 1, in <module>
....
TypeError: multiple values for argument 'x'
>>> g(1000, y=2000, z=500)
x 1000
y 2000
z 500
内省目的的功能签名:
>>> inspect.getargspec(g)
ArgSpec(args=['x', 'y', 'z'], varargs=None, keywords=None, defaults=None)
答案 1 :(得分:2)
这样的东西绝对是可能的。我使用ast
编写了一个解决方案。它比其他解决方案更冗长,但返回的对象是一个无需任何中间编译步骤的函数,例如simple_eval
解决方案。
import ast
def get_multi_lambda(expr, args=()):
code_stmt = ast.parse(expr, mode='eval')
collector = NameCollector()
collector.visit(code_stmt)
arg_set = set(args)
if arg_set - collector.names:
raise TypeError("unused args", arg_set - collector.names)
elif collector.names - arg_set:
# very zealous, meant to stop execution of arbitrary code
# -- prevents use of *any* name that is not an argument to the function
# -- unfortunately this naive approach also stops things like sum
raise TypeError("attempted nonlocal name access",
collector.names - arg_set)
func_node = create_func_node(args, code_stmt)
code_obj = compile(func_node, "<generated>", "eval")
return eval(code_obj, {}, {})
def create_func_node(args, code_stmt):
lambda_args = ast.arguments(
args=[ast.arg(name, None) for name in args],
vararg=None, varargannotation=None, kwonlyargs=[], kwarg=None,
kwargannotation=None, defaults=[], kw_defaults=[]
)
func = ast.Lambda(args=lambda_args, body=code_stmt.body)
expr = ast.Expression(func)
ast.fix_missing_locations(expr)
return expr
class NameCollector(ast.NodeVisitor):
"""Finds all the names used by an ast node tree."""
def __init__(self):
self.names = set()
def visit_Name(self, node):
self.names.add(node.id)
# example usage
func = get_multi_lambda('a / b + 1', ['a', 'b'])
print(func(3, 4)) # prints 1.75 in python 3
您可以选择排除第二个名称检查是否可以信任这些多lambda表达式的来源,或者您可以为您认为合适的某些名称添加例外。例如。 min
,max
,sum
等等......
答案 2 :(得分:1)
您可以将表达式解析为AST。然后,您可以浏览AST以评估表达式。如果您明确列出了您希望处理的节点类型,这可以是安全的。
例如,使用J.F. Sebastian's AST evaluator,您可以执行类似
的操作import ast
import operator as op
import textwrap
def make_func(expression, variables):
template = textwrap.dedent('''\
def func({}):
return eval_expr({!r}, locals())
''').format(','.join(variables), expression)
namespace = {'eval_expr':eval_expr}
exec template in namespace
return namespace['func']
def eval_expr(expr, namespace):
"""
>>> eval_expr('2^6')
4
>>> eval_expr('2**6')
64
>>> eval_expr('1 + 2*3**(4^5) / (6 + -7)')
-5.0
"""
# Module(body=[Expr(value=...)])
return eval_(ast.parse(expr).body[0].value, namespace)
def eval_(node, namespace=None):
"""
https://stackoverflow.com/a/9558001/190597 (J.F. Sebastian)
"""
if namespace is None:
namespace = dict()
if isinstance(node, ast.Num): # <number>
return node.n
elif isinstance(node, ast.operator): # <operator>
return operators[type(node)]
elif isinstance(node, ast.BinOp): # <left> <operator> <right>
return eval_(node.op, namespace)(eval_(node.left, namespace),
eval_(node.right, namespace))
elif isinstance(node, ast.Name):
return namespace[node.id]
else:
raise TypeError(node)
operators = {ast.Add: op.add, ast.Sub: op.sub, ast.Mult: op.mul,
ast.Div: op.truediv, ast.Pow: op.pow, ast.BitXor: op.xor,
ast.USub: op.neg}
f = make_func('x', ['x'])
print(f(2))
# 2
g = make_func('x+y+z', ['x','y','z'])
print(g(1,2,3))
# 6
可以这样使用:
f = make_func('x', ['x'])
print(f(2))
# 2
g = make_func('x+y+z', ['x','y','z'])
print(g(1,2,3))
# 6
答案 3 :(得分:1)
我认为你不能完全按照自己的意愿行事(通常用特定数量的参数来定义函数)。
但是simpleeval内置了变量替换:https://pypi.python.org/pypi/simpleeval#names
所以要吸取教训:
答案 4 :(得分:1)
我发现使用类对象而不是标准函数应该更好。
from simpleeval import simple_eval as seval
class MultivarLambda(object):
def __init__(self, expression, variables):
self.__expression = expression
self.__variables = variables
def __call__(self, *args):
line = self.__expression
for v, arg in zip(self.__variables, args):
line = line.replace(v, arg)
return seval(line)
f = MultivarLambda("(A)**2 + (B)**2", ["A", "B"])
print f('3', '4')
print f('5', '-12')
# 25
# 169