穷举搜索/生成表达式树的每个组合

时间:2015-11-03 13:42:11

标签: c# tree expression expression-trees visitor-pattern

我正在使用基本的表达式树优化器来构建查询计划。在解析树时,我可以决定最好的方式"构建它,取决于我可以分配给每个操作的权重。

如果我有一个简单的树,有2个如何执行动作的选择,我希望能够生成树的两种变体,然后可以比较每个树的权重,看看是什么效率最高。

例如,下面的代码将允许我构建表达式树连接操作的两个变体:一个带有MergeJoinExpression,另一个带有NestedLoopJoinExpression

class Customer
{
        public int Id { get; set; }
}
class Orders
{
        public int Id { get; set; }
        public int CustomerId { get; set; }
}

class MergeJoinExpresion : JoinExpression
{
}

class NestLoopJoinExpresion : JoinExpression
{
}

class Visitor : ExpressionVisitor
{
    public List<Expression> GetPlans(Expression expr)
    {
        // ???
    }

    override VisitJoin(JoinExpression join)
    {
        // For this join, I can return the following (trite example)
        // return MergeJoinExpresion
        // return NestLoopJoinExpresion

        return base.VisitJoin(join);
    }
}

如何构建一个生成树的每个变体并将它们返回给我的方法?

class Program
{
        static void Main(string[] args)
        {
             var query = from c in customers
                        join o in orders on c.Id equals o.CustomerId
                        select new
                        {
                            CustomerId = c.Id,
                            OrderId = o.Id
                        };


            var plans = new Visitor().GetPlans(query);
        }
}

有人可以告诉我如何修改VisitorGetPlans方法来生成这些变体吗?

编辑 - 类似于:

class Visitor : ExpressionVisitor
{
    private List<Expression> exprs = new List<Expression>();

    public List<Expression> GetPlans(Expression expr)
    {
        Visit(expr);    
        return exprs;
    }

    override VisitJoin(JoinExpression join)
    {
        // For this join, I can return the following (trite example)
        // return MergeJoinExpresion
        // return NestLoopJoinExpresion      
        var choices = new Expression[] { MergeJoinExpresion.Create(join), NestLoopJoinExpresion.Create(join) };

        foreach(var choice in choices)
        {
             var cloned = Cloner.Clone(choice);
             var newTree = base.VisitJoin(cloned);
             exprs.Add(newTree);
        }

        return base.VisitJoin(join);
    }
}

2 个答案:

答案 0 :(得分:2)

首先,我们要创建一个访问者,它只会帮助我们从JoinExpression中提取Expression个对象列表:

internal class FindJoinsVisitor : ExpressionVisitor
{
    private List<JoinExpression> expressions = new List<JoinExpression>();
    protected override Expression VisitJoin(JoinExpression join)
    {
        expressions.Add(join);
        return base.VisitJoin(join);
    }
    public IEnumerable<JoinExpression> JoinExpressions
    {
        get
        {
            return expressions;
        }
    }
}
public static IEnumerable<JoinExpression> FindJoins(
    this Expression expression)
{
    var visitor = new FindJoinsVisitor();
    visitor.Visit(expression);
    return visitor.JoinExpressions;
}

接下来,我们将使用以下方法(取自this blog post)来获取序列序列的笛卡尔积:

static IEnumerable<IEnumerable<T>> CartesianProduct<T>(
    this IEnumerable<IEnumerable<T>> sequences) 
{ 
    IEnumerable<IEnumerable<T>> emptyProduct = new[] { Enumerable.Empty<T>() }; 
    return sequences.Aggregate( 
        emptyProduct, 
        (accumulator, sequence) => 
            from accseq in accumulator 
            from item in sequence 
            select accseq.Concat(new[] {item})); 
}

接下来,我们将创建一个访问者,该访问者接受一系列表达式,并将第一个表达式的所有实例替换为第二个表达式:

