如何在EF6 Code First中使用泛型类型与数据库上下文

时间:2014-06-05 19:47:17

标签: c# entity-framework ef-code-first entity-framework-6

例如,假设我有4个不同的实体,每个实体都实现一个将实体添加到数据库的Add()方法:

public class Profile
{
    ...

    public void Add()
    {
        this._dbContext.Profile.Add(this);
        this._dbContext.SaveChanges();
    }

    ...
}

现在我希望有一个泛型类在一个抽象类而不是X个类中实现这种行为。所以我尝试了以下内容:

public abstract class Entity<TEntity> where TEntity : class 
{
    protected DbContext _dbContext;

    protected Entity()
    {
        this._dbContext = new SMTDBContext();
    }

    public void Add()
    {
        this._dbContext.Set<TEntity>().Add(this);
        this._dbContext.SaveChanges();
    }
}

当然它并不是因为“这个”不是一个TEntity ......但它将在未来!到目前为止,我试图找一个做类似事情而没有成功的人。

3 个答案:

答案 0 :(得分:6)

尝试使用通用存储库,最后您将开发类似的东西。 你需要3个接口:

  • IEntity
  • IEntityRepository
  • IEntityContext

这些接口的实现:

  • 的EntityContext
  • EntityRepository

这里是代码:

<强> IEntity.cs

public interface IEntity<TId> where TId : IComparable
{
    TId Id { get; set; }
}

<强> IEntityContext.cs

public interface IEntityContext : IDisposable
{
    void SetAsAdded<TEntity>(TEntity entity) where TEntity : class;
    void SetAsModified<TEntity>(TEntity entity) where TEntity : class;
    void SetAsDeleted<TEntity>(TEntity entity) where TEntity : class;

    IDbSet<TEntity> Set<TEntity>() where TEntity : class;
    int SaveChanges();
}

<强> IEntityRepository.cs

public interface IEntityRepository<TEntity, TId>
    : IDisposable
    where TEntity : class, IEntity<TId>
    where TId : IComparable
{
    IQueryable<TEntity> GetAll(
        Expression<Func<TEntity, bool>> where = null,
        Expression<Func<TEntity, object>> orderBy = null);
    PaginatedList<TEntity> Paginate(int pageIndex, int pageSize);

    TEntity GetSingle(TId id);

    IQueryable<TEntity> GetAllIncluding(
        Expression<Func<TEntity, bool>> where,
        Expression<Func<TEntity, object>> orderBy,
        params Expression<Func<TEntity, object>>[] includeProperties);

    TEntity GetSingleIncluding(
        TId id, params Expression<Func<TEntity, object>>[] includeProperties);

    void Add(TEntity entity);
    void Attach(TEntity entity);
    void Edit(TEntity entity);
    void Delete(TEntity entity);
    int Save();
}

<强> EntityRepository.cs

