mirror of
https://github.com/kataras/iris.git
synced 2025-01-23 02:31:04 +01:00
215 lines
5.5 KiB
Go
215 lines
5.5 KiB
Go
package sqlx
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"unsafe"
|
|
|
|
"github.com/kataras/iris/v12/x/reflex"
|
|
)
|
|
|
|
type (
|
|
// Schema holds the row definitions.
|
|
Schema struct {
|
|
Name string
|
|
Rows map[reflect.Type]*Row
|
|
ColumnNameFunc ColumnNameFunc
|
|
AutoCloseRows bool
|
|
}
|
|
|
|
// Row holds the column definitions and the struct type & name.
|
|
Row struct {
|
|
Schema string // e.g. public
|
|
Name string // e.g. users. Must set to a custom one if the select query contains AS names.
|
|
StructType reflect.Type
|
|
Columns map[string]*Column // e.g. "id":{"id", 0, [0]}
|
|
}
|
|
|
|
// Column holds the database column name and other properties extracted by a struct's field.
|
|
Column struct {
|
|
Name string
|
|
Index int
|
|
FieldIndex []int
|
|
}
|
|
)
|
|
|
|
// NewSchema returns a new Schema. Use its Register() method to cache
|
|
// a structure value so Bind() can fill all struct's fields based on a query.
|
|
func NewSchema() *Schema {
|
|
return &Schema{
|
|
Name: "public",
|
|
Rows: make(map[reflect.Type]*Row),
|
|
ColumnNameFunc: snakeCase,
|
|
AutoCloseRows: true,
|
|
}
|
|
}
|
|
|
|
// DefaultSchema initializes a common Schema.
|
|
var DefaultSchema = NewSchema()
|
|
|
|
// Register caches a struct value to the default schema.
|
|
func Register(tableName string, value interface{}) *Schema {
|
|
return DefaultSchema.Register(tableName, value)
|
|
}
|
|
|
|
// Query is a shortcut of executing a query and bind the result to "dst".
|
|
func Query(ctx context.Context, db *sql.DB, dst interface{}, query string, args ...interface{}) error {
|
|
return DefaultSchema.Query(ctx, db, dst, query, args...)
|
|
}
|
|
|
|
// Bind sets "dst" to the result of "src" and reports any errors.
|
|
func Bind(dst interface{}, src *sql.Rows) error {
|
|
return DefaultSchema.Bind(dst, src)
|
|
}
|
|
|
|
// Register caches a struct value to the schema.
|
|
func (s *Schema) Register(tableName string, value interface{}) *Schema {
|
|
typ := reflect.TypeOf(value)
|
|
for typ.Kind() == reflect.Ptr {
|
|
typ = typ.Elem()
|
|
}
|
|
|
|
if tableName == "" {
|
|
// convert to a human name, e.g. sqlx.Food -> food.
|
|
typeName := typ.String()
|
|
if idx := strings.LastIndexByte(typeName, '.'); idx > 0 && len(typeName) > idx {
|
|
typeName = typeName[idx+1:]
|
|
}
|
|
tableName = snakeCase(typeName)
|
|
}
|
|
|
|
columns, err := convertStructToColumns(typ, s.ColumnNameFunc)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("sqlx: register: %q: %s", reflect.TypeOf(value).String(), err.Error()))
|
|
}
|
|
|
|
s.Rows[typ] = &Row{
|
|
Schema: s.Name,
|
|
Name: tableName,
|
|
StructType: typ,
|
|
Columns: columns,
|
|
}
|
|
|
|
return s
|
|
}
|
|
|
|
// Query is a shortcut of executing a query and bind the result to "dst".
|
|
func (s *Schema) Query(ctx context.Context, db *sql.DB, dst interface{}, query string, args ...interface{}) error {
|
|
rows, err := db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !s.AutoCloseRows { // if not close on bind, we must close it here.
|
|
defer rows.Close()
|
|
}
|
|
|
|
err = s.Bind(dst, rows)
|
|
return err
|
|
}
|
|
|
|
// Bind sets "dst" to the result of "src" and reports any errors.
|
|
func (s *Schema) Bind(dst interface{}, src *sql.Rows) error {
|
|
typ := reflect.TypeOf(dst)
|
|
if typ.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("sqlx: bind: destination not a pointer")
|
|
}
|
|
|
|
typ = typ.Elem()
|
|
|
|
originalKind := typ.Kind()
|
|
if typ.Kind() == reflect.Slice {
|
|
typ = typ.Elem()
|
|
}
|
|
|
|
r, ok := s.Rows[typ]
|
|
if !ok {
|
|
return fmt.Errorf("sqlx: bind: unregistered type: %q", typ.String())
|
|
}
|
|
|
|
columnTypes, err := src.ColumnTypes()
|
|
if err != nil {
|
|
return fmt.Errorf("sqlx: bind: table: %q: %w", r.Name, err)
|
|
}
|
|
|
|
if expected, got := len(r.Columns), len(columnTypes); expected != got {
|
|
return fmt.Errorf("sqlx: bind: table: %q: unexpected number of result columns: %d: expected: %d", r.Name, got, expected)
|
|
}
|
|
|
|
val := reflex.IndirectValue(reflect.ValueOf(dst))
|
|
if s.AutoCloseRows {
|
|
defer src.Close()
|
|
}
|
|
|
|
switch originalKind {
|
|
case reflect.Struct:
|
|
if src.Next() {
|
|
if err = r.bindSingle(typ, val, columnTypes, src); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
return sql.ErrNoRows
|
|
}
|
|
|
|
return src.Err()
|
|
case reflect.Slice:
|
|
for src.Next() {
|
|
elem := reflect.New(typ).Elem()
|
|
if err = r.bindSingle(typ, elem, columnTypes, src); err != nil {
|
|
return err
|
|
}
|
|
|
|
val = reflect.Append(val, elem)
|
|
}
|
|
|
|
if err = src.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
reflect.ValueOf(dst).Elem().Set(val)
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("sqlx: bind: table: %q: unexpected destination kind: %q", r.Name, typ.Kind().String())
|
|
}
|
|
}
|
|
|
|
func (r *Row) bindSingle(typ reflect.Type, val reflect.Value, columnTypes []*sql.ColumnType, scanner interface{ Scan(...interface{}) error }) error {
|
|
fieldPtrs, err := r.lookupStructFieldPtrs(typ, val, columnTypes)
|
|
if err != nil {
|
|
return fmt.Errorf("sqlx: bind: table: %q: %w", r.Name, err)
|
|
}
|
|
|
|
return scanner.Scan(fieldPtrs...)
|
|
}
|
|
|
|
func (r *Row) lookupStructFieldPtrs(typ reflect.Type, val reflect.Value, columnTypes []*sql.ColumnType) ([]interface{}, error) {
|
|
fieldPtrs := make([]interface{}, 0, len(columnTypes))
|
|
|
|
for _, columnType := range columnTypes {
|
|
columnName := columnType.Name()
|
|
tableColumn, ok := r.Columns[columnName]
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
// TODO: when go 1.18 released, replace with that:
|
|
/*
|
|
tableColumnField, err := val.FieldByIndexErr(tableColumn.FieldIndex)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("column: %q: %w", tableColumn.Name, err)
|
|
}
|
|
*/
|
|
tableColumnField := val.FieldByIndex(tableColumn.FieldIndex)
|
|
|
|
tableColumnFieldType := tableColumnField.Type()
|
|
|
|
fieldPtr := reflect.NewAt(tableColumnFieldType, unsafe.Pointer(tableColumnField.UnsafeAddr())).Elem().Addr().Interface()
|
|
fieldPtrs = append(fieldPtrs, fieldPtr)
|
|
}
|
|
|
|
return fieldPtrs, nil
|
|
}
|