internal class ReplaceVisitor : ExpressionVisitor
{
    private readonly Dictionary<Expression, Expression> lookup;
    public ReplaceVisitor(Dictionary<Expression, Expression> pairsToReplace)
    {
        lookup = pairsToReplace;
    }
    public override Expression Visit(Expression node)
    {
        if(lookup.ContainsKey(node))
            return base.Visit(lookup[node]);
        else
            return base.Visit(node);
    }
}

public static Expression ReplaceAll(this Expression expression,
    Dictionary<Expression, Expression> pairsToReplace)
{
    return new ReplaceVisitor(pairsToReplace).Visit(expression);
}

public static Expression ReplaceAll(this Expression expression,
    IEnumerable<Tuple<Expression, Expression>> pairsToReplace)
{
    var lookup = pairsToReplace.ToDictionary(pair => pair.Item1, pair => pair.Item2);
    return new ReplaceVisitor(lookup).Visit(expression);
}

最后,我们通过查找表达式中的所有连接表达式将所有内容放在一起,将这些表达式投影到一对对,其中JoinExpression是对中的第一个项目,第二个是每个可能的替换值。从那里我们可以采用其中的笛卡尔积来获得表达式替换对的所有组合。最后,我们可以将每个替换组合投影到实际替换原始表达式中所有这些对的结果中:

public static IEnumerable<Expression> AllJoinCombinations(Expression expression)
{
    var combinations = expression.FindJoins()
        .Select(join => new Tuple<Expression, Expression>[]
        {
            Tuple.Create<Expression, Expression>(join, new NestLoopJoinExpresion(join)), 
            Tuple.Create<Expression, Expression>(join, new MergeJoinExpresion(join)),
        })
        .CartesianProduct();

    return combinations.Select(combination => expression.ReplaceAll(combination));
}

答案 1 :(得分:1)

你肯定需要不可变的树。

创建一个类:

class JoinOptionsExpression: JoinExpression {
    public IEnumerable<JoinExpression> Options {get; private set;}
    private JoinOptionsExpression(){}
    public static JoinOptionsExpression Create(IEnumerable<JoinExpression> options){
        return new JoinOptionsExpression{Options = options.ToList().AsReadOnly()}; // you can improve this probably
    }
}

然后在您的VisitJoin方法中返回选项,并返回所有选项:

private List<Dictionary<JoinOptionsExpression,int>> selections = new List<Dictionary<JoinOptionsExpression,int>>{new Dictionary<JoinOptionsExpression,int>()};
override VisitJoin(JoinExpression join)
{
    var choices = new Expression[] { MergeJoinExpresion.Create(join), NestLoopJoinExpresion.Create(join) };
    List<Expression> exprs = new List<Expression>();
    foreach(var choice in choices)
    {
         var cloned = Cloner.Clone(choice);
         var newTree = base.VisitJoin(cloned);
         exprs.Add(newTree);
    }
    var result = JoinOptionsExpression.Create(exprs);
    // now add all choices
    if (exprs.Count > 0)
        foreach (selection in selections.ToList()) // to make sure your don't modify during enumeration, you can improve this too
        {
            selection.Add(result, 0);
            for (i=1; i<exprs.Count; i++)
            {
                var copy= new Dictionary<JoinOptionsExpression, int>(selection);
                copy[result] = i;
                selections.Add(copy);
            }
        }
    return result;
}

然后你需要第二个来自框架访问者的访问者,没有其他原因,只需提取你的选项:

class OptionsExtractor:ExpressionVisitor
{
    public IEnumerable<Expression> Extract(Expression expression, List<Dictionary<JoinOptionsExpression,int>> selections)
    {
        foreach(var selection in selections)
        {
            currentSelections = selection;
            yield return Visit(expression);
        }
    }
    private Dictionary<JoinOptionsExpression,int> currentSelections;
    override Expression Visit(Expression node)
    {
        var opts = node as JoinOptionsExpression;
        if (opts != null)
            return base.Visit(opts.Options.ElementAt(currentSelections[opts]);
        else
            return base.Visit(node);
    }
}

无论如何,一个详尽的搜索可以迅速在你的脸上爆炸,我猜你知道。 免责声明:我只是在这个编辑器中输入了这个,它甚至不会编译,但你应该能够理解。