创建一个iQueryable,它对单元测试linq到实体不区分大小写 - NSubstitute

时间:2018-06-18 09:43:24

标签: c# linq linq-to-sql linq-to-entities nsubstitute

我正在尝试将Linq单元测试到实体,我想在不同的外壳中搜索相同的单词并返回相同的单词。

目前的情况是我正在尝试单元测试搜索大写单词,EG“Hi”和“hi”。

Linq to使用实体框架的实体目前支持这一点,我可以在where子句中搜索这两个术语,它可以为我工作。

问题: 我正在尝试制作一个行为相同的模拟队列:

 public class SimpleWord
    {
        public string Text;
    }

    [Test]
    public void someTest()
    {
        //arrange
        var lowerWords = new[] { "hi" };
        var upperWords = new[] { "Hi" };

        var wordsList = new List<SimpleWord> {new SimpleWord { Text = "hi" } };
        IDbSet<SimpleWord> wordsDbSet = Substitute.For<DbSet<SimpleWord>, IDbSet<SimpleWord>>();

        //set up the mock dbSet
        var dataAsList = wordsList.ToList();
        var queryable = dataAsList.AsQueryable();
        wordsDbSet.Provider.Returns(queryable.Provider);
        wordsDbSet.Expression.Returns(queryable.Expression);
        wordsDbSet.ElementType.Returns(queryable.ElementType);
        wordsDbSet.GetEnumerator().Returns(queryable.GetEnumerator());

        //act
        var resultLower = wordsDbSet.Where(wrd => lowerWords.Contains(wrd.Text)).ToList();
        var resultHigher = wordsDbSet.Where(wrd => upperWords.Contains(wrd.Text)).ToList();
        //assert
        Assert.That(resultHigher.Count, Is.EqualTo(1), "did not find upper case");
        Assert.That(resultLower.Count, Is.EqualTo(1), "did not find lower case");
    }

问题: 当我在任何地方调用搜索时,如何使wordsDbSet不区分大小写.Where()搜索它。

我不想改变使徒行传

 var resultHigher = wordsDbSet.Where(wrd => 
                    upperWords.Contains(wrd.Text, StringComparer.OrdinalIgnoreCase)).ToList();

我正在寻找的答案是改变安排:

wordsDbSet.When(contains.IsCalled).Return(contains.OrdinalIgnoreCasing)

感谢您的期待!

1 个答案:

答案 0 :(得分:3)

好的......可行但很长(不是很复杂......只有很长时间)。主要的问题是实现IQueryable<>IQueryProvider是一件痛苦的事,很难解释它是如何工作的(你可以复制一些你可以在互联网上找到的代码,但是对于为什么和它是如何工作的。)

