EF 4.2 Code First查询拦截和导航属性

时间:2012-01-13 08:26:25

标签: entity-framework-4

注意:我项目中的实际实体是不同的。您要阅读的方案是一个简化的示例。我的实际示例影响了比此处列出的实体更多的实体。

在我的项目中,我有一个成员和小组课程如下:

public class Member
{
    public string Name { get; set; }
    public IList<Group> Groups { get; set; }
}

public class Group
{
    public string Name { get; set; }
    public IList<Member> Members { get; set; }
}

我的DbContext实现(省略了ModelBuilder代码):

public class Db : DbContext
{
    public DbSet<Group> Groups { get; set; }
    public DbSet<Member> Members { get; set; }
}

假设我要求强制执行DbContext,只返回名称以“X”开头的组和成员。我可以通过更改Db类来实现这一点:

public class Db : DbContext
{
    public IQueryable<Group> Groups
    {
        get
        {
        return from g in ((IObjectContextAdapter)this).CreateQuery<Group>("SELECT VALUE Groups FROM Groups")
                where g.Name.StartsWith("X")
                select g;
        }
    }

    public IQueryable<Member> Members
    {
        get
        {
        return from m in ((IObjectContextAdapter)this).CreateQuery<Member>("SELECT VALUE Members FROM Members")
                where m.Name.StartsWith("X")
                select m;
        }
    }
}

现在,以下查询仅返回名称以“X”开头的成员和组:

var members = db.Members.ToList();
var groups = db.Groups.ToList();

这里的问题与INCLUDES ...

有关
var members = db.Members.Include(m => m.Groups).ToList();
var groups = db.Groups.Include(g => g.Members).ToList();

虽然“members”列表只有名称以“X”开头的“members”,但Groups属性包含名称不符合的Group对象。相同的模式适用于“组”列表,其中Member对象不符合。

我缺少EF 4.2中的功能吗?

如何影响从导航属性生成的查询?

2 个答案:

答案 0 :(得分:1)

注意:我意识到这可以使用一些改进。它适用于我的场景。代码无法保证编译。

我最终构建的是一种覆盖查询执行的方法,并结合Tip 37 - How to do a Conditional Include中描述的方法,用所需的相关对象填充DbContext的缓存。

设计目标:

  • 提供拦截查询执行的方法
  • 在知道查询的编写者想要“包含”
  • 的哪些导航属性时,需要阻止执行某些“包含”语句
  • 需要填充缓存并将查询重定向到缓存,以防止对远程存储的不必要调用

拦截查询:

  • 创建ReLinqContext<T>
  • 设置包含过滤选项(可选)
  • 定义拦截函数,该函数接受IQueryable<T>并返回IQueryable<T>并将函数分配给ReLinqContext<T>
  • 定义基本查询,并调用AsReLinqQuery()扩展方法,从上方传递ReLinqContext<T>

使用下面的其他类,我可以更改我的DbContext属性,如下所示:


获取DbContext.Groups属性的示例正文

// Enables tracking of ReLinq queries
var ctx = new ReLinqContext<Group>();

// Configures ctx to log, but avoid execution of the Include(string) method
ctx.DisableIncludes();

ctx.Intercept = q =>
{
    // Extract the ObjectQuery<T> out of the DbSet<T>
    // This must be done for q.ChangeSource to work
    var groups = from a in Set<Group>.AsObjectQuery();

    // Rewrite the query to point to a new data source...
    var rewrittenQ = q.ChangeSource(groups);

    // load the results into the context...
    rewrittenQ.Load();

    // Get group ids from the cache
    var groupIds = (from g in Set<Group>().Local
                    select g.Id).ToList();

    // Load respective Member objects into the context...
    if (ctx.IncludePaths.Contains("Members"))
    {
        var members = from m in Set<Member>()
                      from g in m.Groups
                      where groupIds.Contains(g.Id) && m.Name.StartsWith("X")
                      select m;

        members.Load();
    }

    // Add additional if (ctx.IncludePaths.Contains("...")) checks here

    // Return a query that will execute against the DbContext cache
    return q.ChangeSource(Set<Group>().Local.AsQueryable());
};

