如何将表达式树转换为部分SQL查询?

时间:2011-10-11 20:10:37

标签: c# .net linq entity-framework lambda

当EF或LINQ to SQL运行查询时,它:

  1. 从代码
  2. 构建表达式树
  3. 将表达式树转换为SQL查询
  4. 执行查询,从数据库中获取原始结果并将其转换为应用程序使用的结果。
  5. 查看堆栈跟踪,我无法弄清楚第二部分发生的位置。

    通常,是否可以使用EF的现有部分或(最好)LINQ to SQL将Expression对象转换为部分SQL查询(使用Transact-SQL语法),或者我必须重新发明轮子?


    更新:评论要求提供我正在尝试做的示例。

    实际上,the answer by Ryan Wright below完全说明了我想要实现的结果,除了我的问题是关于如何通过使用EF实际使用的.NET Framework的现有机制来实现它的事实。和LINQ to SQL ,而不是重新发明轮子并自己编写数千行不经过测试的代码来做类似的事情。

    这也是一个例子。再次注意,没有ORM生成的代码。

    private class Product
    {
        [DatabaseMapping("ProductId")]
        public int Id { get; set; }
    
        [DatabaseMapping("Price")]
        public int PriceInCents { get; set; }
    }
    
    private string Convert(Expression expression)
    {
        // Some magic calls to .NET Framework code happen here.
        // [...]
    }
    
    private void TestConvert()
    {
        Expression<Func<Product, int, int, bool>> inPriceRange =
            (Product product, int from, int to) =>
                product.PriceInCents >= from && product.PriceInCents <= to;
    
        string actualQueryPart = this.Convert(inPriceRange);
    
        Assert.AreEqual("[Price] between @from and @to", actualQueryPart);
    }
    

    名称Price来自预期查询的位置?

    可以通过查询DatabaseMapping类的Price属性的自定义Product属性来获取名称。

    名称@from@to来自预期查询的位置?

    这些名称是表达式参数的实际名称。

    between … and来自预期查询的位置?

    这是二进制表达式的可能结果。也许EF或LINQ to SQL而不是between … and语句,而不是[Price] >= @from and [Price] <= @to。它也没关系,因为结果在逻辑上是相同的(我没有提到性能),所以它并不重要。

    为什么预期查询中没有where

    因为Expression中没有任何内容表明必须有where个关键字。也许实际表达式只是表达式中的一个,这些表达式稍后将与二元运算符组合以构建一个更大的查询,以便用where作为前缀。

9 个答案:

答案 0 :(得分:42)

是的,您可以使用访问者模式解析LINQ表达式树。您需要通过子类化ExpressionVisitor来构建查询转换器,如下所示。通过挂钩到正确的点,您可以使用转换器从LINQ表达式构造SQL字符串。请注意,下面的代码仅处理基本where / orderby / skip / take子句,但您可以根据需要填写更多内容。希望它是迈出良好的第一步。

public class MyQueryTranslator : ExpressionVisitor
{
    private StringBuilder sb;
    private string _orderBy = string.Empty;
    private int? _skip = null;
    private int? _take = null;
    private string _whereClause = string.Empty;

    public int? Skip
    {
        get
        {
            return _skip;
        }
    }

    public int? Take
    {
        get
        {
            return _take;
        }
    }

    public string OrderBy
    {
        get
        {
            return _orderBy;
        }
    }

    public string WhereClause
    {
        get
        {
            return _whereClause;
        }
    }

    public MyQueryTranslator()
    {
    }

    public string Translate(Expression expression)
    {
        this.sb = new StringBuilder();
        this.Visit(expression);
        _whereClause = this.sb.ToString();
        return _whereClause;
    }

    private static Expression StripQuotes(Expression e)
    {
        while (e.NodeType == ExpressionType.Quote)
        {
            e = ((UnaryExpression)e).Operand;
        }
        return e;
    }

    protected override Expression VisitMethodCall(MethodCallExpression m)
    {
        if (m.Method.DeclaringType == typeof(Queryable) && m.Method.Name == "Where")
        {
            this.Visit(m.Arguments[0]);
            LambdaExpression lambda = (LambdaExpression)StripQuotes(m.Arguments[1]);
            this.Visit(lambda.Body);
            return m;
        }
        else if (m.Method.Name == "Take")
        {
            if (this.ParseTakeExpression(m))
            {
                Expression nextExpression = m.Arguments[0];
                return this.Visit(nextExpression);
            }
        }
        else if (m.Method.Name == "Skip")
        {
            if (this.ParseSkipExpression(m))
            {
                Expression nextExpression = m.Arguments[0];
                return this.Visit(nextExpression);
            }
        }
        else if (m.Method.Name == "OrderBy")
        {
            if (this.ParseOrderByExpression(m, "ASC"))
            {
                Expression nextExpression = m.Arguments[0];
                return this.Visit(nextExpression);
            }
        }
        else if (m.Method.Name == "OrderByDescending")
        {
            if (this.ParseOrderByExpression(m, "DESC"))
            {
                Expression nextExpression = m.Arguments[0];
                return this.Visit(nextExpression);
            }
        }

        throw new NotSupportedException(string.Format("The method '{0}' is not supported", m.Method.Name));
    }

    protected override Expression VisitUnary(UnaryExpression u)
    {
        switch (u.NodeType)
        {
            case ExpressionType.Not:
                sb.Append(" NOT ");
                this.Visit(u.Operand);
                break;
            case ExpressionType.Convert:
                this.Visit(u.Operand);
                break;
            default:
                throw new NotSupportedException(string.Format("The unary operator '{0}' is not supported", u.NodeType));
        }
        return u;
    }


