使用模拟数据库进行async / await单元测试

时间:2014-05-30 11:23:34

标签: c# unit-testing mocking async-await

我有一个方法可以从某个数据库表中获取所有记录:

public async Task<List<T>> GetAllRecordsAsync<T>(EntitiesNew source) where T : class, IGetAllRecords
{
    if (source != null) 
        return await source.Set<T>().ToListAsync();
    return null;
}

我正在尝试编写单元测试。我的测试方法是:

 public async Task GetAllRecordsAsyncTest()
    {
        var data = new List<TABLE_NAME>
        {
            new TABLE_NAME {VALID= 1, NAME = "test 1"},
            new TABLE_NAME {VALID= 1, NAME = "test 2"}
        }.AsQueryable();

        var mockSet = new Mock<DbSet<TABLE_NAME>>();
        mockSet.As<IQueryable<TABLE_NAME>>().Setup(m => m.Provider).Returns(data.Provider);
        mockSet.As<IQueryable<TABLE_NAME>>().Setup(m => m.Expression).Returns(data.Expression);
        mockSet.As<IQueryable<TABLE_NAME>>().Setup(m => m.ElementType).Returns(data.ElementType);
        mockSet.As<IQueryable<TABLE_NAME>>().Setup(m => m.GetEnumerator()).Returns(data.GetEnumerator());

        var mockContext = new Mock<EntitiesNew>();
        mockContext.Setup(x => x.TABLE_NAME).Returns(mockSet.Object);

        var database = new Database();
        var records = await database.GetAllRecordsAsync<TABLE_NAME>(mockContext.Object);
        int numberOfRecords = records.Count;
        Assert.AreEqual(2, numberOfRecords, "Wrong number of records.");
    }

问题是我从数据库中获得实际记录数。如何从模拟对象中获取记录数?

4 个答案:

答案 0 :(得分:1)

我的猜测是你必须更改mockContext.Setup(x => x.TABLE_NAME),而是模拟用于查询数据的Set<T>()函数。

答案 1 :(得分:1)

我认为,您可以实现Repository模式来封装数据访问,然后模拟每个存储库,而不是模拟ORM。模拟会更容易,因为你只需要模拟GetAllRecordsAsync<T>而不是内部。

您可以尝试这样做:

interface IRepository<T> where T : class, IGetAllRecords
{
    Task<List<T>> GetAllRecordsAsync(EntitiesNew source);
}

public class Repository<T> : IRepository<T> where T : class, IGetAllRecords
{
    public async Task<List<T>> GetAllRecordsAsync(EntitiesNew source)
    {
        return await Task.FromResult<List<T>>(null);
    }
}

public class Foo : IGetAllRecords {}

public class FooRepository : Repository<Foo>
{
}

我不喜欢你正在使用的模拟框架,但你可能会像这样模仿IRepository<Foo>

var mockSet = new Mock<IRepository<Foo>>();
mockSet.Setup(x => x.GetAllRecordsAsync(null)).Returns(Task.FromResult<List<Foo>>(/*desired return value*/));

答案 2 :(得分:1)

根据this answer,我已经为我的实体类添加了接口:

public interface IDbContext
{
    DbSet<T> Set<T>() where T: class; 
}
public class EntitiesNew : DbContext, IDbContext
{
    public EntitiesNew()
        : base("name=EntitiesNew")
    {
    }}

然后我在存储库类中更改了我的方法:

public async Task<List<T>> GetAllRecordsAsync<T>(IDbContext source) where T : class, IGetAllRecords
{
    if (source != null) 
        return await source.Set<T>().ToListAsync();
    return null;
}

最后,我的测试方法现在看起来如下:

[TestMethod]
public async Task GetAllRecordsAsyncTest()
{
    var data = new List<TABLE_NAME>
    {
        new TABLE_NAME {VALID= 1, NAME = "test 1"},
        new TABLE_NAME {VALID= 1, NAME = "test 2"}
    }.AsQueryable();

    var mockSet = new Mock<DbSet<TABLE_NAME>>();
    mockSet.As<IQueryable<TABLE_NAME>>().Setup(m => m.Provider).Returns(new TestDbAsyncQueryProvider<TABLE_NAME>(data.Provider));
    mockSet.As<IQueryable<TABLE_NAME>>().Setup(m => m.Expression).Returns(data.Expression);
    mockSet.As<IQueryable<TABLE_NAME>>().Setup(m => m.ElementType).Returns(data.ElementType);
    mockSet.As<IQueryable<TABLE_NAME>>().Setup(m => m.GetEnumerator()).Returns(data.GetEnumerator());
    mockSet.As<IDbAsyncEnumerable<TABLE_NAME>>().Setup(x=>x.GetAsyncEnumerator()).Returns(new TestDbAsyncEnumerator<TABLE_NAME>(data.GetEnumerator()));
    var mockContext = new Mock<IDbContext>();
    mockContext.Setup(x => x.TABLE_NAME).Returns(mockSet.Object);

    var database = new Database();
    var records = await database.GetAllRecordsAsync<TABLE_NAME>(mockContext.Object);
    int numberOfRecords = records.Count;
    Assert.AreEqual(2, numberOfRecords, "Wrong number of records.");
}

答案 3 :(得分:0)

另一种选择是使用内存数据库:

var data = RETURN_DATA;
var optionsBuilder = new DbContextOptionsBuilder<DB_NAME>()
  .UseInMemoryDatabase(Guid.NewGuid().ToString());

var context = new DB_NAME(optionsBuilder.Options);
  context.TABLE_NAME.Add(data);
  context.SaveChanges();