// The call to ChangeSource during interception
// will allow actual data to be returned.
return new Group[0].AsReLinqQuery(ctx);

实施ReLinqQueries的附加代码

public static class ReLinqExtensions
{
    public static IQueryable<T> ChangeSource<T>(this IQueryable<T> oldSource, IQueryable<T> newSource)
    {
        return (IQueryable<T>) QueryableRebinder.Rebind(oldSource, newSource);
    }

    public static IReLinqQuery<T> AsReLinqQuery<T>(this IEnumerable<T> enumerable, IReLinqContext<T> context)
    {
        return AsReLinqQuery(enumerable.AsQueryable(), context);
    }

    public static IReLinqQuery<T> AsReLinqQuery<T>(this IQueryable<T> query, IReLinqContext<T> context)
    {
        return new ReLinqQuery<T>(query, (IReLinqContext)context);
    }

    public static IReLinqContext<T> DisableIncludes<T>(this IReLinqContext<T> context)
    {
        context.AllowIncludePath = path => false;
        return context;
    }
}

public static class DbSetExtensions
{
    public static ObjectQuery<T> AsObjectQuery<T>(this DbSet<T> source) where T : class
    {
        return (ObjectQuery<T>)DbSetUnwrapper.UnWrap(source);
    }
}

public interface IReLinqContext
{
    IList<string> IncludePaths { get; }
    Delegate Intercept { get; }
    Func<string, bool> AllowIncludePath { get; }
}

public interface IReLinqContext<T>
{
    IEnumerable<string> IncludePaths { get; }
    Func<IQueryable<T>, IQueryable<T>> Intercept { get; set; }
    Func<string, bool> AllowIncludePath { get; set; }
}

public class ReLinqContext<T> : IReLinqContext<T>, IReLinqContext
{
    private readonly IList<string> _includePaths; 

    public ReLinqContext()
    {
        _includePaths = new List<string>();
        IncludePaths = new ReadOnlyCollection<string>(_includePaths);
        Intercept = q => q;
        AllowIncludePath = path => true;
    }

    public IEnumerable<string> IncludePaths { get; private set; }
    public Func<IQueryable<T>, IQueryable<T>> Intercept { get; set; }
    public Func<string, bool> AllowIncludePath { get; set; }

    IList<string> IReLinqContext.IncludePaths { get { return _includePaths; }}

    Delegate IReLinqContext.Intercept
    {
        get
        {
            return Intercept;
        }
    }
}

public interface IReLinqQuery<T> : IOrderedQueryable<T>
{
    IReLinqContext<T> Context { get; }
    IReLinqQuery<T> Include(String path);
}

public class ReLinqQuery<T> : IReLinqQuery<T>
{
    public IReLinqContext<T> Context { get; private set; }
    private Expression expression = null;
    private ReLinqQueryProvider provider = null;

    public ReLinqQuery(IQueryable source, IReLinqContext context)
    {
        Context = (IReLinqContext<T>)context;
        expression = Expression.Constant(this);
        provider = new ReLinqQueryProvider(source, context);
    }

    public ReLinqQuery(IQueryable source, IReLinqContext context, Expression e)
    {
        if (e == null) throw new ArgumentNullException("e");
        expression = e;
        provider = new ReLinqQueryProvider(source, context);
    }

    public IEnumerator<T> GetEnumerator()
    {
        return ((IEnumerable<T>)provider.Execute(expression)).GetEnumerator();
    }

    IEnumerator IEnumerable.GetEnumerator()
    {
        return ((IEnumerable)provider.Execute(expression)).GetEnumerator();
    }

    public IReLinqQuery<T> Include(String path)
    {
        ((IReLinqContext)Context).IncludePaths.Add(path);

        if (!Context.AllowIncludePath(path))
        {
            return this;
        }

        var possibleObjectQuery = provider.Source as DbQuery<T>;

        if (possibleObjectQuery != null)
        {
            return new ReLinqQuery<T>(possibleObjectQuery.Include(path), (IReLinqContext)Context);
        }

        return this;
    }