    /// <summary>
    /// 
    /// </summary>
    /// <param name="b"></param>
    /// <returns></returns>
    protected override Expression VisitBinary(BinaryExpression b)
    {
        sb.Append("(");
        this.Visit(b.Left);

        switch (b.NodeType)
        {
            case ExpressionType.And:
                sb.Append(" AND ");
                break;

            case ExpressionType.AndAlso:
                sb.Append(" AND ");
                break;

            case ExpressionType.Or:
                sb.Append(" OR ");
                break;

            case ExpressionType.OrElse:
                sb.Append(" OR ");
                break;

            case ExpressionType.Equal:
                if (IsNullConstant(b.Right))
                {
                    sb.Append(" IS ");
                }
                else
                {
                    sb.Append(" = ");
                }
                break;

            case ExpressionType.NotEqual:
                if (IsNullConstant(b.Right))
                {
                    sb.Append(" IS NOT ");
                }
                else
                {
                    sb.Append(" <> ");
                }
                break;

            case ExpressionType.LessThan:
                sb.Append(" < ");
                break;

            case ExpressionType.LessThanOrEqual:
                sb.Append(" <= ");
                break;

            case ExpressionType.GreaterThan:
                sb.Append(" > ");
                break;

            case ExpressionType.GreaterThanOrEqual:
                sb.Append(" >= ");
                break;

            default:
                throw new NotSupportedException(string.Format("The binary operator '{0}' is not supported", b.NodeType));

        }

        this.Visit(b.Right);
        sb.Append(")");
        return b;
    }

    protected override Expression VisitConstant(ConstantExpression c)
    {
        IQueryable q = c.Value as IQueryable;

        if (q == null && c.Value == null)
        {
            sb.Append("NULL");
        }
        else if (q == null)
        {
            switch (Type.GetTypeCode(c.Value.GetType()))
            {
                case TypeCode.Boolean:
                    sb.Append(((bool)c.Value) ? 1 : 0);
                    break;

                case TypeCode.String:
                    sb.Append("'");
                    sb.Append(c.Value);
                    sb.Append("'");
                    break;

                case TypeCode.DateTime:
                    sb.Append("'");
                    sb.Append(c.Value);
                    sb.Append("'");
                    break;

                case TypeCode.Object:
                    throw new NotSupportedException(string.Format("The constant for '{0}' is not supported", c.Value));

                default:
                    sb.Append(c.Value);
                    break;
            }
        }

        return c;
    }

    protected override Expression VisitMember(MemberExpression m)
    {
        if (m.Expression != null && m.Expression.NodeType == ExpressionType.Parameter)
        {
            sb.Append(m.Member.Name);
            return m;
        }

        throw new NotSupportedException(string.Format("The member '{0}' is not supported", m.Member.Name));
    }

    protected bool IsNullConstant(Expression exp)
    {
        return (exp.NodeType == ExpressionType.Constant && ((ConstantExpression)exp).Value == null);
    }

    private bool ParseOrderByExpression(MethodCallExpression expression, string order)
    {
        UnaryExpression unary = (UnaryExpression)expression.Arguments[1];
        LambdaExpression lambdaExpression = (LambdaExpression)unary.Operand;

        lambdaExpression = (LambdaExpression)Evaluator.PartialEval(lambdaExpression);

        MemberExpression body = lambdaExpression.Body as MemberExpression;
        if (body != null)
        {
            if (string.IsNullOrEmpty(_orderBy))
            {
                _orderBy = string.Format("{0} {1}", body.Member.Name, order);
            }
            else
            {
                _orderBy = string.Format("{0}, {1} {2}", _orderBy, body.Member.Name, order);
            }

            return true;
        }

        return false;
    }

    private bool ParseTakeExpression(MethodCallExpression expression)
    {
        ConstantExpression sizeExpression = (ConstantExpression)expression.Arguments[1];

        int size;
        if (int.TryParse(sizeExpression.Value.ToString(), out size))
        {
            _take = size;
            return true;
        }

        return false;
    }

    private bool ParseSkipExpression(MethodCallExpression expression)
    {
        ConstantExpression sizeExpression = (ConstantExpression)expression.Arguments[1];

        int size;
        if (int.TryParse(sizeExpression.Value.ToString(), out size))
        {
            _skip = size;
            return true;
        }

        return false;
    }
}

然后通过调用:

访问表达式
var translator = new MyQueryTranslator();
string whereClause = translator.Translate(expression);

答案 1 :(得分:22)

简短的回答似乎是您不能将EF或LINQ to SQL的部分用作翻译的快捷方式。您至少需要ObjectContext的子类才能获得internal protected QueryProvider属性,这意味着创建上下文的所有开销,包括所有元数据等等。

假设你没问题,要获得部分SQL查询,例如,只需WHERE子句,你基本上需要查询提供程序并调用IQueryProvider.CreateQuery()就像LINQ所做的那样它的实施Queryable.Where。要获得更完整的查询,您可以使用ObjectQuery.ToTraceString()

至于发生这种情况的地方,LINQ provider basics通常表示

  

IQueryProvider使用LINQ框架传递的构造表达式树返回对IQueryable的引用,该框架树用于进一步调用。一般而言,每个查询块都转换为一堆方法调用。对于每个方法调用,都涉及一些表达式。在创建我们的提供程序时 - 在IQueryProvider.CreateQuery方法中 - 我们运行表达式并填充一个过滤器对象,该对象在IQueryProvider.Execute方法中用于对数据存储运行查询

那个

  

查询可以通过两种方式执行,可以通过在Query类中实现GetEnumerator方法(在IEnumerable接口中定义)来实现(继承自IQueryable);或者它可以直接由LINQ运行时执行

在调试器下检查EF,它是前者。

如果您不想完全重新发明轮子,并且EF和LINQ to SQL都不是选项,那么这一系列文章可能会有所帮助:

