创建一个接受两个具有相同方法的不同对象的函数

时间:2016-05-28 23:35:41

标签: go

我正在学习Golang,并希望了解" Go way"解决这个问题。

具体来说,我使用的是sql软件包,并且我在代码中看到了一些冗余功能,我希望将其引入功能。

我的原始代码

我有 1) 用户结构:

type User struct {
  ID        int
  FirstName string
  LastName  string
}

2) 一个从数据库中获取ID的一个用户的函数(Postgresql):

func GetUserById(id int) (user User) {
  sql := `
    SELECT id, first_name, last_name
    FROM users
    WHERE id = $1
  `
  row := db.QueryRow(sql, id)
  err := row.Scan(&user.ID, &user.FirstName, &user.LastName)
  if err != nil {
    panic(err)
  }
  return
}

3) 一个从数据库中获取所有用户的功能:

func GetUsers() (users []User) {
  sql := `
    SELECT id, first_name, last_name
    FROM users
    ORDER BY last_name
  `
  rows, err := db.Query(sql)
  if err != nil {
    panic(err)
  }
  for rows.Next() {
    user := User{}
    err := rows.Scan(&user.ID, &user.FirstName, &user.LastName)
    if err != nil {
      panic(err)
    }
    users = append(users, user)
  }
  rows.Close()
  return
}

期望的重构

用户记录中只有3个字段,这是一个简单的例子。但是,有了更多的字段,rows.Scan(...)这两个数据访问函数都可以很好地转移到一个可以调用的函数:

func ScanUserFromRow(row *sql.Row) (user User) {
  err := row.Scan(&user.ID, &user.FirstName, &user.LastName)
  if err != nil {
    panic(err)
  }
  return
}

然后更新的数据库访问功能类似于:

func GetUserById(id int) (user User) {
  sql := `
    SELECT id, first_name, last_name
    FROM users
    WHERE id = $1
  `
  row := db.QueryRow(sql, id)
  user = ScanUserFromRow(row)
  return
}

func GetUsers() (users []User) {
  sql := `
    SELECT id, first_name, last_name
    FROM users
    ORDER BY last_name
  `
  rows, err := db.Query(sql)
  if err != nil {
    panic(err)
  }
  for rows.Next() {
    user := ScanUserFromRow(rows)
    users = append(users, user)
  }
  rows.Close()
  return
}

但是,在GetUserById函数的情况下,我正在处理*sql.Row结构指针。在GetUsers函数的情况下,我处理*sql.Rows结构指针。两者是不同的......显然,类似,因为它们都有Scan方法。

似乎类型系统不允许我创建一个接受其中一个的方法。有没有办法利用interface{}来实现这一点,还是有其他一些更惯用的Go解决方案呢?

更新/解决方案

有了这个问题,我说sql.Rowsql.Rows都是鸭子,而且#34; quack"与Scan。如何使用允许两者的函数参数?

@seh provided an answer below允许通过使参数成为自定义接口来实现我希望的那种鸭子类型。以下是生成的代码:

type rowScanner interface {
  Scan(dest ...interface{}) error
}

func ScanPlayerFromRow(rs rowScanner) (u User) {
  err := rs.Scan(&u.ID, &u.FirstName, &u.LastName)
  if err != nil {
    panic(err)
  }
  return
}

...or, as @Kaveh pointed out below,接口的定义可以在函数参数中内联:

func ScanPlayerFromRow(rs interface {
  Scan(des ...interface{}) error
}) (u User) {
  err := rs.Scan(&u.ID, &u.FirstName, &u.LastName)
  if err != nil {
    panic(err)
  }
  return
}

3 个答案:

答案 0 :(得分:3)

sql.Rowssql.Row都有a Scan method。标准库中没有包含该方法的接口,但可以自己定义:

type rowScanner interface {
    Scan(dest ...interface{}) error
}

然后,您可以编写一个对rowScanner而不是*sql.Row*sql.Rows进行操作的函数:

 import "database/sql"

 type rowScanner interface {
    Scan(dest ...interface{}) error
 }

 func handleRow(scanner rowScanner) error {
    var i int
    return scanner.Scan(&i)
 }

 func main() {
    var row *sql.Row
    handleRow(row) // Crashes due to calling on a nil pointer.
    var rows *sql.Rows
    handleRow(rows)  // Crashes due to calling on a nil pointer.
 }

我没有使用真实*sql.Row*sql.Rows进行模拟,但这应该会给你一个想法。您所需的ScanUserFromRow功能需要rowScanner而不是*sql.Row

答案 1 :(得分:2)

在这种情况下,我会使用Query()重构sql.rows并同时返回users[]。如果我们知道id是唯一的,那么添加一个便利函数来返回该数组中的第一项。

类似的东西:

type User struct {
  ID        int
  FirstName string
  LastName  string
}
func GetUser(id int) (user User) {
  users := getUsers(id)
  if (len(users) > 0) {
      user = users[0]
  }
  return
}
func GetUsers() []User {
  return getUsers(0)
}
func getUsers(id int) (users []User) {
  rows := getUserRows(id)
  for rows.Next() {
    user := User{}
    err := rows.Scan(&user.ID, &user.FirstName, &user.LastName)
    if err != nil {
      panic(err)
    }
    users = append(users, user)
  }
  if err := rows.Err(); err != nil {
      log.Fatal(err)
  }
  rows.Close()
  return
}

func getUserRows(id int) (rows sql.Rows) {
  sqlSelect := `
    SELECT id, first_name, last_name
    FROM users`
  sqlWhere := `
    WHERE id = $1`
  sqlOrder := `
    ORDER BY last_name`
  var err error
  if (0 == id) {
    rows, err = db.Query(sqlSelect + sqlOrder)
  } else {
    rows, err = db.Query(sqlSelect + sqlWhere + sqlOrder, id)
  }
  if err != nil {
    panic(err)
  }
  return
}

答案 2 :(得分:1)

您可以强制执行函数的参数以遵守特定的接口 - 在这种情况下,使用Scan(dest ...interface{}) error方法:

func sampleHandler(ru interface {
    Scan(dest ...interface{}) error
}) error {
    var data []interface{}
    // result of some action in your logic
    return ru.Scan(data)
}