我写的是一个IQueryable<>包装器,它“包装”一个IQueryable<>对象(就像AsQueryable()返回的对象一样,“动态”替换所有表达式树传递了一些string方法(加上Enumerable.Contains<string>)以及接受StringComparison / StringComparer的相应重载。请使用它:

var arr = new[] { "foo " };
var query = new[] { "Foo", "Bar", "bar" }
    .AsQueryable()
    .AsStringComparison(StringComparison.CurrentCultureIgnoreCase);

// query is a IQueryable<>

var res = query
    .Where(x => string.Compare(x, "foo") < 0)
    .Where(x => x.CompareTo("foo") < 0)
    .Where(x => string.Compare(x, 0, "foo", 0, 3) < 0)
    .Where(x => x.Contains("foo"))
    .Where(x => string.Equals(x, "foo"))
    .Where(x => x.Equals("foo"))
    .Where(x => arr.Contains(x))
    .Where(x => x == "foo")
    .Where(x => x != "foo")
;

(这是我要替换的所有方法的列表)

和实施:

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;

public static class StringComparisonQueryableWrapper
{
    public static IQueryable<T> AsStringComparison<T>(this IQueryable<T> query, StringComparison comparisonType)
    {
        return new StringComparisonQueryableWrapper<T>(query, comparisonType);
    }
}

public class StringComparisonQueryableWrapper<T> : IQueryable<T>, IQueryable, IQueryProvider
{
    private readonly IQueryable<T> baseQuery;
    public readonly StringComparison ComparisonType;

    public StringComparisonQueryableWrapper(IQueryable<T> baseQuery, StringComparison comparisonType)
    {
        this.baseQuery = baseQuery;
        this.ComparisonType = comparisonType;
    }

    Expression IQueryable.Expression => baseQuery.Expression;

    Type IQueryable.ElementType => baseQuery.ElementType;

    IQueryProvider IQueryable.Provider => this;

    IQueryable IQueryProvider.CreateQuery(Expression expression)
    {
        Type type = expression.Type;
        var iqueryableT = type.GetInterfaces().Where(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(IQueryable<>)).Single();
        Type type2 = iqueryableT.GetGenericArguments()[0];

        var thisType = typeof(StringComparisonQueryableWrapper<>).MakeGenericType(typeof(T));
        var createQueryMethod = thisType.GetMethods(BindingFlags.Instance | BindingFlags.NonPublic).Where(x => x.Name == "System.Linq.IQueryProvider.CreateQuery" && x.IsGenericMethod).Single().MakeGenericMethod(type2);
        var queryable = (IQueryable)createQueryMethod.Invoke(this, new object[] { expression });
        return queryable;
    }

    IQueryable<TElement> IQueryProvider.CreateQuery<TElement>(Expression expression)
    {
        var expression2 = TransformExpression(expression);
        var query = baseQuery.Provider.CreateQuery<TElement>(expression2);
        return new StringComparisonQueryableWrapper<TElement>(query, ComparisonType);
    }

    object IQueryProvider.Execute(Expression expression)
    {
        var expression2 = TransformExpression(expression);
        return baseQuery.Provider.Execute(expression2);
    }

    TResult IQueryProvider.Execute<TResult>(Expression expression)
    {
        var expression2 = TransformExpression(expression);
        return baseQuery.Provider.Execute<TResult>(expression2);
    }

    IEnumerator<T> IEnumerable<T>.GetEnumerator()
    {
        return baseQuery.GetEnumerator();
    }

    IEnumerator IEnumerable.GetEnumerator()
    {
        return baseQuery.GetEnumerator();
    }

    private Expression TransformExpression(Expression expression)
    {
        Expression expression2 = new StringComparisonExpressionTranformer(ComparisonType).Visit(expression);
        return expression2;
    }

    private class StringComparisonExpressionTranformer : ExpressionVisitor
    {
        private readonly StringComparison comparisonType;

        private static readonly IReadOnlyDictionary<MethodInfo, Func<MethodCallExpression, StringComparison, Expression>> transformers;
        private static readonly IReadOnlyDictionary<MethodInfo, Func<BinaryExpression, StringComparison, Expression>> transformers2;

        // https://stackoverflow.com/a/32764110/613130
        private static readonly IReadOnlyDictionary<StringComparison, StringComparer> comparisonToComparer = new Dictionary<StringComparison, System.StringComparer>
        {
            { StringComparison.CurrentCulture, StringComparer.CurrentCulture },
            { StringComparison.CurrentCultureIgnoreCase, StringComparer.CurrentCultureIgnoreCase },
            { StringComparison.InvariantCulture, StringComparer.InvariantCulture },
            { StringComparison.InvariantCultureIgnoreCase, StringComparer.InvariantCultureIgnoreCase },
            { StringComparison.Ordinal, StringComparer.Ordinal },
            { StringComparison.OrdinalIgnoreCase, StringComparer.OrdinalIgnoreCase }
        };

        static StringComparisonExpressionTranformer()
        {
            var transformers = new Dictionary<MethodInfo, Func<MethodCallExpression, StringComparison, Expression>>();

            {
                // string.Compare("foo", "bar")
                var method = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
                transformers.Add(method, Compare);
            }

            {
                // string.Compare("foo", 0, "bar", 0, 3)
                var method = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(int), typeof(string), typeof(int), typeof(int) }, null);
                transformers.Add(method, CompareIndexLength);
            }

            {
                // "foo".CompareTo("bar")
                var method = typeof(string).GetMethod(nameof(string.CompareTo), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
                transformers.Add(method, CompareTo);
            }

            {
                // "foo".Contains("bar")
                var method = typeof(string).GetMethod(nameof(string.Contains), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
                transformers.Add(method, Contains);
            }

            {
                // string.Equals("foo", "bar")
                var method = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
                transformers.Add(method, EqualsStatic);
            }

            {
                // "foo".Equals("bar")
                var method = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
                transformers.Add(method, EqualsInstance);
            }

            {
                // Enumerable.Contains<TSource>(source, "foo")
                var method = (from x in typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
                              where x.Name == nameof(Enumerable.Contains)
                              let args = x.GetGenericArguments()
                              where args.Length == 1
                              let pars = x.GetParameters()
                              where pars.Length == 2 &&
                                  pars[0].ParameterType == typeof(IEnumerable<>).MakeGenericType(args[0]) &&
                                  pars[1].ParameterType == args[0]
                              select x).Single();

                // Enumerable.Contains<string>(source, "foo")
                var method2 = method.MakeGenericMethod(typeof(string));

                transformers.Add(method2, EnumerableContains);
            }

            // TODO: all the various Array.Find*, Array.IndexOf

            StringComparisonExpressionTranformer.transformers = transformers;

            var transformers2 = new Dictionary<MethodInfo, Func<BinaryExpression, StringComparison, Expression>>();

            {
                // ==
                var method = typeof(string).GetMethod("op_Equality", BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
                transformers2.Add(method, OpEquality);
            }

            {
                // !=
                var method = typeof(string).GetMethod("op_Inequality", BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
                transformers2.Add(method, OpInequality);
            }

            StringComparisonExpressionTranformer.transformers2 = transformers2;
        }

        public StringComparisonExpressionTranformer(StringComparison comparisonType)
        {
            this.comparisonType = comparisonType;
        }

        // methods
        protected override Expression VisitMethodCall(MethodCallExpression node)
        {
            Func<MethodCallExpression, StringComparison, Expression> transformer;

            if (transformers.TryGetValue(node.Method, out transformer))
            {
                Expression node2 = transformer(node, comparisonType);
                return Visit(node2);
            }

            return base.VisitMethodCall(node);
        }

        // operators
        protected override Expression VisitBinary(BinaryExpression node)
        {
            Func<BinaryExpression, StringComparison, Expression> transformer;

            if (node.Method != null && transformers2.TryGetValue(node.Method, out transformer))
            {
                Expression node2 = transformer(node, comparisonType);
                return Visit(node2);
            }

            return base.VisitBinary(node);
        }

        private static readonly MethodInfo StringEqualsStatic = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string), typeof(StringComparison) }, null);
        private static readonly MethodInfo StringEqualsInstance = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string), typeof(StringComparison) }, null);

        private static readonly MethodInfo StringCompareStatic = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string), typeof(StringComparison) }, null);
        private static readonly MethodInfo StringCompareIndexLengthStatic = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(int), typeof(string), typeof(int), typeof(int), typeof(StringComparison) }, null);

        private static readonly MethodInfo StringIndexOfInstance = typeof(string).GetMethod(nameof(string.IndexOf), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string), typeof(StringComparison) }, null);

        private static readonly MethodInfo EnumerableContainsStatic = (from x in typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
                                                                       where x.Name == nameof(Enumerable.Contains)
                                                                       let args = x.GetGenericArguments()
                                                                       where args.Length == 1
                                                                       let pars = x.GetParameters()
                                                                       where pars.Length == 3 &&
                                                                           pars[0].ParameterType == typeof(IEnumerable<>).MakeGenericType(args[0]) &&
                                                                           pars[1].ParameterType == args[0] &&
                                                                           pars[2].ParameterType == typeof(IEqualityComparer<>).MakeGenericType(args[0])
                                                                       select x).Single().MakeGenericMethod(typeof(string));

        private static Expression Compare(MethodCallExpression exp, StringComparison comparisonType)
        {
            return Expression.Call(StringCompareStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparisonType));
        }

        private static Expression CompareIndexLength(MethodCallExpression exp, StringComparison comparisonType)
        {
            return Expression.Call(StringCompareIndexLengthStatic, exp.Arguments[0], exp.Arguments[1], exp.Arguments[2], exp.Arguments[3], exp.Arguments[4], Expression.Constant(comparisonType));
        }

        private static Expression CompareTo(MethodCallExpression exp, StringComparison comparisonType)
        {
            return Expression.Call(StringCompareStatic, exp.Object, exp.Arguments[0], Expression.Constant(comparisonType));
        }

        private static Expression Contains(MethodCallExpression exp, StringComparison comparisonType)
        {
            // No "".Contains(, StringComparison). Translate to "".IndexOf(, StringComparison) != -1
            return Expression.NotEqual(Expression.Call(exp.Object, StringIndexOfInstance, exp.Arguments[0], Expression.Constant(comparisonType)), Expression.Constant(-1));
        }

        private static Expression EqualsStatic(MethodCallExpression exp, StringComparison comparisonType)
        {
            return Expression.Call(StringEqualsStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparisonType));
        }

        private static Expression EqualsInstance(MethodCallExpression exp, StringComparison comparisonType)
        {
            return Expression.Call(exp.Object, StringEqualsInstance, exp.Arguments[0], Expression.Constant(comparisonType));
        }

        private static Expression EnumerableContains(MethodCallExpression exp, StringComparison comparisonType)
        {
            StringComparer comparer = comparisonToComparer[comparisonType];
            return Expression.Call(EnumerableContainsStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparer));
        }


        private static Expression OpEquality(BinaryExpression exp, StringComparison comparisonType)
        {
            return Expression.Call(StringEqualsStatic, exp.Left, exp.Right, Expression.Constant(comparisonType));
        }

        private static Expression OpInequality(BinaryExpression exp, StringComparison comparisonType)
        {
            return Expression.Not(Expression.Call(StringEqualsStatic, exp.Left, exp.Right, Expression.Constant(comparisonType)));
        }
    }
}

如果您看一下,有一个非常简单的代理IQueryable / IQueryProvider实施(StringComparisonQueryableWrapper<T>),它使用ExpressionVisitorStringComparisonExpressionTranformer)来实现查找并替换某些特定MethodCallExpression(对方法的调用)和BinaryExpression==!=运算符)MethodCallExpression使用{{1}的方法}} / StringComparison。缺少的是StringComparerArray.IndexOf ...

的替换者