以下是创建查询提供程序的一些资源,这些资源可能会让您更加繁重,以实现您的目标:

答案 2 :(得分:5)

在Linq2SQL中,您可以使用:

var cmd = DataContext.GetCommand(expression);
var sqlQuery = cmd.CommandText;

答案 3 :(得分:5)

它还没有完成,但是如果你以后来这里,可以考虑一下这些想法:

    private string CreateWhereClause(Expression<Func<T, bool>> predicate)
    {
        StringBuilder p = new StringBuilder(predicate.Body.ToString());
        var pName = predicate.Parameters.First();
        p.Replace(pName.Name + ".", "");
        p.Replace("==", "=");
        p.Replace("AndAlso", "and");
        p.Replace("OrElse", "or");
        p.Replace("\"", "\'");
        return p.ToString();
    }

    private string AddWhereToSelectCommand(Expression<Func<T, bool>> predicate, int maxCount = 0)
    {           
        string command = string.Format("{0} where {1}", CreateSelectCommand(maxCount), CreateWhereClause(predicate));
        return command;
    }

    private string CreateSelectCommand(int maxCount = 0)
    {
        string selectMax = maxCount > 0 ? "TOP " + maxCount.ToString() + " * " : "*";
        string command = string.Format("Select {0} from {1}", selectMax, _tableName);
        return command;
    }

答案 4 :(得分:4)

你基本上必须重新发明轮子。 QueryProvider是从表达式树到其存储本机语法的转换。这就是处理特殊情况的事情,比如string.Contains(),string.StartsWith()以及处理它的所有专业函数。它还将在ORM的各个层中处理元数据查找(在数据库优先或模型优先实体框架的情况下为* .edml)。已经有用于构建SQL命令的示例和框架。但你正在寻找的东西听起来像是部分解决方案。

还要了解表/视图元数据是正确确定合法内容所必需的。查询提供程序非常复杂,除了将简单的表达式树转换为SQL之外,还为您做了很多工作。

响应你的第二部分发生在哪里。第二部分发生在IQueryable的枚举过程中。 IQueryables也是IEnumerables,并且最终在调用GetEnumerator时它依次调用具有表达式树的查询提供程序,该表达式树将使用其元数据来生成sql命令。事实并非如此,但它应该让这个想法得以实现。

答案 5 :(得分:2)

您可以使用以下代码:

var query = from c in Customers
            select c;

string sql = ((ObjectQuery)query).ToTraceString();

查看以下信息:Retrieving the SQL generated by the Entity Provider

答案 6 :(得分:1)

搜索了数小时的表达式树到SQL转换器的实现后,我发现没有任何有用或免费的东西,或者以某种方式与.NET Core一起工作。 然后我发现了这个。谢谢 Ryan Wright 。 我接受了他的代码,并对其进行了一些修改以满足我的需求。现在,我将其还给社区。

当前版本可以执行以下操作:

批量更新

            int rowCount = context
                .Users
                .Where(x => x.Status == UserStatus.Banned)
                .Update(x => new
                {
                    DisplayName = "Bad Guy"
                });

这将产生以下sql

DECLARE @p0 NVarChar
DECLARE @p1 Int
SET @p0 = 'Bad Guy'
SET @p1 = 3
UPDATE [Users]
SET [DisplayName] = @p0
WHERE ( [Status] = @p1 )

批量删除

            int rowCount = context
                .Users
                .Where(x => x.UniqueName.EndsWith("012"))
                .Delete();

生成的sql

DECLARE @p0 NVarChar
SET @p0 = '%012'
DELETE
FROM [Users]
WHERE [UniqueName] LIKE @p0

输出SQL语句

            string sql = context
                .Users
                .Where(x => x.Status == UserStatus.LockedOut)
                .OrderBy(x => x.UniqueName)
                .ThenByDescending(x => x.LastLogin)
                .Select(x => new
                {
                    x.UniqueName,
                    x.Email
                })
                .ToSqlString();

这会产生sql

DECLARE @p0 Int
SET @p0 = 4
SELECT [UniqueName], [Email]
FROM [Users]
WHERE ( [Status] = @p0 )
ORDER BY [LastLogin] DESC, [UniqueName] ASC

另一个例子

            string sql = context
                .Users
                .Where(x => x.Status == UserStatus.LockedOut)
                .OrderBy(x => x.UniqueName)
                .ThenByDescending(x => x.LastLogin)
                .Select(x => new
                {
                    x.UniqueName,
                    x.Email,
                    x.LastLogin
                })
                .Take(4)
                .Skip(3)
                .Distinct()
                .ToSqlString();

sql

DECLARE @p0 Int
SET @p0 = 4
SELECT DISTINCT [UniqueName], [Email], [LastLogin]
FROM [Users]
WHERE ( [Status] = @p0 )
ORDER BY [LastLogin] DESC, [UniqueName] ASC OFFSET 3 ROWS FETCH NEXT 4 ROWS ONLY

另一个带有局部变量的示例

            string name ="venom";

            string sql = context
                .Users
                .Where(x => x.LastLogin == DateTime.UtcNow && x.UniqueName.Contains(name))
                .Select(x => x.Email)
                .ToSqlString();

生成的sql

DECLARE @p0 DateTime
DECLARE @p1 NVarChar
SET @p0 = '20.06.2020 19:23:46'
SET @p1 = '%venom%'
SELECT [Email]
FROM [Users]
WHERE ( ( [LastLogin] = @p0 ) AND [UniqueName] LIKE @p1 )

SimpleExpressionToSQL 类本身可以直接使用

var simpleExpressionToSQL = new SimpleExpressionToSQL(queryable);
simpleExpressionToSQL.ExecuteNonQuery(IsolationLevel.Snapshot);

代码

此处使用的评估者来自here

