如何拦截和修改批量更新?

时间:2019-05-15 13:18:47

标签: c# expression-trees entity-framework-plus entity-framework-extensions

在我的代码中的多个地方,我正在使用方便的dandy @UseFilters(UnprocessableEntityExceptionFilter) export class EventsController {} 扩展名进行批量更新,例如

Z.EntityFramework.Plus

这将更新await db.Foos .Where(f => f.SomeCondition) .UpdateAsync(f => new Foo { Field1 = "bar", Field2 = f.Field2 + 1 }); 为真的所有Foo记录,将SomeCondition设置为“ bar”,而Field1将增加1。

现在提出了一个新要求,其中某些表(但不是全部)正在跟踪Field2。其中包括我正在进行批量更新的记录。

所以我的方法是这样的。我有一个界面:

ModifiedDate

因此,我跟踪public interface ITrackModifiedDate { DateTime ModifiedDate { get; set; } } 的所有类都可以实现ModifiedDate。然后,我编写一个中间人扩展名以拦截ITrackModifiedDate调用:

.UpdateAsync()

如您所见,除了已经更新的其他任何字段之外,我还不确定如何修改public static async Task<int> UpdateAsync<T>(this IQueryable<T> queryable, Expression<Func<T, T>> updateFactory) where T : class { if (typeof(ITrackModifiedDate).IsAssignableFrom(typeof(T))) { // TODO Now what? } return await BatchUpdateExtensions.UpdateAsync(queryable, updateFactory); } 以将updateFactory设置为ModifiedDate

如何做到?

更新:我不反对更改扩展名,以使其只接受类型为DateTime.UtcNow的{​​{1}},如果可以的话,即

T

1 个答案:

答案 0 :(得分:1)

我通过以下代码使用它:

using System;
using System.Data.Entity;
using System.Linq;
using System.Linq.Expressions;
using System.Threading.Tasks;
using Z.EntityFramework.Plus; 

class Program
{
    static async Task Main(string[] args)
    {
        using (var context = new SomeContext())
        {
            await context
                .Customers
                .Where(c => c.Email.Contains("42"))
                .CustomUpdateAsync((c) => new Customer()
                {
                    Email = "4242"
                });
        }
    }

}

public static class Helper
{
    public static async Task<int> CustomUpdateAsync<T>(this IQueryable<T> queryable, Expression<Func<T, T>> updateFactory)
        where T : class
    {
        var targetType = typeof(T);
        if (typeof(ITrackModifiedDate).IsAssignableFrom(targetType))
        {
            updateFactory = (Expression<Func<T, T>>)new TrackModifiedDateVisitor().Modify(updateFactory);
        }

        return await BatchUpdateExtensions.UpdateAsync(queryable, updateFactory);
    }
}


public class TrackModifiedDateVisitor : ExpressionVisitor
{
    public Expression Modify(Expression expression)
    {
        return Visit(expression);
    }

    public override Expression Visit(Expression node)
    {
        if (node is MemberInitExpression initExpression)
        {
            var existingBindings = initExpression.Bindings.ToList();
            var modifiedProperty = initExpression.NewExpression.Type.GetProperty(nameof(ITrackModifiedDate.ModifiedDate));

            // it will be `some.ModifiedDate = currentDate`
            var modifiedExpression = Expression.Bind(
                modifiedProperty,
                Expression.Constant(DateTime.Now, typeof(DateTime))
                );

            existingBindings.Add(modifiedExpression);

            // and then we just generate new MemberInit expression but with additional property assigment
            return base.Visit(Expression.MemberInit(initExpression.NewExpression, existingBindings));
        }

        return base.Visit(node);
    }
}


public class SomeContext: DbContext
{
    public SomeContext()
        : base("Data Source=.;Initial Catalog=TestDb;Integrated Security=SSPI;")
    {
        Database.SetInitializer(new CreateDatabaseIfNotExists<SomeContext>());
    }

    public DbSet<Customer> Customers { get; set; }
}

public class Customer: ITrackModifiedDate
{
    public int ID { get; set; }
    public string Email { get; set; }
    public DateTime ModifiedDate { get; set; }
}

public interface ITrackModifiedDate
{
    DateTime ModifiedDate { get; set; }
}

需要的部分是TrackModifiedDateVisitor类,它遍历updateFactory表达式并在找到MemberInitExpression并进行更新时进行遍历。最初,它具有属性分配的列表,我们为ModifiedDate生成一个新属性,并使用现有的属性加上生成的属性创建新的MemberInitExpression

执行访问者代码后的结果-updateFactory将具有

c => new Customer() {Email = "4242", ModifiedDate = 5/16/2019 23:19:00}