如何使Lambda表达式在多个LINQ查询中可重用

时间:2015-03-18 16:30:12

标签: c# linq function linq-to-entities

我有以下LINQ查询:

var result = from person in dbContext.Person
             select new
             {
                 FirstName = person.FirstName,
                 LastName = person.LastName,

                 // I want to save this logic
                 JobCount = person.Jobs.Count(x => x.Completed)
             };
}

为避免在其他LINQ查询中重复自己,我想让JobCount lambda逻辑可用于其他查询。

我想我可以使用Func<Person, int>,就像这样:

public Func<Person, int> GetCompletedJobsForPerson = person => person.Jobs.Count(x => x.Completed);

var result = from person in dbContext.Person
             select new
             {
                 FirstName = person.FirstName,
                 LastName = person.LastName,

                 // Use Invoke to get amount
                 JobCount = GetCompletedJobsForPerson.Invoke(person)
             };
}

问题陈述:此操作失败,因为该方法无法映射到SQL语句并导致NotSupportedException

  

NotSupportedException未处理   LINQ to Entities中不支持LINQ表达式节点类型“Invoke”。

如何从多个LINQ查询中重用lambda?

1 个答案:

答案 0 :(得分:2)

它不能以简单的方式完成(如果可以轻松完成,有人会这样做:-))

可以做的是使用PredicateBuilder相同的技巧并创建一个AsExpandable来代替一些&#34;令牌&#34;您的查询中的(函数调用)与其他一些函数调用。但我认为这不值得。这是一百行代码来做正确的#34;。

另一个问题是查询需要调用这个特殊方法:

var result = (from person in dbContext.Person
              select new
              {
                  FirstName = person.FirstName,
                  LastName = person.LastName,

                  // Use Invoke to get amount
                  JobCount = GetCompletedJobsForPerson(person)
              }).FixMethodCalls();

好的......这很难,但可行:

// v0.11 Codename: Handle with Care+ (+ == Plus)

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;

[AttributeUsage(AttributeTargets.Method | AttributeTargets.Property)]
public class ExpandableAttribute : Attribute
{
    // Just to know the suffix to use :-)
    public static readonly string ExpandableSuffix = "Expression";
}

// Replaces method and properties calls to "special" method calls that 
// are Expression(s). These method/property calls can be used anywhere in 
// the query (Select, Where, GroupBy, ...)
// Remember to use .AsExpandable2() somewhere in your query (it must be a 
// "top level" part of the query):

// OK:
// var res1 = (from x in table select x).AsExpandable2();
// var res2 = table.AsExpandable().Where(x => true);
// var res3 = table.Where(x => true).AsExpandable2();
// var res4 = table.Where(x => true).AsExpandable2().Select(x => x);
// var res5 = table.Where(x => true).AsExpandable2().Select(x => x).AsExpandable2();

// Not OK:
// var res1 = table.Select(x => x.subtable.AsExpandable2());

// **Method calls**

// The methods to be expanded can be static or instance. There must be
// a corresponding **static* method with same name and suffix 
// "Expression", that doesn't have parameters and returns an Expression 
// with a certain signature.

// Static:
// var res2 = table.AsExpandable2().Select(x => MyClass.StaticMethod(1, x, 2, 3));
// There must be in the class MyClass
// public/private/protected static Expression<Func<int, MyClass, int, int, returnType(StaticMethod)>> StaticMethodExpression()

// Instance:
// var res1 = table.AsExpandable2().Select(x => x.InstanceMethod(1, 2, 3));
// There must be in the class x.GetType() 
// public/private/protected static Expression<Func<x.GetType(), int, int, int, returnType(InstanceMethod)>> InstanceMethodExpression()

// Note that multiple "tables" can be passed as parameters:
// Static:
// var res3 = (from x in table1 from y in table2 select new { x, y }).AsExpandable2().Select(z => MyClass.StaticMethod(1, z.x, z.y, 2, 3));
// There must be in the class MyClass
// public/private/protected/internal static Expression<Func<int, x.GetType(), y.GetType(), int, int, returnType(StaticMethod)>> StaticMethodExpression()

// Instance:
// var res4 = (from x in table1 from y in table2 select new { x, y }).AsExpandable2().Select(z => z.x.StaticMethod(1, z.y, 2, 3));
// There must be in the class x.GetType() 
// public/private/protected/internal static Expression<Func<x.GetType(), int, y.GetType(), int, int, returnType(StaticMethod)>> InstanceMethodExpression()