SimpleExpressionToSQL

    public class SimpleExpressionToSQL : ExpressionVisitor
    {
        /*
         * Original By Ryan Wright: https://stackoverflow.com/questions/7731905/how-to-convert-an-expression-tree-to-a-partial-sql-query
         */

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly List<string> _groupBy = new List<string>();

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly List<string> _orderBy = new List<string>();

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly List<SqlParameter> _parameters = new List<SqlParameter>();

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly List<string> _select = new List<string>();

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly List<string> _update = new List<string>();

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly List<string> _where = new List<string>();

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private int? _skip;

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private int? _take;

        public SimpleExpressionToSQL(IQueryable queryable)
        {
            if (queryable is null)
            {
                throw new ArgumentNullException(nameof(queryable));
            }

            Expression expression = queryable.Expression;
            Visit(expression);
            Type entityType = (GetEntityType(expression) as IQueryable).ElementType;
            TableName = queryable.GetTableName(entityType);
            DbContext = queryable.GetDbContext();
        }

        public string CommandText => BuildSqlStatement().Join(Environment.NewLine);

        public DbContext DbContext { get; private set; }

        public string From => $"FROM [{TableName}]";

        public string GroupBy => _groupBy.Count == 0 ? null : "GROUP BY " + _groupBy.Join(", ");
        public bool IsDelete { get; private set; } = false;
        public bool IsDistinct { get; private set; }
        public string OrderBy => BuildOrderByStatement().Join(" ");
        public SqlParameter[] Parameters => _parameters.ToArray();
        public string Select => BuildSelectStatement().Join(" ");
        public int? Skip => _skip;
        public string TableName { get; private set; }
        public int? Take => _take;
        public string Update => "SET " + _update.Join(", ");

        public string Where => _where.Count == 0 ? null : "WHERE " + _where.Join(" ");

        public static implicit operator string(SimpleExpressionToSQL simpleExpression) => simpleExpression.ToString();

        public int ExecuteNonQuery(IsolationLevel isolationLevel = IsolationLevel.RepeatableRead)
        {
            DbConnection connection = DbContext.Database.GetDbConnection();
            using (DbCommand command = connection.CreateCommand())
            {
                command.CommandText = CommandText;
                command.CommandType = CommandType.Text;
                command.Parameters.AddRange(Parameters);

#if DEBUG
                Debug.WriteLine(ToString());
#endif

                if (command.Connection.State != ConnectionState.Open)
                    command.Connection.Open();

                using (DbTransaction transaction = connection.BeginTransaction(isolationLevel))
                {
                    command.Transaction = transaction;
                    int result = command.ExecuteNonQuery();
                    transaction.Commit();

                    return result;
                }
            }
        }

        public async Task<int> ExecuteNonQueryAsync(IsolationLevel isolationLevel = IsolationLevel.RepeatableRead)
        {
            DbConnection connection = DbContext.Database.GetDbConnection();
            using (DbCommand command = connection.CreateCommand())
            {
                command.CommandText = CommandText;
                command.CommandType = CommandType.Text;
                command.Parameters.AddRange(Parameters);

#if DEBUG
                Debug.WriteLine(ToString());
#endif

                if (command.Connection.State != ConnectionState.Open)
                    await command.Connection.OpenAsync();

                using (DbTransaction transaction = connection.BeginTransaction(isolationLevel))
                {
                    command.Transaction = transaction;
                    int result = await command.ExecuteNonQueryAsync();
                    transaction.Commit();

                    return result;
                }
            }
        }

        public override string ToString() =>
            BuildDeclaration()
                .Union(BuildSqlStatement())
                .Join(Environment.NewLine);

        protected override Expression VisitBinary(BinaryExpression binaryExpression)
        {
            _where.Add("(");
            Visit(binaryExpression.Left);

            switch (binaryExpression.NodeType)
            {
                case ExpressionType.And:
                    _where.Add("AND");
                    break;

                case ExpressionType.AndAlso:
                    _where.Add("AND");
                    break;

                case ExpressionType.Or:
                case ExpressionType.OrElse:
                    _where.Add("OR");
                    break;

                case ExpressionType.Equal:
                    if (IsNullConstant(binaryExpression.Right))
                    {
                        _where.Add("IS");
                    }
                    else
                    {
                        _where.Add("=");
                    }
                    break;

                case ExpressionType.NotEqual:
                    if (IsNullConstant(binaryExpression.Right))
                    {
                        _where.Add("IS NOT");
                    }
                    else
                    {
                        _where.Add("<>");
                    }
                    break;

                case ExpressionType.LessThan:
                    _where.Add("<");
                    break;

                case ExpressionType.LessThanOrEqual:
                    _where.Add("<=");
                    break;

                case ExpressionType.GreaterThan:
                    _where.Add(">");
                    break;

                case ExpressionType.GreaterThanOrEqual:
                    _where.Add(">=");
                    break;

                default:
                    throw new NotSupportedException(string.Format("The binary operator '{0}' is not supported", binaryExpression.NodeType));
            }

            Visit(binaryExpression.Right);
            _where.Add(")");
            return binaryExpression;
        }

        protected override Expression VisitConstant(ConstantExpression constantExpression)
        {
            switch (constantExpression.Value)
            {
                case null when constantExpression.Value == null:
                    _where.Add("NULL");
                    break;

                default:

                    if (constantExpression.Type.CanConvertToSqlDbType())
                    {
                        _where.Add(CreateParameter(constantExpression.Value).ParameterName);
                    }

                    break;
            }

            return constantExpression;
        }

        protected override Expression VisitMember(MemberExpression memberExpression)
        {
            Expression VisitMemberLocal(Expression expression)
            {
                switch (expression.NodeType)
                {
                    case ExpressionType.Parameter:
                        _where.Add($"[{memberExpression.Member.Name}]");
                        return memberExpression;

                    case ExpressionType.Constant:
                        _where.Add(CreateParameter(GetValue(memberExpression)).ParameterName);

                        return memberExpression;

                    case ExpressionType.MemberAccess:
                        _where.Add(CreateParameter(GetValue(memberExpression)).ParameterName);

                        return memberExpression;
                }

                throw new NotSupportedException(string.Format("The member '{0}' is not supported", memberExpression.Member.Name));
            }

            if (memberExpression.Expression == null)
            {
                return VisitMemberLocal(memberExpression);
            }

            return VisitMemberLocal(memberExpression.Expression);
        }

        protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
        {
            switch (methodCallExpression.Method.Name)
            {
                case nameof(Queryable.Where) when methodCallExpression.Method.DeclaringType == typeof(Queryable):

                    Visit(methodCallExpression.Arguments[0]);
                    var lambda = (LambdaExpression)StripQuotes(methodCallExpression.Arguments[1]);
                    Visit(lambda.Body);

                    return methodCallExpression;

                case nameof(Queryable.Select):
                    return ParseExpression(methodCallExpression, _select);

                case nameof(Queryable.GroupBy):
                    return ParseExpression(methodCallExpression, _groupBy);

                case nameof(Queryable.Take):
                    return ParseExpression(methodCallExpression, ref _take);

                case nameof(Queryable.Skip):
                    return ParseExpression(methodCallExpression, ref _skip);

                case nameof(Queryable.OrderBy):
                case nameof(Queryable.ThenBy):
                    return ParseExpression(methodCallExpression, _orderBy, "ASC");

                case nameof(Queryable.OrderByDescending):
                case nameof(Queryable.ThenByDescending):
                    return ParseExpression(methodCallExpression, _orderBy, "DESC");

                case nameof(Queryable.Distinct):
                    IsDistinct = true;
                    return Visit(methodCallExpression.Arguments[0]);

                case nameof(string.StartsWith):
                    _where.AddRange(ParseExpression(methodCallExpression, methodCallExpression.Object));
                    _where.Add("LIKE");
                    _where.Add(CreateParameter(GetValue(methodCallExpression.Arguments[0]).ToString() + "%").ParameterName);
                    return methodCallExpression.Arguments[0];

                case nameof(string.EndsWith):
                    _where.AddRange(ParseExpression(methodCallExpression, methodCallExpression.Object));
                    _where.Add("LIKE");
                    _where.Add(CreateParameter("%" + GetValue(methodCallExpression.Arguments[0]).ToString()).ParameterName);
                    return methodCallExpression.Arguments[0];

                case nameof(string.Contains):
                    _where.AddRange(ParseExpression(methodCallExpression, methodCallExpression.Object));
                    _where.Add("LIKE");
                    _where.Add(CreateParameter("%" + GetValue(methodCallExpression.Arguments[0]).ToString() + "%").ParameterName);
                    return methodCallExpression.Arguments[0];

                case nameof(Extensions.ToSqlString):
                    return Visit(methodCallExpression.Arguments[0]);

                case nameof(Extensions.Delete):
                case nameof(Extensions.DeleteAsync):
                    IsDelete = true;
                    return Visit(methodCallExpression.Arguments[0]);

                case nameof(Extensions.Update):
                    return ParseExpression(methodCallExpression, _update);

                default:
                    if (methodCallExpression.Object != null)
                    {
                        _where.Add(CreateParameter(GetValue(methodCallExpression)).ParameterName);
                        return methodCallExpression;
                    }
                    break;
            }

            throw new NotSupportedException($"The method '{methodCallExpression.Method.Name}' is not supported");
        }

        protected override Expression VisitUnary(UnaryExpression unaryExpression)
        {
            switch (unaryExpression.NodeType)
            {
                case ExpressionType.Not:
                    _where.Add("NOT");
                    Visit(unaryExpression.Operand);
                    break;

                case ExpressionType.Convert:
                    Visit(unaryExpression.Operand);
                    break;

                default:
                    throw new NotSupportedException($"The unary operator '{unaryExpression.NodeType}' is not supported");
            }
            return unaryExpression;
        }

        private static Expression StripQuotes(Expression expression)
        {
            while (expression.NodeType == ExpressionType.Quote)
            {
                expression = ((UnaryExpression)expression).Operand;
            }
            return expression;
        }

        [SuppressMessage("Style", "IDE0011:Add braces", Justification = "Easier to read")]
        private IEnumerable<string> BuildDeclaration()
        {
            if (Parameters.Length == 0)                        /**/    yield break;
            foreach (SqlParameter parameter in Parameters)     /**/    yield return $"DECLARE {parameter.ParameterName} {parameter.SqlDbType}";

            foreach (SqlParameter parameter in Parameters)     /**/
                if (parameter.SqlDbType.RequiresQuotes())      /**/    yield return $"SET {parameter.ParameterName} = '{parameter.SqlValue?.ToString().Replace("'", "''") ?? "NULL"}'";
                else                                           /**/    yield return $"SET {parameter.ParameterName} = {parameter.SqlValue}";
        }

        [SuppressMessage("Style", "IDE0011:Add braces", Justification = "Easier to read")]
        private IEnumerable<string> BuildOrderByStatement()
        {
            if (Skip.HasValue && _orderBy.Count == 0)                       /**/   yield return "ORDER BY (SELECT NULL)";
            else if (_orderBy.Count == 0)                                   /**/   yield break;
            else if (_groupBy.Count > 0 && _orderBy[0].StartsWith("[Key]")) /**/   yield return "ORDER BY " + _groupBy.Join(", ");
            else                                                            /**/   yield return "ORDER BY " + _orderBy.Join(", ");

            if (Skip.HasValue && Take.HasValue)                             /**/   yield return $"OFFSET {Skip} ROWS FETCH NEXT {Take} ROWS ONLY";
            else if (Skip.HasValue && !Take.HasValue)                       /**/   yield return $"OFFSET {Skip} ROWS";
        }

        [SuppressMessage("Style", "IDE0011:Add braces", Justification = "Easier to read")]
        private IEnumerable<string> BuildSelectStatement()
        {
            yield return "SELECT";

            if (IsDistinct)                                 /**/    yield return "DISTINCT";

            if (Take.HasValue && !Skip.HasValue)            /**/    yield return $"TOP ({Take.Value})";

            if (_select.Count == 0 && _groupBy.Count > 0)   /**/    yield return _groupBy.Select(x => $"MAX({x})").Join(", ");
            else if (_select.Count == 0)                    /**/    yield return "*";
            else                                            /**/    yield return _select.Join(", ");
        }

        [SuppressMessage("Style", "IDE0011:Add braces", Justification = "Easier to read")]
        private IEnumerable<string> BuildSqlStatement()
        {
            if (IsDelete)                   /**/   yield return "DELETE";
            else if (_update.Count > 0)     /**/   yield return $"UPDATE [{TableName}]";
            else                            /**/   yield return Select;

            if (_update.Count == 0)         /**/   yield return From;
            else if (_update.Count > 0)     /**/   yield return Update;

            if (Where != null)              /**/   yield return Where;
            if (GroupBy != null)            /**/   yield return GroupBy;
            if (OrderBy != null)            /**/   yield return OrderBy;
        }

        private SqlParameter CreateParameter(object value)
        {
            string parameterName = $"@p{_parameters.Count}";

            var parameter = new SqlParameter()
            {
                ParameterName = parameterName,
                Value = value
            };

            _parameters.Add(parameter);

            return parameter;
        }

        private object GetEntityType(Expression expression)
        {
            while (true)
            {
                switch (expression)
                {
                    case ConstantExpression constantExpression:
                        return constantExpression.Value;

                    case MethodCallExpression methodCallExpression:
                        expression = methodCallExpression.Arguments[0];
                        continue;

                    default:
                        return null;
                }
            }
        }

        private IEnumerable<string> GetNewExpressionString(NewExpression newExpression, string appendString = null)
        {
            for (int i = 0; i < newExpression.Members.Count; i++)
            {
                if (newExpression.Arguments[i].NodeType == ExpressionType.MemberAccess)
                {
                    yield return
                        appendString == null ?
                        $"[{newExpression.Members[i].Name}]" :
                        $"[{newExpression.Members[i].Name}] {appendString}";
                }
                else
                {
                    yield return
                        appendString == null ?
                        $"[{newExpression.Members[i].Name}] = {CreateParameter(GetValue(newExpression.Arguments[i])).ParameterName}" :
                        $"[{newExpression.Members[i].Name}] = {CreateParameter(GetValue(newExpression.Arguments[i])).ParameterName}";
                }
            }
        }

        private object GetValue(Expression expression)
        {
            object GetMemberValue(MemberInfo memberInfo, object container = null)
            {
                switch (memberInfo)
                {
                    case FieldInfo fieldInfo:
                        return fieldInfo.GetValue(container);

                    case PropertyInfo propertyInfo:
                        return propertyInfo.GetValue(container);

                    default: return null;
                }
            }

            switch (expression)
            {
                case ConstantExpression constantExpression:
                    return constantExpression.Value;

                case MemberExpression memberExpression when memberExpression.Expression is ConstantExpression constantExpression:
                    return GetMemberValue(memberExpression.Member, constantExpression.Value);

                case MemberExpression memberExpression when memberExpression.Expression is null: // static
                    return GetMemberValue(memberExpression.Member);

                case MethodCallExpression methodCallExpression:
                    return Expression.Lambda(methodCallExpression).Compile().DynamicInvoke();

                case null:
                    return null;
            }

            throw new NotSupportedException();
        }

        private bool IsNullConstant(Expression expression) => expression.NodeType == ExpressionType.Constant && ((ConstantExpression)expression).Value == null;

        private IEnumerable<string> ParseExpression(Expression parent, Expression body, string appendString = null)
        {
            switch (body)
            {
                case MemberExpression memberExpression:
                    return appendString == null ?
                        new string[] { $"[{memberExpression.Member.Name}]" } :
                        new string[] { $"[{memberExpression.Member.Name}] {appendString}" };

                case NewExpression newExpression:
                    return GetNewExpressionString(newExpression, appendString);

                case ParameterExpression parameterExpression when parent is LambdaExpression lambdaExpression && lambdaExpression.ReturnType == parameterExpression.Type:
                    return new string[0];

                case ConstantExpression constantExpression:
                    return constantExpression
                        .Type
                        .GetProperties(BindingFlags.Public | BindingFlags.Instance)
                        .Select(x => $"[{x.Name}] = {CreateParameter(x.GetValue(constantExpression.Value)).ParameterName}");
            }

            throw new NotSupportedException();
        }

        private Expression ParseExpression(MethodCallExpression expression, List<string> commandList, string appendString = null)
        {
            var unary = (UnaryExpression)expression.Arguments[1];
            var lambdaExpression = (LambdaExpression)unary.Operand;

            lambdaExpression = (LambdaExpression)Evaluator.PartialEval(lambdaExpression);

            commandList.AddRange(ParseExpression(lambdaExpression, lambdaExpression.Body, appendString));

            return Visit(expression.Arguments[0]);
        }

        private Expression ParseExpression(MethodCallExpression expression, ref int? size)
        {
            var sizeExpression = (ConstantExpression)expression.Arguments[1];

            if (int.TryParse(sizeExpression.Value.ToString(), out int value))
            {
                size = value;
                return Visit(expression.Arguments[0]);
            }

            throw new NotSupportedException();
        }
    }