    public Type ElementType
    {
        get { return typeof(T); }
    }

    public Expression Expression
    {
        get { return expression; }
    }

    public IQueryProvider Provider
    {
        get { return provider; }
    }
}

public class ReLinqQueryProvider : IQueryProvider
{
    internal IQueryable Source { get; private set; }
    internal IReLinqContext Context { get; private set; }

    public ReLinqQueryProvider(IQueryable source, IReLinqContext context)
    {
        if (source == null) throw new ArgumentNullException("source");
        Source = source;
        Context = context;
    }

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        return new ReLinqQuery<TElement>(Source, Context, expression);
    }

    public IQueryable CreateQuery(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        Type elementType = expression.Type.GetGenericArguments().Single();
        IQueryable result = (IQueryable)Activator.CreateInstance(typeof(ReLinqQuery<>).MakeGenericType(elementType),
                new object[] { Source, Context, expression });
        return result;
    }

    public TResult Execute<TResult>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        object result = (this as IQueryProvider).Execute(expression);
        return (TResult)result;
    }

    public object Execute(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        var translated = ReLinqQueryUnwrapper.UnWrap(expression, Source);
        var translatedQuery = Source.Provider.CreateQuery(translated);
        //var query = CreateQuery(expression);

        var interceptedQuery = Context.Intercept.DynamicInvoke(translatedQuery);

        return interceptedQuery;
    }
} 

public class ReLinqQueryUnwrapper : ExpressionVisitor
{
    private readonly IQueryable _source;

    public static Expression UnWrap(Expression expression, IQueryable source)
    {
        var queryTranslator = new ReLinqQueryUnwrapper(source);

        return queryTranslator.Visit(expression);
    }

    public ReLinqQueryUnwrapper(IQueryable source)
    {
        _source = source;
    }

    #region Visitors
    protected override Expression VisitConstant(ConstantExpression c)
    {
        if (c.Type == typeof(ReLinqQuery<>).MakeGenericType(_source.ElementType))
        {
            return _source.Expression;
        }

        return base.VisitConstant(c);
    }
    #endregion
}

public class DbSetUnwrapper : ExpressionVisitor
{
    public static IQueryable UnWrap(IQueryable source)
    {
        var dbSetUnwrapper = new DbSetUnwrapper(source);
        dbSetUnwrapper.Visit(source.Expression);
        return dbSetUnwrapper._target;
    }

    private readonly IQueryable _source;
    private IQueryable _target;

    public DbSetUnwrapper(IQueryable source)
    {
        _source = source;
    }

    public override Expression Visit(Expression node)
    {
        if(node.NodeType == ExpressionType.Constant)
        {
            var c = (ConstantExpression) node;

            if (c.Type == typeof(ObjectQuery<>).MakeGenericType(_source.ElementType))
            {
                _target = (IQueryable)c.Value;
            }
        }

        return base.Visit(node);
    }
}

public class QueryableRebinder : ExpressionVisitor
{
    private IQueryable _oldSource;
    private IQueryable _newSource;

    public static IQueryable Rebind(IQueryable oldSource, IQueryable newSource)
    {
        var queryTranslator = new QueryableRebinder(oldSource, newSource);

        return newSource.Provider.CreateQuery(queryTranslator.Visit(oldSource.Expression));
    }

    public QueryableRebinder(IQueryable oldSource, IQueryable newSource)
    {
        _oldSource = oldSource;
        _newSource = newSource;
    }

    #region Visitors
    protected override Expression VisitConstant(ConstantExpression c)
    {
        if (typeof(IQueryable<>).MakeGenericType(_oldSource.ElementType).IsAssignableFrom(c.Type))
        {
            return Expression.Constant(_newSource);
        }

        return base.VisitConstant(c);
    }
    #endregion
}

答案 1 :(得分:0)