我正在尝试将Linq单元测试到实体,我想在不同的外壳中搜索相同的单词并返回相同的单词。
目前的情况是我正在尝试单元测试搜索大写单词,EG“Hi”和“hi”。
Linq to使用实体框架的实体目前支持这一点,我可以在where子句中搜索这两个术语,它可以为我工作。
问题: 我正在尝试制作一个行为相同的模拟队列:
public class SimpleWord
{
public string Text;
}
[Test]
public void someTest()
{
//arrange
var lowerWords = new[] { "hi" };
var upperWords = new[] { "Hi" };
var wordsList = new List<SimpleWord> {new SimpleWord { Text = "hi" } };
IDbSet<SimpleWord> wordsDbSet = Substitute.For<DbSet<SimpleWord>, IDbSet<SimpleWord>>();
//set up the mock dbSet
var dataAsList = wordsList.ToList();
var queryable = dataAsList.AsQueryable();
wordsDbSet.Provider.Returns(queryable.Provider);
wordsDbSet.Expression.Returns(queryable.Expression);
wordsDbSet.ElementType.Returns(queryable.ElementType);
wordsDbSet.GetEnumerator().Returns(queryable.GetEnumerator());
//act
var resultLower = wordsDbSet.Where(wrd => lowerWords.Contains(wrd.Text)).ToList();
var resultHigher = wordsDbSet.Where(wrd => upperWords.Contains(wrd.Text)).ToList();
//assert
Assert.That(resultHigher.Count, Is.EqualTo(1), "did not find upper case");
Assert.That(resultLower.Count, Is.EqualTo(1), "did not find lower case");
}
问题: 当我在任何地方调用搜索时,如何使wordsDbSet不区分大小写.Where()搜索它。
我不想改变使徒行传
var resultHigher = wordsDbSet.Where(wrd =>
upperWords.Contains(wrd.Text, StringComparer.OrdinalIgnoreCase)).ToList();
我正在寻找的答案是改变安排:
wordsDbSet.When(contains.IsCalled).Return(contains.OrdinalIgnoreCasing)
感谢您的期待!
答案 0 :(得分:3)
好的......可行但很长(不是很复杂......只有很长时间)。主要的问题是实现IQueryable<>
和IQueryProvider
是一件痛苦的事,很难解释它是如何工作的(你可以复制一些你可以在互联网上找到的代码,但是对于为什么和它是如何工作的。)
我写的是一个IQueryable<>
包装器,它“包装”一个IQueryable<>
对象(就像AsQueryable()
返回的对象一样,“动态”替换所有表达式树传递了一些string
方法(加上Enumerable.Contains<string>
)以及接受StringComparison
/ StringComparer
的相应重载。请使用它:
var arr = new[] { "foo " };
var query = new[] { "Foo", "Bar", "bar" }
.AsQueryable()
.AsStringComparison(StringComparison.CurrentCultureIgnoreCase);
// query is a IQueryable<>
var res = query
.Where(x => string.Compare(x, "foo") < 0)
.Where(x => x.CompareTo("foo") < 0)
.Where(x => string.Compare(x, 0, "foo", 0, 3) < 0)
.Where(x => x.Contains("foo"))
.Where(x => string.Equals(x, "foo"))
.Where(x => x.Equals("foo"))
.Where(x => arr.Contains(x))
.Where(x => x == "foo")
.Where(x => x != "foo")
;
(这是我要替换的所有方法的列表)
和实施:
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
public static class StringComparisonQueryableWrapper
{
public static IQueryable<T> AsStringComparison<T>(this IQueryable<T> query, StringComparison comparisonType)
{
return new StringComparisonQueryableWrapper<T>(query, comparisonType);
}
}
public class StringComparisonQueryableWrapper<T> : IQueryable<T>, IQueryable, IQueryProvider
{
private readonly IQueryable<T> baseQuery;
public readonly StringComparison ComparisonType;
public StringComparisonQueryableWrapper(IQueryable<T> baseQuery, StringComparison comparisonType)
{
this.baseQuery = baseQuery;
this.ComparisonType = comparisonType;
}
Expression IQueryable.Expression => baseQuery.Expression;
Type IQueryable.ElementType => baseQuery.ElementType;
IQueryProvider IQueryable.Provider => this;
IQueryable IQueryProvider.CreateQuery(Expression expression)
{
Type type = expression.Type;
var iqueryableT = type.GetInterfaces().Where(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(IQueryable<>)).Single();
Type type2 = iqueryableT.GetGenericArguments()[0];
var thisType = typeof(StringComparisonQueryableWrapper<>).MakeGenericType(typeof(T));
var createQueryMethod = thisType.GetMethods(BindingFlags.Instance | BindingFlags.NonPublic).Where(x => x.Name == "System.Linq.IQueryProvider.CreateQuery" && x.IsGenericMethod).Single().MakeGenericMethod(type2);
var queryable = (IQueryable)createQueryMethod.Invoke(this, new object[] { expression });
return queryable;
}
IQueryable<TElement> IQueryProvider.CreateQuery<TElement>(Expression expression)
{
var expression2 = TransformExpression(expression);
var query = baseQuery.Provider.CreateQuery<TElement>(expression2);
return new StringComparisonQueryableWrapper<TElement>(query, ComparisonType);
}
object IQueryProvider.Execute(Expression expression)
{
var expression2 = TransformExpression(expression);
return baseQuery.Provider.Execute(expression2);
}
TResult IQueryProvider.Execute<TResult>(Expression expression)
{
var expression2 = TransformExpression(expression);
return baseQuery.Provider.Execute<TResult>(expression2);
}
IEnumerator<T> IEnumerable<T>.GetEnumerator()
{
return baseQuery.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return baseQuery.GetEnumerator();
}
private Expression TransformExpression(Expression expression)
{
Expression expression2 = new StringComparisonExpressionTranformer(ComparisonType).Visit(expression);
return expression2;
}
private class StringComparisonExpressionTranformer : ExpressionVisitor
{
private readonly StringComparison comparisonType;
private static readonly IReadOnlyDictionary<MethodInfo, Func<MethodCallExpression, StringComparison, Expression>> transformers;
private static readonly IReadOnlyDictionary<MethodInfo, Func<BinaryExpression, StringComparison, Expression>> transformers2;
// https://stackoverflow.com/a/32764110/613130
private static readonly IReadOnlyDictionary<StringComparison, StringComparer> comparisonToComparer = new Dictionary<StringComparison, System.StringComparer>
{
{ StringComparison.CurrentCulture, StringComparer.CurrentCulture },
{ StringComparison.CurrentCultureIgnoreCase, StringComparer.CurrentCultureIgnoreCase },
{ StringComparison.InvariantCulture, StringComparer.InvariantCulture },
{ StringComparison.InvariantCultureIgnoreCase, StringComparer.InvariantCultureIgnoreCase },
{ StringComparison.Ordinal, StringComparer.Ordinal },
{ StringComparison.OrdinalIgnoreCase, StringComparer.OrdinalIgnoreCase }
};
static StringComparisonExpressionTranformer()
{
var transformers = new Dictionary<MethodInfo, Func<MethodCallExpression, StringComparison, Expression>>();
{
// string.Compare("foo", "bar")
var method = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
transformers.Add(method, Compare);
}
{
// string.Compare("foo", 0, "bar", 0, 3)
var method = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(int), typeof(string), typeof(int), typeof(int) }, null);
transformers.Add(method, CompareIndexLength);
}
{
// "foo".CompareTo("bar")
var method = typeof(string).GetMethod(nameof(string.CompareTo), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
transformers.Add(method, CompareTo);
}
{
// "foo".Contains("bar")
var method = typeof(string).GetMethod(nameof(string.Contains), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
transformers.Add(method, Contains);
}
{
// string.Equals("foo", "bar")
var method = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
transformers.Add(method, EqualsStatic);
}
{
// "foo".Equals("bar")
var method = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
transformers.Add(method, EqualsInstance);
}
{
// Enumerable.Contains<TSource>(source, "foo")
var method = (from x in typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
where x.Name == nameof(Enumerable.Contains)
let args = x.GetGenericArguments()
where args.Length == 1
let pars = x.GetParameters()
where pars.Length == 2 &&
pars[0].ParameterType == typeof(IEnumerable<>).MakeGenericType(args[0]) &&
pars[1].ParameterType == args[0]
select x).Single();
// Enumerable.Contains<string>(source, "foo")
var method2 = method.MakeGenericMethod(typeof(string));
transformers.Add(method2, EnumerableContains);
}
// TODO: all the various Array.Find*, Array.IndexOf
StringComparisonExpressionTranformer.transformers = transformers;
var transformers2 = new Dictionary<MethodInfo, Func<BinaryExpression, StringComparison, Expression>>();
{
// ==
var method = typeof(string).GetMethod("op_Equality", BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
transformers2.Add(method, OpEquality);
}
{
// !=
var method = typeof(string).GetMethod("op_Inequality", BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
transformers2.Add(method, OpInequality);
}
StringComparisonExpressionTranformer.transformers2 = transformers2;
}
public StringComparisonExpressionTranformer(StringComparison comparisonType)
{
this.comparisonType = comparisonType;
}
// methods
protected override Expression VisitMethodCall(MethodCallExpression node)
{
Func<MethodCallExpression, StringComparison, Expression> transformer;
if (transformers.TryGetValue(node.Method, out transformer))
{
Expression node2 = transformer(node, comparisonType);
return Visit(node2);
}
return base.VisitMethodCall(node);
}
// operators
protected override Expression VisitBinary(BinaryExpression node)
{
Func<BinaryExpression, StringComparison, Expression> transformer;
if (node.Method != null && transformers2.TryGetValue(node.Method, out transformer))
{
Expression node2 = transformer(node, comparisonType);
return Visit(node2);
}
return base.VisitBinary(node);
}
private static readonly MethodInfo StringEqualsStatic = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string), typeof(StringComparison) }, null);
private static readonly MethodInfo StringEqualsInstance = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string), typeof(StringComparison) }, null);
private static readonly MethodInfo StringCompareStatic = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string), typeof(StringComparison) }, null);
private static readonly MethodInfo StringCompareIndexLengthStatic = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(int), typeof(string), typeof(int), typeof(int), typeof(StringComparison) }, null);
private static readonly MethodInfo StringIndexOfInstance = typeof(string).GetMethod(nameof(string.IndexOf), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string), typeof(StringComparison) }, null);
private static readonly MethodInfo EnumerableContainsStatic = (from x in typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
where x.Name == nameof(Enumerable.Contains)
let args = x.GetGenericArguments()
where args.Length == 1
let pars = x.GetParameters()
where pars.Length == 3 &&
pars[0].ParameterType == typeof(IEnumerable<>).MakeGenericType(args[0]) &&
pars[1].ParameterType == args[0] &&
pars[2].ParameterType == typeof(IEqualityComparer<>).MakeGenericType(args[0])
select x).Single().MakeGenericMethod(typeof(string));
private static Expression Compare(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringCompareStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparisonType));
}
private static Expression CompareIndexLength(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringCompareIndexLengthStatic, exp.Arguments[0], exp.Arguments[1], exp.Arguments[2], exp.Arguments[3], exp.Arguments[4], Expression.Constant(comparisonType));
}
private static Expression CompareTo(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringCompareStatic, exp.Object, exp.Arguments[0], Expression.Constant(comparisonType));
}
private static Expression Contains(MethodCallExpression exp, StringComparison comparisonType)
{
// No "".Contains(, StringComparison). Translate to "".IndexOf(, StringComparison) != -1
return Expression.NotEqual(Expression.Call(exp.Object, StringIndexOfInstance, exp.Arguments[0], Expression.Constant(comparisonType)), Expression.Constant(-1));
}
private static Expression EqualsStatic(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringEqualsStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparisonType));
}
private static Expression EqualsInstance(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(exp.Object, StringEqualsInstance, exp.Arguments[0], Expression.Constant(comparisonType));
}
private static Expression EnumerableContains(MethodCallExpression exp, StringComparison comparisonType)
{
StringComparer comparer = comparisonToComparer[comparisonType];
return Expression.Call(EnumerableContainsStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparer));
}
private static Expression OpEquality(BinaryExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringEqualsStatic, exp.Left, exp.Right, Expression.Constant(comparisonType));
}
private static Expression OpInequality(BinaryExpression exp, StringComparison comparisonType)
{
return Expression.Not(Expression.Call(StringEqualsStatic, exp.Left, exp.Right, Expression.Constant(comparisonType)));
}
}
}
如果您看一下,有一个非常简单的代理IQueryable
/ IQueryProvider
实施(StringComparisonQueryableWrapper<T>
),它使用ExpressionVisitor
(StringComparisonExpressionTranformer
)来实现查找并替换某些特定MethodCallExpression
(对方法的调用)和BinaryExpression
(==
和!=
运算符)MethodCallExpression
使用{{1}的方法}} / StringComparison
。缺少的是StringComparer
,Array.IndexOf
...