我将在评论中发布该扩展名 编辑:评论太久了...我将添加另一个答案。

在生产时请谨慎使用

随时使用它制作一个Nuget包:)

答案 7 :(得分:1)

SimpleExpressionToSQL 类的

扩展

    public static class Extensions
    {
        private static readonly MethodInfo _deleteMethod;
        private static readonly MethodInfo _deleteMethodAsync;
        private static readonly MethodInfo _toSqlStringMethod;
        private static readonly MethodInfo _updateMethod;
        private static readonly MethodInfo _updateMethodAsync;

        static Extensions()
        {
            Type extensionType = typeof(Extensions);

            _deleteMethod = extensionType.GetMethod(nameof(Extensions.Delete), BindingFlags.Static | BindingFlags.Public);
            _updateMethod = extensionType.GetMethod(nameof(Extensions.Update), BindingFlags.Static | BindingFlags.Public);

            _deleteMethodAsync = extensionType.GetMethod(nameof(Extensions.DeleteAsync), BindingFlags.Static | BindingFlags.Public);
            _updateMethodAsync = extensionType.GetMethod(nameof(Extensions.Update), BindingFlags.Static | BindingFlags.Public);

            _toSqlStringMethod = extensionType.GetMethod(nameof(Extensions.ToSqlString), BindingFlags.Static | BindingFlags.Public);
        }

        public static bool CanConvertToSqlDbType(this Type type) => type.ToSqlDbTypeInternal().HasValue;

        public static int Delete<T>(this IQueryable<T> queryable)
        {
            var simpleExpressionToSQL = new SimpleExpressionToSQL(queryable.AppendCall(_deleteMethod));
            return simpleExpressionToSQL.ExecuteNonQuery();
        }

        public static async Task<int> DeleteAsync<T>(this IQueryable<T> queryable)
        {
            var simpleExpressionToSQL = new SimpleExpressionToSQL(queryable.AppendCall(_deleteMethodAsync));
            return await simpleExpressionToSQL.ExecuteNonQueryAsync();
        }

        public static string GetTableName<TEntity>(this DbSet<TEntity> dbSet) where TEntity : class
        {
            DbContext context = dbSet.GetService<ICurrentDbContext>().Context;
            IModel model = context.Model;
            IEntityType entityTypeOfFooBar = model
                .GetEntityTypes()
                .First(t => t.ClrType == typeof(TEntity));

            IAnnotation tableNameAnnotation = entityTypeOfFooBar.GetAnnotation("Relational:TableName");

            return tableNameAnnotation.Value.ToString();
        }

        public static string GetTableName(this IQueryable query, Type entity)
        {
            QueryCompiler compiler = query.Provider.GetValueOfField<QueryCompiler>("_queryCompiler");
            IModel model = compiler.GetValueOfField<IModel>("_model");
            IEntityType entityTypeOfFooBar = model
                .GetEntityTypes()
                .First(t => t.ClrType == entity);

            IAnnotation tableNameAnnotation = entityTypeOfFooBar.GetAnnotation("Relational:TableName");

            return tableNameAnnotation.Value.ToString();
        }

        public static SqlDbType ToSqlDbType(this Type type) =>
            type.ToSqlDbTypeInternal() ?? throw new InvalidCastException($"Unable to cast from '{type}' to '{typeof(DbType)}'.");

        public static string ToSqlString<T>(this IQueryable<T> queryable) => new SimpleExpressionToSQL(queryable.AppendCall(_toSqlStringMethod));

        public static int Update<TSource, TResult>(this IQueryable<TSource> queryable, Expression<Func<TSource, TResult>> selector)
        {
            var simpleExpressionToSQL = new SimpleExpressionToSQL(queryable.AppendCall(_updateMethod, selector));
            return simpleExpressionToSQL.ExecuteNonQuery();
        }

        public static async Task<int> UpdateAsync<TSource, TResult>(this IQueryable<TSource> queryable, Expression<Func<TSource, TResult>> selector)
        {
            var simpleExpressionToSQL = new SimpleExpressionToSQL(queryable.AppendCall(_updateMethodAsync, selector));
            return await simpleExpressionToSQL.ExecuteNonQueryAsync();
        }

        internal static DbContext GetDbContext(this IQueryable query)
        {
            QueryCompiler compiler = query.Provider.GetValueOfField<QueryCompiler>("_queryCompiler");
            RelationalQueryContextFactory queryContextFactory = compiler.GetValueOfField<RelationalQueryContextFactory>("_queryContextFactory");
            QueryContextDependencies dependencies = queryContextFactory.GetValueOfField<QueryContextDependencies>("_dependencies");

            return dependencies.CurrentContext.Context;
        }

        internal static string Join(this IEnumerable<string> values, string separator) => string.Join(separator, values);

        internal static bool RequiresQuotes(this SqlDbType sqlDbType)
        {
            switch (sqlDbType)
            {
                case SqlDbType.Char:
                case SqlDbType.Date:
                case SqlDbType.DateTime:
                case SqlDbType.DateTime2:
                case SqlDbType.DateTimeOffset:
                case SqlDbType.NChar:
                case SqlDbType.NText:
                case SqlDbType.Time:
                case SqlDbType.SmallDateTime:
                case SqlDbType.Text:
                case SqlDbType.UniqueIdentifier:
                case SqlDbType.Timestamp:
                case SqlDbType.VarChar:
                case SqlDbType.Xml:
                case SqlDbType.Variant:
                case SqlDbType.NVarChar:
                    return true;

                default:
                    return false;
            }
        }

        internal static unsafe string ToCamelCase(this string value)
        {
            if (value == null || value.Length == 0)
            {
                return value;
            }

            string result = string.Copy(value);

            fixed (char* chr = result)
            {
                char valueChar = *chr;
                *chr = char.ToLowerInvariant(valueChar);
            }

            return result;
        }

        private static IQueryable<TResult> AppendCall<TSource, TResult>(this IQueryable<TSource> queryable, MethodInfo methodInfo, Expression<Func<TSource, TResult>> selector)
        {
            MethodInfo methodInfoGeneric = methodInfo.MakeGenericMethod(typeof(TSource), typeof(TResult));
            MethodCallExpression methodCallExpression = Expression.Call(methodInfoGeneric, queryable.Expression, selector);

            return new EntityQueryable<TResult>(queryable.Provider as IAsyncQueryProvider, methodCallExpression);
        }

        private static IQueryable<T> AppendCall<T>(this IQueryable<T> queryable, MethodInfo methodInfo)
        {
            MethodInfo methodInfoGeneric = methodInfo.MakeGenericMethod(typeof(T));
            MethodCallExpression methodCallExpression = Expression.Call(methodInfoGeneric, queryable.Expression);

            return new EntityQueryable<T>(queryable.Provider as IAsyncQueryProvider, methodCallExpression);
        }

        private static T GetValueOfField<T>(this object obj, string name)
        {
            FieldInfo field = obj
                .GetType()
                .GetField(name, BindingFlags.NonPublic | BindingFlags.Instance);

            return (T)field.GetValue(obj);
        }

        [SuppressMessage("Style", "IDE0011:Add braces", Justification = "Easier to read than with Allman braces")]
        private static SqlDbType? ToSqlDbTypeInternal(this Type type)
        {
            if (Nullable.GetUnderlyingType(type) is Type nullableType)
                return nullableType.ToSqlDbTypeInternal();

            if (type.IsEnum)
                return Enum.GetUnderlyingType(type).ToSqlDbTypeInternal();

            if (type == typeof(long))            /**/                return SqlDbType.BigInt;
            if (type == typeof(byte[]))          /**/                return SqlDbType.VarBinary;
            if (type == typeof(bool))            /**/                return SqlDbType.Bit;
            if (type == typeof(string))          /**/                return SqlDbType.NVarChar;
            if (type == typeof(DateTime))        /**/                return SqlDbType.DateTime2;
            if (type == typeof(decimal))         /**/                return SqlDbType.Decimal;
            if (type == typeof(double))          /**/                return SqlDbType.Float;
            if (type == typeof(int))             /**/                return SqlDbType.Int;
            if (type == typeof(float))           /**/                return SqlDbType.Real;
            if (type == typeof(Guid))            /**/                return SqlDbType.UniqueIdentifier;
            if (type == typeof(short))           /**/                return SqlDbType.SmallInt;
            if (type == typeof(object))          /**/                return SqlDbType.Variant;
            if (type == typeof(DateTimeOffset))  /**/                return SqlDbType.DateTimeOffset;
            if (type == typeof(TimeSpan))        /**/                return SqlDbType.Time;
            if (type == typeof(byte))            /**/                return SqlDbType.TinyInt;

            return null;
        }
    }