// **Properties**

// Same as with method calls, but with properties :-)
// (useful for things like FullName, where 
// FullName = Name + ' ' + Surname)
// Remember that the *Expression property must be **static**!

// Static (not very useful :-) ):
// var res1 = table.AsExpandable2().Select(x => MyClass.StaticProperty);
// There must be in the class MyClass
// public/private/protected/internal static Expression<Func<MyClass.StaticProperty.GetType()>> StaticPropertyExpression { get; }

// Instance:
// var res2 = table.AsExpandable2().Select(x => x.InstanceProperty);
// There must be in the class x.GetType() 
// public/private/protected/internal static Expression<Func<x.GetType(), x.InstanceProperty.GetType())>> InstancePropertyExpression { get; }

public static class MethodsPropertiesExpander
{
    // Because AsExpandable() is already used by http://www.albahari.com/nutshell/linqkit.aspx
    public static IQueryable<T> AsExpandable2<T>(this IQueryable<T> source)
    {
        if (source is MethodsPropertiesExpander<T>)
        {
            return source;
        }

        return new MethodsPropertiesExpander<T>(source);
    }
}

public interface IMethodsPropertiesExpander
{
}

public class MethodsPropertiesExpander<T> : IOrderedQueryable<T>, IQueryProvider, IMethodsPropertiesExpander
{
    public readonly IQueryable<T> Query;

    public MethodsPropertiesExpander(IQueryable<T> query)
    {
        if (!(query is IMethodsPropertiesExpander))
        {
            Expression expression = MethodsPropertiesReplacer.Default.Visit(query.Expression);
            Query = expression == query.Expression ? query : query.Provider.CreateQuery<T>(expression);
        }
        else
        {
            Query = query;
        }
    }

    /* IQueryable<T> */

    public IEnumerator<T> GetEnumerator()
    {
        return Query.GetEnumerator();
    }

    IEnumerator IEnumerable.GetEnumerator()
    {
        return GetEnumerator();
    }

    public Type ElementType
    {
        get { return Query.ElementType; }
    }

    public Expression Expression
    {
        get { return Query.Expression; }
    }

    public IQueryProvider Provider
    {
        get { return this; }
    }

    /* IQueryProvider */

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
    {
        return new MethodsPropertiesExpander<TElement>(Query.Provider.CreateQuery<TElement>(expression));
    }

    public IQueryable CreateQuery(Expression expression)
    {
        Type iqueryableArgument = GetIQueryableTypeArgument(expression.Type);
        MethodInfo createQueryImplMethod = typeof(MethodsPropertiesExpander<T>)
            .GetMethod("CreateQuery", BindingFlags.Instance | BindingFlags.NonPublic)
            .MakeGenericMethod(iqueryableArgument);

        return (IQueryable)createQueryImplMethod.Invoke(this, new[] { expression });
    }

    public TResult Execute<TResult>(Expression expression)
    {
        if (!(Query.Provider is IMethodsPropertiesExpander))
        {
            // We want to expand it only once :-)
            expression = MethodsPropertiesReplacer.Default.Visit(expression);
        }

        return Query.Provider.Execute<TResult>(expression);
    }

    public object Execute(Expression expression)
    {
        if (!(Query.Provider is IMethodsPropertiesExpander))
        {
            // We want to expand it only once :-)
            expression = MethodsPropertiesReplacer.Default.Visit(expression);
        }

        return Query.Provider.Execute(expression);
    }

    /* Implementation methods */

    /// <summary>
    /// Gets the T of IQueryablelt;T&gt;
    /// </summary>
    /// <param name="type"></param>
    /// <returns></returns>
    protected static Type GetIQueryableTypeArgument(Type type)
    {
        IEnumerable<Type> interfaces = type.IsInterface ?
            new[] { type }.Concat(type.GetInterfaces()) :
            type.GetInterfaces();
        Type argument = (from x in interfaces
                         where x.IsGenericType
                         let gt = x.GetGenericTypeDefinition()
                         where gt == typeof(IQueryable<>)
                         select x.GetGenericArguments()[0]).FirstOrDefault();
        return argument;
    }

    /* Utility classes */

    protected sealed class MethodsPropertiesReplacer : ExpressionVisitor
    {
        // Single instance is enough!
        public static readonly MethodsPropertiesReplacer Default = new MethodsPropertiesReplacer();