public class EntityRepository<TEntity, TId>
    : IEntityRepository<TEntity, TId>
    where TEntity : class, IEntity<TId>
    where TId : IComparable
{

    private readonly IEntityContext _dbContext;

    public EntityRepository(IEntityContext dbContext)
    {
        if (dbContext == null)
            throw new ArgumentNullException("dbContext");

        _dbContext = dbContext;
    }

    public IQueryable<TEntity> GetAllIncluding(
        Expression<Func<TEntity, bool>> where,
        Expression<Func<TEntity, object>> orderBy,
        params Expression<Func<TEntity, object>>[] includeProperties)
    {
        try
        {
            IQueryable<TEntity> queryable = GetAll(where, orderBy);
            foreach (Expression<Func<TEntity, object>> includeProperty in includeProperties)
            {
                queryable =
                    queryable.Include<TEntity, object>(includeProperty);
            }
            return queryable;
        }
        catch (Exception)
        {
            throw;
        }
    }

    public TEntity GetSingleIncluding(
        TId id,
        params Expression<Func<TEntity, object>>[] includeProperties)
    {
        try
        {
            IQueryable<TEntity> entities =
                    GetAllIncluding(null, null, includeProperties);
            TEntity entity =
                Filter<TId>(entities, x => x.Id, id).FirstOrDefault();
            return entity;
        }
        catch (Exception)
        {
            throw;
        }
    }

    public void Add(TEntity entity)
    {
        try
        {
            _dbContext.Set<TEntity>().Add(entity);
            if (this.EntityAdded != null)
                this.EntityAdded(this, new EntityAddedEventArgs<TEntity, TId>(entity));
        }
        catch (Exception)
        {
            throw;
        }
    }

    public void Attach(TEntity entity)
    {
        try
        {
            _dbContext.SetAsAdded(entity);
            if (this.EntityAttach != null)
                this.EntityAttach(this, new EntityAddedEventArgs<TEntity, TId>(entity));
        }
        catch (Exception)
        {
            throw;
        }
    }

    public void Edit(TEntity entity)
    {
        try
        {
            _dbContext.SetAsModified(entity);
            if (this.EntityModified != null)
                this.EntityModified(this, new EntityModifiedEventArgs<TEntity, TId>(entity));
        }
        catch (Exception)
        {
            throw;
        }
    }

    public void Delete(TEntity entity)
    {
        try
        {
            _dbContext.SetAsDeleted(entity);
            if (this.EntityDeleted != null)
                this.EntityDeleted(this, new EntityDeletedEventArgs<TEntity, TId>(entity));
        }
        catch (Exception)
        {
            throw;
        }
    }

    public int Save()
    {
        try
        {
            return _dbContext.SaveChanges();
        }
        catch (Exception)
        {
            throw;
        }
    }

    public IQueryable<TEntity> GetAll(
        Expression<Func<TEntity, bool>> where = null,
        Expression<Func<TEntity, object>> orderBy = null)
    {
        try
        {
            IQueryable<TEntity> queryable =
                (where != null) ? _dbContext.Set<TEntity>().Where(where)
                : _dbContext.Set<TEntity>();

            return (orderBy != null) ? queryable.OrderBy(orderBy)
                : queryable;
        }
        catch (Exception)
        {
            throw;
        }
    }

    public TEntity GetSingle(TId id)
    {
        try
        {
            IQueryable<TEntity> entities = GetAll();
            TEntity entity =
                Filter<TId>(entities, x => x.Id, id).FirstOrDefault();
            return entity;
        }
        catch (Exception)
        {
            throw;
        }
    }

    public void Dispose()
    {
        _dbContext.Dispose();
    }

    #region Private

    private IQueryable<TEntity> Filter<TProperty>(
        IQueryable<TEntity> dbSet,
        Expression<Func<TEntity, TProperty>> property, TProperty value)
        where TProperty : IComparable
    {
        try
        {
            var memberExpression = property.Body as MemberExpression;

            if (memberExpression == null ||
                !(memberExpression.Member is PropertyInfo))
                throw new ArgumentException
                    ("Property expected", "property");

            Expression left = property.Body;
            Expression right =
                Expression.Constant(value, typeof(TProperty));
            Expression searchExpression = Expression.Equal(left, right);

            Expression<Func<TEntity, bool>> lambda =
                Expression.Lambda<Func<TEntity, bool>>(
                    searchExpression,
                    new ParameterExpression[] { property.Parameters.Single() });

            return dbSet.Where(lambda);
        }
        catch (Exception)
        {
            throw;
        }
    }

    private enum OrderByType
    {
        Ascending,
        Descending
    }
    #endregion
}

<强> EntityContext.cs

public abstract class EntityContext : DbContext, IEntityContext
{
    /// <summary>
    /// Constructs a new context instance using conventions to create the name of
    /// the database to which a connection will be made. The by-convention name is
    /// the full name (namespace + class name) of the derived context class.  See
    /// the class remarks for how this is used to create a connection. 
    /// </summary>
    protected EntityContext() : base() { }

    /// <summary>
    /// Constructs a new context instance using conventions to create the name of
    /// the database to which a connection will be made, and initializes it from
    /// the given model.  The by-convention name is the full name (namespace + class
    /// name) of the derived context class.  See the class remarks for how this is
    /// used to create a connection.
    /// </summary>
    /// <param name="model">The model that will back this context.</param>
    protected EntityContext(DbCompiledModel model) : base(model) { }

    /// <summary>
    /// Constructs a new context instance using the given string as the name or connection
    /// string for the database to which a connection will be made.  See the class
    /// remarks for how this is used to create a connection.
    /// </summary>
    /// <param name="nameOrConnectionString">Either the database name or a connection string.</param>
    public EntityContext(string nameOrConnectionString)
        : base(nameOrConnectionString) { }

    /// <summary>
    /// Constructs a new context instance using the existing connection to connect
    /// to a database.  The connection will not be disposed when the context is disposed.
    /// </summary>
    /// <param name="existingConnection">An existing connection to use for the new context.</param>
    /// <param name="contextOwnsConnection">
    /// If set to true the connection is disposed when the context is disposed, otherwise
    /// the caller must dispose the connection.
    /// </param>
    public EntityContext
        (DbConnection existingConnection, bool contextOwnsConnection)
        : base(existingConnection, contextOwnsConnection) { }

    /// <summary>
    /// Constructs a new context instance around an existing ObjectContext.  An existing
    /// ObjectContext to wrap with the new context.  If set to true the ObjectContext
    /// is disposed when the EntitiesContext is disposed, otherwise the caller must dispose
    /// the connection.
    /// </summary>
    /// <param name="objectContext">An existing ObjectContext to wrap with the new context.</param>
    /// <param name="EntitiesContextOwnsObjectContext">
    /// If set to true the ObjectContext is disposed when the EntitiesContext is disposed,
    /// otherwise the caller must dispose the connection.
    /// </param>
    public EntityContext(
        ObjectContext objectContext,
        bool EntityContextOwnsObjectContext)
        : base(objectContext, EntityContextOwnsObjectContext)
    { }