答案 8 :(得分:0)

不确定这是否正是您所需要的,但它看起来可能很接近:

string[] companies = { "Consolidated Messenger", "Alpine Ski House", "Southridge Video", "City Power & Light",
                   "Coho Winery", "Wide World Importers", "Graphic Design Institute", "Adventure Works",
                   "Humongous Insurance", "Woodgrove Bank", "Margie's Travel", "Northwind Traders",
                   "Blue Yonder Airlines", "Trey Research", "The Phone Company",
                   "Wingtip Toys", "Lucerne Publishing", "Fourth Coffee" };

// The IQueryable data to query.
IQueryable<String> queryableData = companies.AsQueryable<string>();

// Compose the expression tree that represents the parameter to the predicate.
ParameterExpression pe = Expression.Parameter(typeof(string), "company");

// ***** Where(company => (company.ToLower() == "coho winery" || company.Length > 16)) *****
// Create an expression tree that represents the expression 'company.ToLower() == "coho winery"'.
Expression left = Expression.Call(pe, typeof(string).GetMethod("ToLower", System.Type.EmptyTypes));
Expression right = Expression.Constant("coho winery");
Expression e1 = Expression.Equal(left, right);

// Create an expression tree that represents the expression 'company.Length > 16'.
left = Expression.Property(pe, typeof(string).GetProperty("Length"));
right = Expression.Constant(16, typeof(int));
Expression e2 = Expression.GreaterThan(left, right);

