我有一个Expression<Func<IQueryable<TIn>,TOut>>
,我想将其转换为Expression<Func<IEnumerable<TIn>,TOut>>
。实际上,我的最终目标是将给定的树编译成Func<IEnumerable<TIn>,TOut>>
,但是一旦完成转换,这就变得微不足道了。
我可以将给定的lambda包装成一个在输入序列上首先调用AsQueryable()
的lambda,但是我认为这是非常低效的。我认为它必须遍历表达式树并在每次使用时进行编译。
有什么想法吗?
编辑:
当然,必须做出某些假设。转换只应知道如何将Queryable
的静态方法的精确匹配转换为Enumerable
的等效静态方法。否则,它要么失败,要么就做什么。我不在乎。
编辑2:
更多说明:
我要进行的过程的输入采用lambda表达式树。该lambda将IQueryable<T>
作为输入并产生一些输出。我想产生一个具有等效逻辑的新lambda,但是它将IEnumerable<T>
作为输入并产生等效输出。
例如,对Queryable.Where(...)
的所有调用都应替换为Enumerable.Where(...)
等对Queryable.Select(...)
,Enumerable.Select(...)
的调用。
编辑3:
一个例子:
// The expression I get transforms IQueryables, for instance this one:
Expression<Func<IQueryable<int>, double>> input =
qi => (double)qi.Sum() / qi.Count();
// I want an expression that transforms IEnumerables:
Expression<Func<IEnumerable<int>, double>> desiredOutput =
ei => (double)ei.Sum() / ei.Count();
// I can make it work like this:
var dirtyWorkaround = MakeDirtyWorkaround(input);
Expression<Func<IEnumerable<TIn>, TOut>> MakeDirtyWorkaround<TIn, TOut>(
Expression<Func<IQueryable<TIn>, TOut>> original)
{
// Doing this:
// ei => original.Invoke(ei.AsQueryable())
var asQueryableMethod = new Func<IEnumerable<TIn>, IQueryable<TIn>>(Queryable.AsQueryable).Method;
var parameter = Expression.Parameter(typeof(IEnumerable<TIn>), "ie");
return Expression.Lambda<Func<IEnumerable<TIn>, TOut>>(
Expression.Invoke(original,
Expression.Call(asQueryableMethod, parameter)),
parameter);
}
// But it's inefficient. Demonstration:
// The compiled expression can be cached.
var compiledDesired = desiredOutput.Compile();
var compiledDirty = dirtyWorkaround.Compile();
var exampleEnumerable = Enumerable.Range(1, 10);
var repetitions = 10_000;
// Desired test:
var desiredSw = Stopwatch.StartNew();
for (var i = 0; i < repetitions; ++i)
{
var exampleOutput = compiledDesired.Invoke(exampleEnumerable);
}
desiredSw.Stop();
// Dirty test:
var dirtySw = Stopwatch.StartNew();
for (var i = 0; i < repetitions; ++i)
{
// For every loop iteration, a query gets built on top of exampleEnumerable,
// then gets adapted to IEnumerable and compiled.
// It's ~1000 times slower in this case.
var exampleOutput = compiledDirty.Invoke(exampleEnumerable);
}
dirtySw.Stop();
Console.WriteLine($"Executed in {dirtySw.ElapsedMilliseconds} ms instead of {desiredSw.ElapsedMilliseconds} ms.");
// Executed in 3000 ms instead of 3 ms.
答案 0 :(得分:0)
鉴于.AsQueryable()
几乎可以实现我想要的功能,我在GitHub上查看了它的代码。它生成一个EnumerableQuery
类。在查询执行后查找类的内容,我进入EnumerableRewriter.cs,它执行了我感兴趣的棘手的表达式树转换。
我修改了该类以转换IQueryable<T>
个参数而不是常量EnumerableQuery
。可能还有皱纹可以消除,但这是一个很好的开始。
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
namespace Test
{
internal class EnumerableRewriter : ExpressionVisitor
{
public ReadOnlyDictionary<ParameterExpression, ParameterExpression> GetParameterReplacements()
=> parameterReplacements == null
? null
: new ReadOnlyDictionary<ParameterExpression, ParameterExpression>(parameterReplacements);
private Dictionary<ParameterExpression, ParameterExpression> parameterReplacements;
protected override Expression VisitParameter(ParameterExpression par)
{
var type = par.Type;
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IQueryable<>))
{
if (parameterReplacements == null)
parameterReplacements = new Dictionary<ParameterExpression, ParameterExpression>();
if (!parameterReplacements.TryGetValue(par, out var replacement))
{
var elementType = type.GetGenericArguments()[0];
replacement = Expression.Parameter(
typeof(IEnumerable<>).MakeGenericType(elementType),
par.Name);
parameterReplacements[par] = replacement;
}
return replacement;
}
return par;
}
// We must ensure that if a LabelTarget is rewritten that it is always rewritten to the same new target
// or otherwise expressions using it won't match correctly.
private Dictionary<LabelTarget, LabelTarget> _targetCache;
// Finding equivalent types can be relatively expensive, and hitting with the same types repeatedly is quite likely.
private Dictionary<Type, Type> _equivalentTypeCache;
protected override Expression VisitMethodCall(MethodCallExpression m)
{
var obj = Visit(m.Object);
var args = Visit(m.Arguments);
// check for args changed
if (obj != m.Object || args != m.Arguments)
{
var mInfo = m.Method;
var typeArgs = (mInfo.IsGenericMethod) ? mInfo.GetGenericArguments() : null;
if ((mInfo.IsStatic || mInfo.DeclaringType.IsAssignableFrom(obj.Type))
&& ArgsMatch(mInfo, args, typeArgs))
{
// current method is still valid
return Expression.Call(obj, mInfo, args);
}
else if (mInfo.DeclaringType == typeof(Queryable))
{
// convert Queryable method to Enumerable method
var seqMethod = FindEnumerableMethod(mInfo.Name, args, typeArgs);
args = FixupQuotedArgs(seqMethod, args);
return Expression.Call(obj, seqMethod, args);
}
else
{
// rebind to new method
var method = FindMethod(mInfo.DeclaringType, mInfo.Name, args, typeArgs);
args = FixupQuotedArgs(method, args);
return Expression.Call(obj, method, args);
}
}
return m;
}
private ReadOnlyCollection<Expression> FixupQuotedArgs(MethodInfo mi, ReadOnlyCollection<Expression> argList)
{
var pis = mi.GetParameters();
if (pis.Length > 0)
{
List<Expression> newArgs = null;
for (int i = 0, n = pis.Length; i < n; i++)
{
var arg = argList[i];
var pi = pis[i];
arg = FixupQuotedExpression(pi.ParameterType, arg);
if (newArgs == null && arg != argList[i])
{
newArgs = new List<Expression>(argList.Count);
for (var j = 0; j < i; j++)
{
newArgs.Add(argList[j]);
}
}
newArgs?.Add(arg);
}
if (newArgs != null)
argList = newArgs.AsReadOnly();
}
return argList;
}
private Expression FixupQuotedExpression(Type type, Expression expression)
{
var expr = expression;
while (true)
{
if (type.IsAssignableFrom(expr.Type))
return expr;
if (expr.NodeType != ExpressionType.Quote)
break;
expr = ((UnaryExpression)expr).Operand;
}
if (!type.IsAssignableFrom(expr.Type) && type.IsArray && expr.NodeType == ExpressionType.NewArrayInit)
{
var strippedType = StripExpression(expr.Type);
if (type.IsAssignableFrom(strippedType))
{
var elementType = type.GetElementType();
var na = (NewArrayExpression)expr;
var exprs = new List<Expression>(na.Expressions.Count);
for (int i = 0, n = na.Expressions.Count; i < n; i++)
{
exprs.Add(FixupQuotedExpression(elementType, na.Expressions[i]));
}
expression = Expression.NewArrayInit(elementType, exprs);
}
}
return expression;
}
protected override Expression VisitLambda<T>(Expression<T> node) => node;
private static Type GetPublicType(Type t)
{
// If we create a constant explicitly typed to be a private nested type,
// such as Lookup<,>.Grouping or a compiler-generated iterator class, then
// we cannot use the expression tree in a context which has only execution
// permissions. We should endeavour to translate constants into
// new constants which have public types.
if (t.IsGenericType && t.GetGenericTypeDefinition().GetInterfaces().Contains(typeof(IGrouping<,>)))
return typeof(IGrouping<,>).MakeGenericType(t.GetGenericArguments());
if (!t.IsNestedPrivate)
return t;
foreach (var iType in t.GetInterfaces())
{
if (iType.IsGenericType && iType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
return iType;
}
if (typeof(IEnumerable).IsAssignableFrom(t))
return typeof(IEnumerable);
return t;
}
private Type GetEquivalentType(Type type)
{
if (_equivalentTypeCache == null)
{
// Pre-loading with the non-generic IQueryable and IEnumerable not only covers this case
// without any reflection-based introspection, but also means the slightly different
// code needed to catch this case can be omitted safely.
_equivalentTypeCache = new Dictionary<Type, Type>
{
{ typeof(IQueryable), typeof(IEnumerable) },
{ typeof(IEnumerable), typeof(IEnumerable) }
};
}
if (!_equivalentTypeCache.TryGetValue(type, out var equiv))
{
var pubType = GetPublicType(type);
if (pubType.IsInterface && pubType.IsGenericType)
{
var genericType = pubType.GetGenericTypeDefinition();
if (genericType == typeof(IOrderedEnumerable<>))
equiv = pubType;
else if (genericType == typeof(IOrderedQueryable<>))
equiv = typeof(IOrderedEnumerable<>).MakeGenericType(pubType.GenericTypeArguments[0]);
else if (genericType == typeof(IEnumerable<>))
equiv = pubType;
else if (genericType == typeof(IQueryable<>))
equiv = typeof(IEnumerable<>).MakeGenericType(pubType.GenericTypeArguments[0]);
}
if (equiv == null)
{
var interfacesWithInfo = pubType.GetInterfaces().Select(IntrospectionExtensions.GetTypeInfo).ToArray();
var singleTypeGenInterfacesWithGetType = interfacesWithInfo
.Where(i => i.IsGenericType && i.GenericTypeArguments.Length == 1)
.Select(i => new { Info = i, GenType = i.GetGenericTypeDefinition() })
.ToArray();
var typeArg = singleTypeGenInterfacesWithGetType
.Where(i => i.GenType == typeof(IOrderedQueryable<>) || i.GenType == typeof(IOrderedEnumerable<>))
.Select(i => i.Info.GenericTypeArguments[0])
.Distinct()
.SingleOrDefault();
if (typeArg != null)
equiv = typeof(IOrderedEnumerable<>).MakeGenericType(typeArg);
else
{
typeArg = singleTypeGenInterfacesWithGetType
.Where(i => i.GenType == typeof(IQueryable<>) || i.GenType == typeof(IEnumerable<>))
.Select(i => i.Info.GenericTypeArguments[0])
.Distinct()
.Single();
equiv = typeof(IEnumerable<>).MakeGenericType(typeArg);
}
}
_equivalentTypeCache.Add(type, equiv);
}
return equiv;
}
private static ILookup<string, MethodInfo> s_seqMethods;
private static MethodInfo FindEnumerableMethod(string name, ReadOnlyCollection<Expression> args, params Type[] typeArgs)
{
if (s_seqMethods == null)
{
s_seqMethods = typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic)
.ToLookup(m => m.Name);
}
var mi = s_seqMethods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
Debug.Assert(mi != null, "All static methods with arguments on Queryable have equivalents on Enumerable.");
if (typeArgs != null)
return mi.MakeGenericMethod(typeArgs);
return mi;
}
private static MethodInfo FindMethod(Type type, string name, ReadOnlyCollection<Expression> args, Type[] typeArgs)
{
using (var en = type.GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic)
.Where(m => m.Name == name)
.GetEnumerator())
{
if (!en.MoveNext())
throw new InvalidOperationException($"No method '{name}' on type '{type.FullName}'.");
do
{
var mi = en.Current;
if (ArgsMatch(mi, args, typeArgs))
return (typeArgs != null) ? mi.MakeGenericMethod(typeArgs) : mi;
} while (en.MoveNext());
}
throw new InvalidOperationException($"No method '{name}{(typeArgs != null ? "<" + typeArgs + ">" : null)}' on type '{type.FullName}' matches arguments '{args}'.");
}
private static bool ArgsMatch(MethodInfo m, ReadOnlyCollection<Expression> args, Type[] typeArgs)
{
var mParams = m.GetParameters();
if (mParams.Length != args.Count)
return false;
if (!m.IsGenericMethod && typeArgs != null && typeArgs.Length > 0)
{
return false;
}
if (!m.IsGenericMethodDefinition && m.IsGenericMethod && m.ContainsGenericParameters)
{
m = m.GetGenericMethodDefinition();
}
if (m.IsGenericMethodDefinition)
{
if (typeArgs == null || typeArgs.Length == 0)
return false;
if (m.GetGenericArguments().Length != typeArgs.Length)
return false;
m = m.MakeGenericMethod(typeArgs);
mParams = m.GetParameters();
}
for (int i = 0, n = args.Count; i < n; i++)
{
var parameterType = mParams[i].ParameterType;
if (parameterType == null)
return false;
if (parameterType.IsByRef)
parameterType = parameterType.GetElementType();
var arg = args[i];
if (!parameterType.IsAssignableFrom(arg.Type))
{
if (arg.NodeType == ExpressionType.Quote)
{
arg = ((UnaryExpression)arg).Operand;
}
if (!parameterType.IsAssignableFrom(arg.Type) &&
!parameterType.IsAssignableFrom(StripExpression(arg.Type)))
{
return false;
}
}
}
return true;
}
private static Type StripExpression(Type type)
{
var isArray = type.IsArray;
var tmp = isArray ? type.GetElementType() : type;
var eType = GetExpressionType(tmp);
if (eType != null)
tmp = eType.GetGenericArguments()[0];
if (isArray)
{
var rank = type.GetArrayRank();
return (rank == 1) ? tmp.MakeArrayType() : tmp.MakeArrayType(rank);
}
return type;
}
private static Type GetExpressionType(Type type)
{
while (type != null && type != typeof(object))
{
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Expression<>))
return type;
type = type.BaseType;
}
return null;
}
protected override Expression VisitConditional(ConditionalExpression c)
{
var type = c.Type;
if (!typeof(IQueryable).IsAssignableFrom(type))
return base.VisitConditional(c);
var test = Visit(c.Test);
var ifTrue = Visit(c.IfTrue);
var ifFalse = Visit(c.IfFalse);
var trueType = ifTrue.Type;
var falseType = ifFalse.Type;
if (trueType.IsAssignableFrom(falseType))
return Expression.Condition(test, ifTrue, ifFalse, trueType);
if (falseType.IsAssignableFrom(trueType))
return Expression.Condition(test, ifTrue, ifFalse, falseType);
return Expression.Condition(test, ifTrue, ifFalse, GetEquivalentType(type));
}
protected override Expression VisitBlock(BlockExpression node)
{
var type = node.Type;
if (!typeof(IQueryable).IsAssignableFrom(type))
return base.VisitBlock(node);
var nodes = Visit(node.Expressions);
var variables = VisitAndConvert(node.Variables, "EnumerableRewriter.VisitBlock");
if (type == node.Expressions.Last().Type)
return Expression.Block(variables, nodes);
return Expression.Block(GetEquivalentType(type), variables, nodes);
}
protected override Expression VisitGoto(GotoExpression node)
{
var type = node.Value.Type;
if (!typeof(IQueryable).IsAssignableFrom(type))
return base.VisitGoto(node);
var target = VisitLabelTarget(node.Target);
var value = Visit(node.Value);
return Expression.MakeGoto(node.Kind, target, value, GetEquivalentType(typeof(EnumerableQuery).IsAssignableFrom(type) ? value.Type : type));
}
protected override LabelTarget VisitLabelTarget(LabelTarget node)
{
LabelTarget newTarget;
if (_targetCache == null)
_targetCache = new Dictionary<LabelTarget, LabelTarget>();
else if (_targetCache.TryGetValue(node, out newTarget))
return newTarget;
var type = node.Type;
if (!typeof(IQueryable).IsAssignableFrom(type))
newTarget = base.VisitLabelTarget(node);
else
newTarget = Expression.Label(GetEquivalentType(type), node.Name);
_targetCache.Add(node, newTarget);
return newTarget;
}
}
}
快速测试:
// The expression I get transforms IQueryables, for instance this one:
Expression<Func<IQueryable<int>, double>> input =
qi => (double)qi.Sum() / qi.Count();
// I want an expression that transforms IEnumerables:
Expression<Func<IEnumerable<int>, double>> desiredOutput =
ei => (double)ei.Sum() / ei.Count();
var cleanSolution = MakeClean(input);
Expression<Func<IEnumerable<TIn>, TOut>> MakeClean<TIn, TOut>(
Expression<Func<IQueryable<TIn>, TOut>> original)
{
var rewriter = new EnumerableRewriter();
var newBody = rewriter.Visit(original.Body);
var replacements = rewriter.GetParameterReplacements();
var newParams = original.Parameters.Select(p => replacements.TryGetValue(p, out var replacement) ? replacement : p);
return Expression.Lambda<Func<IEnumerable<TIn>, TOut>>(newBody, newParams);
}
var compiledDesired = desiredOutput.Compile();
var compiledClean = cleanSolution.Compile();
var exampleEnumerable = Enumerable.Range(1, 10);
var repetitions = 10_000;
// Desired test:
var desiredSw = Stopwatch.StartNew();
for (var i = 0; i < repetitions; ++i)
{
var exampleOutput = compiledDesired.Invoke(exampleEnumerable);
}
desiredSw.Stop();
// Clean test:
var cleanSw = Stopwatch.StartNew();
for (var i = 0; i < repetitions; ++i)
{
var exampleOutput = compiledClean.Invoke(exampleEnumerable);
}
cleanSw.Stop();
Console.WriteLine($"Executed in {cleanSw.ElapsedMilliseconds} ms instead of {desiredSw.ElapsedMilliseconds} ms.");
// It now executes at roughly the same speed.