    /// <summary>
    /// Constructs a new context instance using the given string as the name or connection
    /// string for the database to which a connection will be made, and initializes
    /// it from the given model.  See the class remarks for how this is used to create
    /// a connection.
    /// </summary>
    /// <param name="nameOrConnectionString">Either the database name or a connection string.</param>
    /// <param name="model">The model that will back this context.</param>
    public EntityContext(
        string nameOrConnectionString,
        DbCompiledModel model)
        : base(nameOrConnectionString, model)
    { }

    /// <summary>
    /// Constructs a new context instance using the existing connection to connect
    /// to a database, and initializes it from the given model.  The connection will
    /// not be disposed when the context is disposed.  An existing connection to
    /// use for the new context.  The model that will back this context.  If set
    /// to true the connection is disposed when the context is disposed, otherwise
    /// the caller must dispose the connection.
    /// </summary>
    /// <param name="existingConnection">An existing connection to use for the new context.</param>
    /// <param name="model">The model that will back this context.</param>
    /// <param name="contextOwnsConnection">
    /// If set to true the connection is disposed when the context is disposed, otherwise
    /// the caller must dispose the connection.
    /// </param>
    public EntityContext(
        DbConnection existingConnection,
        DbCompiledModel model, bool contextOwnsConnection)
        : base(existingConnection, model, contextOwnsConnection)
    { }

    public new IDbSet<TEntity> Set<TEntity>() where TEntity : class
    {
        try
        {
            return base.Set<TEntity>();
        }
        catch (Exception)
        {
            throw;
        }
    }

    public void SetAsAdded<TEntity>(TEntity entity) where TEntity : class
    {
        try
        {
            DbEntityEntry dbEntityEntry = GetDbEntityEntrySafely(entity);
            dbEntityEntry.State = EntityState.Added;
        }
        catch (Exception)
        {
            throw;
        }
    }

    public void SetAsModified<TEntity>(TEntity entity) where TEntity : class
    {
        try
        {
            DbEntityEntry dbEntityEntry = GetDbEntityEntrySafely(entity);
            dbEntityEntry.State = EntityState.Modified;
        }
        catch (Exception)
        {
            throw;
        }
    }

    public void SetAsDeleted<TEntity>(TEntity entity) where TEntity : class
    {
        try
        {
            DbEntityEntry dbEntityEntry = GetDbEntityEntrySafely(entity);
            dbEntityEntry.State = EntityState.Deleted;
        }
        catch (Exception)
        {
            throw;
        }
    }

    public override int SaveChanges()
    {
        try
        {
            return base.SaveChanges();
        }
        catch (Exception)
        {
            throw;
        }
    }

    public new void Dispose()
    {
        try
        {
            base.Dispose();
        }
        catch (Exception)
        {
            throw;
        }
    }

    #region Private
    private DbEntityEntry GetDbEntityEntrySafely<TEntity>(
        TEntity entity) where TEntity : class
    {
        try
        {
            DbEntityEntry dbEntityEntry = base.Entry<TEntity>(entity);
            if (dbEntityEntry.State == EntityState.Detached)
                Set<TEntity>().Attach(entity);

            return dbEntityEntry;
        }
        catch (Exception)
        {
            throw;
        }
    }
    #endregion
}

答案很长,但值得...祝你有美好的一天:) 它是个人巨大项目的一部分:D

答案 1 :(得分:6)

您的问题的解决方案是使用泛型约束的定义更明确。将约束定义为 TEntity必须是实体&lt; TEntity&gt; 的子类,即使用where TEntity : Entity<TEntity>而不是where TEntity : class

public abstract class Entity<TEntity> where TEntity : Entity<TEntity>
{
    protected DbContext _dbContext;

    protected Entity()
    {
        this._dbContext = new SMTDBContext();
    }

    public void Add()
    {
        this._dbContext.Set<TEntity>().Add((TEntity)this);
        this._dbContext.SaveChanges();
    }
}

答案 2 :(得分:0)

你可以解决以下问题;你只需要在运行时确保它真的是一个TEntity:

public void Add()
{
    object obj = this;
    this._dbContext.Set<TEntity>().Add((TEntity)obj);
    this._dbContext.SaveChanges();
}

由于编译器在使用对象类型时无法跟踪这是什么。如果你收到错误,那是因为obj不是真正的TEntity。但是,您可能希望使用工厂,存储库或其他设计模式来处理实体框架DBSet。