// Combine the expression trees to create an expression tree that represents the
// expression '(company.ToLower() == "coho winery" || company.Length > 16)'.
Expression predicateBody = Expression.OrElse(e1, e2);

// Create an expression tree that represents the expression
// 'queryableData.Where(company => (company.ToLower() == "coho winery" || company.Length > 16))'
MethodCallExpression whereCallExpression = Expression.Call(
    typeof(Queryable),
    "Where",
    new Type[] { queryableData.ElementType },
    queryableData.Expression,
    Expression.Lambda<Func<string, bool>>(predicateBody, new ParameterExpression[] { pe }));
// ***** End Where *****

// ***** OrderBy(company => company) *****
// Create an expression tree that represents the expression
// 'whereCallExpression.OrderBy(company => company)'
MethodCallExpression orderByCallExpression = Expression.Call(
    typeof(Queryable),
    "OrderBy",
    new Type[] { queryableData.ElementType, queryableData.ElementType },
    whereCallExpression,
    Expression.Lambda<Func<string, string>>(pe, new ParameterExpression[] { pe }));
// ***** End OrderBy *****

// Create an executable query from the expression tree.
IQueryable<string> results = queryableData.Provider.CreateQuery<string>(orderByCallExpression);

// Enumerate the results.
foreach (string company in results)
    Console.WriteLine(company);