        private MethodsPropertiesReplacer()
        {
        }

        protected override Expression VisitMember(MemberExpression node)
        {
            PropertyInfo property = node.Member as PropertyInfo;
            MethodInfo getter;

            // We handle only properties (that aren't indexers) that have 
            // a get
            if (property != null && property.GetIndexParameters().Length == 0 && (getter = property.GetGetMethod(true)) != null)
            {
                // We work only on methods marked as [ExpandableAttribute]
                var attribute = property.GetCustomAttributes(typeof(ExpandableAttribute), false).FirstOrDefault();

                if (attribute != null)
                {
                    string name = property.Name + ExpandableAttribute.ExpandableSuffix;

                    var property2 = property.DeclaringType.GetProperty(name, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, null, Type.EmptyTypes, null);

                    if (property2 == null || property2.GetGetMethod(true) == null)
                    {
                        if (property2 == null)
                        {
                            if (property.DeclaringType.GetProperty(name, BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, null, Type.EmptyTypes, null) != null)
                            {
                                throw new NotSupportedException(string.Format("{0}.{1} isn't static!", property.DeclaringType.FullName, name));
                            }

                            throw new NotSupportedException(string.Format("{0}.{1} not found!", property.DeclaringType.FullName, name));
                        }

                        // property2.GetGetMethod(true) == null
                        throw new NotSupportedException(string.Format("{0}.{1} doesn't have a getter!", property.DeclaringType.FullName, name));
                    }

                    // Instance Parameters have the additional 
                    // "parameter" of the declaring type
                    var argumentsPlusReturnTypes = getter.IsStatic ?
                        new[] { node.Type } :
                        new[] { property.DeclaringType, node.Type };

                    var funcType = typeof(Func<>).Assembly.GetType(string.Format("System.Func`{0}", argumentsPlusReturnTypes.Length));

                    var returnType = typeof(Expression<>).MakeGenericType(funcType.MakeGenericType(argumentsPlusReturnTypes));

                    if (property2.PropertyType != returnType)
                    {
                        throw new NotSupportedException(string.Format("{0}.{1} has wrong return type!", property.DeclaringType.FullName, name));
                    }

                    var expression = (LambdaExpression)property2.GetValue(null, null);

                    // Instance Members have the additional "parameter" 
                    // of the declaring type
                    var arguments2 = getter.IsStatic ? new Expression[0] : new[] { node.Expression };

                    var replacer = new SimpleExpressionReplacer(expression.Parameters, arguments2);
                    var body = replacer.Visit(expression.Body);

                    return this.Visit(body);
                }
            }

            return base.VisitMember(node);
        }

        protected override Expression VisitMethodCall(MethodCallExpression node)
        {
            MethodInfo method = node.Method;

            // We work only on methods marked as [ExpandableAttribute]
            var attribute = method.GetCustomAttributes(typeof(ExpandableAttribute), false).FirstOrDefault();

            if (attribute != null)
            {
                string name = method.Name + ExpandableAttribute.ExpandableSuffix;

                var method2 = method.DeclaringType.GetMethod(name, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, Type.EmptyTypes, null);

                if (method2 == null)
                {
                    if (method.DeclaringType.GetMethod(name, BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, Type.EmptyTypes, null) != null)
                    {
                        throw new NotSupportedException(string.Format("{0}.{1} isn't static!", method.DeclaringType.FullName, name));
                    }

                    throw new NotSupportedException(string.Format("{0}.{1} not found!", method.DeclaringType.FullName, name));
                }

                // Instance methods have the additional "parameter" of
                // the declaring type
                var argumentsPlusReturnTypes = method.IsStatic ?
                    node.Arguments.Select(x => x.Type).Concat(new[] { node.Type }).ToArray() :
                    new[] { method.DeclaringType }.Concat(node.Arguments.Select(x => x.Type)).Concat(new[] { node.Type }).ToArray();

                var funcType = typeof(Func<>).Assembly.GetType(string.Format("System.Func`{0}", argumentsPlusReturnTypes.Length));

                var returnType = typeof(Expression<>).MakeGenericType(funcType.MakeGenericType(argumentsPlusReturnTypes));

                if (method2.ReturnType != returnType)
                {
                    throw new NotSupportedException(string.Format("{0}.{1} has wrong return type!", method.DeclaringType.FullName, name));
                }

                var expression = (LambdaExpression)method2.Invoke(null, null);

                // Instance methods have the additional "parameter" of
                // the declaring type
                var arguments2 = method.IsStatic ? node.Arguments : new[] { node.Object }.Concat(node.Arguments);

                var replacer = new SimpleExpressionReplacer(expression.Parameters, arguments2);
                var body = replacer.Visit(expression.Body);

                return this.Visit(body);
            }

            return base.VisitMethodCall(node);
        }
    }
}

// A simple expression visitor to replace some nodes of an expression 
// with some other nodes
public class SimpleExpressionReplacer : ExpressionVisitor
{
    public readonly Dictionary<Expression, Expression> Replaces;

    public SimpleExpressionReplacer(Dictionary<Expression, Expression> replaces)
    {
        Replaces = replaces;
    }

    public SimpleExpressionReplacer(IEnumerable<Expression> from, IEnumerable<Expression> to)
    {
        Replaces = new Dictionary<Expression, Expression>();

        using (var enu1 = from.GetEnumerator())
        using (var enu2 = to.GetEnumerator())
        {
            while (true)
            {
                bool res1 = enu1.MoveNext();
                bool res2 = enu2.MoveNext();

                if (!res1 || !res2)
                {
                    if (!res1 && !res2)
                    {
                        break;
                    }

                    if (!res1)
                    {
                        throw new ArgumentException("from shorter");
                    }

                    throw new ArgumentException("to shorter");
                }

                Replaces.Add(enu1.Current, enu2.Current);
            }
        }
    }

    public SimpleExpressionReplacer(Expression from, Expression to)
    {
        Replaces = new Dictionary<Expression, Expression> { { from, to } };
    }

    public override Expression Visit(Expression node)
    {
        Expression to;

        if (node != null && Replaces.TryGetValue(node, out to))
        {
            return base.Visit(to);
        }

        return base.Visit(node);
    }
}

我已经添加了奖励:你甚至可以扩展&#34;特殊属性。关于如何使用它的说明在开头的重要评论中。现在我给你举一些例子:

// Generated by EF
public partial class MyClass
{
    public int ID { get; set; }

    public string Name { get; set; }

    public string Surname { get; set; }

    public ICollection<MyInnerClass> MyInnerClass;
}

// Written by you (remember the partial!)
public partial class MyClass
{
    [Expandable]
    public int CountMyInnerClass()
    {
        // Not necessary to implement, unless you want to use it C#-side
        throw new NotImplementedException();
    }

    [Expandable]
    public int CountMyInnerClassPlus(int num)
    {
        // Not necessary to implement, unless you want to use it C#-side
        throw new NotImplementedException();
    }

    [Expandable]
    public int CountMyInnerClassProperty
    {
        get
        {
            // Not necessary to implement, unless you want to use it C#-side
            throw new NotImplementedException();
        }
    }

    [Expandable]
    public string FullName
    {
        get
        {
            // Not necessary to implement, unless you want to use it C#-side
            return Name + " " + Surname;
        }
    }

    protected static Expression<Func<MyClass, int>> CountMyInnerClassExpression()
    {
        return x => x.MyInnerClass.Count();
    }

    protected static Expression<Func<MyClass, int, int>> CountMyInnerClassPlusExpression()
    {
        return (x, num) => x.MyInnerClass.Count() + num;
    }

    protected static Expression<Func<MyClass, int>> CountMyInnerClassPropertyExpression
    {
        get
        {
            return x => x.MyInnerClass.Count();
        }
    }

    protected static Expression<Func<MyClass, string>> FullNameExpression
    {
        get
        {
            return x => x.Name + " " + x.Surname;
        }
    }
}

然后,在其他一些类类中(可能是查询的类):

[Expandable]
public static int LocalCountMyInnerClassPlus(MyClass x, int num)
{
    // Not necessary to implement, unless you want to use it C#-side
    throw new NotImplementedException();
}

public static Expression<Func<MyClass, int, int>> LocalCountMyInnerClassPlusExpression()
{
    return (x, num) => x.MyInnerClass.Count() + num;
}

然后

var query = (from x in db.MyClasses
             select new
                 {
                     x.ID,
                     x.FullName,
                     Count1 = x.CountMyInnerClass(),
                     Count2 = x.CountMyInnerClassPlus(5),
                     Count3 = x.CountMyInnerClassProperty,
                     Count4 = LocalCountMyInnerClassPlus(x, 10),
                 }).AsExpandable2().ToList();

它只是起作用: - )