From 4ecc9e3831ac6dd9df63da258afbec059dcd8ab4 Mon Sep 17 00:00:00 2001 From: "Gerasimos (Makis) Maropoulos" Date: Tue, 8 Mar 2022 00:33:08 +0200 Subject: [PATCH] add a new x/sqlx sub-package and example --- HISTORY.md | 2 + _examples/README.md | 3 +- _examples/database/sqlx/main.go | 181 ++++++++++++++++++++++++++++++ x/reflex/struct.go | 13 ++- x/sqlx/sqlx.go | 189 ++++++++++++++++++++++++++++++++ x/sqlx/sqlx_test.go | 75 +++++++++++++ x/sqlx/struct_row.go | 92 ++++++++++++++++ x/sqlx/util.go | 42 +++++++ 8 files changed, 595 insertions(+), 2 deletions(-) create mode 100644 _examples/database/sqlx/main.go create mode 100644 x/sqlx/sqlx.go create mode 100644 x/sqlx/sqlx_test.go create mode 100644 x/sqlx/struct_row.go create mode 100644 x/sqlx/util.go diff --git a/HISTORY.md b/HISTORY.md index 63cb2985..4f3bfa01 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -28,6 +28,8 @@ The codebase for Dependency Injection, Internationalization and localization and ## Fixes and Improvements +- Add a new [x/sqlx](/x/sqlx/) sub-package ([example](_examples/database/sqlx/main.go)). + - Add a new [x/reflex](/x/reflex) sub-package. - Add `Context.ReadMultipartRelated` as requested at: [issues/#1787](https://github.com/kataras/iris/issues/1787). diff --git a/_examples/README.md b/_examples/README.md index ab61519b..32be6859 100644 --- a/_examples/README.md +++ b/_examples/README.md @@ -8,7 +8,7 @@ * [Bootstrapper](bootstrapper) * [Project Structure](project) :fire: * Monitor - * [Simple Proccess Monitor (includes UI)](monitor/monitor-middleware/main.go) **NEW** + * [Simple Process Monitor (includes UI)](monitor/monitor-middleware/main.go) **NEW** * [Heap, MSpan/MCache, Size Classes, Objects, Goroutines, GC/CPU fraction (includes UI)](monitor/statsviz/main.go) **NEW** * Database * [MySQL, Groupcache & Docker](database/mysql) @@ -16,6 +16,7 @@ * [Sqlx](database/orm/sqlx/main.go) * [Gorm](database/orm/gorm/main.go) * [Reform](database/orm/reform/main.go) + * [x/sqlx](database/sqlx/main.go) **NEW** * HTTP Server * [HOST:PORT](http-server/listen-addr/main.go) * [Public Test Domain](http-server/listen-addr-public/main.go) diff --git a/_examples/database/sqlx/main.go b/_examples/database/sqlx/main.go new file mode 100644 index 00000000..4ed847e1 --- /dev/null +++ b/_examples/database/sqlx/main.go @@ -0,0 +1,181 @@ +package main + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "time" + + "github.com/kataras/iris/v12" + "github.com/kataras/iris/v12/x/errors" + "github.com/kataras/iris/v12/x/sqlx" + + _ "github.com/lib/pq" +) + +const ( + host = "localhost" + port = 5432 + user = "postgres" + password = "admin!123" + dbname = "test" +) + +func main() { + app := iris.New() + + db := mustConnectDB() + mustCreateExtensions(context.Background(), db) + mustCreateTables(context.Background(), db) + + app.Post("/", insert(db)) + app.Get("/", list(db)) + app.Get("/{event_id:uuid}", getByID(db)) + + /* + curl --location --request POST 'http://localhost:8080' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "name": "second_test_event", + "data": { + "key": "value", + "year": 2022 + } + }' + + curl --location --request GET 'http://localhost:8080' + + curl --location --request GET 'http://localhost:8080/4fc0363f-1d1f-4a43-8608-5ed266485645' + */ + app.Listen(":8080") +} + +func mustConnectDB() *sql.DB { + connString := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", + host, port, user, password, dbname) + db, err := sql.Open("postgres", connString) + if err != nil { + panic(err) + } + + err = db.Ping() + if err != nil { + panic(err) + } + + return db +} + +func mustCreateExtensions(ctx context.Context, db *sql.DB) { + query := `CREATE EXTENSION IF NOT EXISTS pgcrypto;` + _, err := db.ExecContext(ctx, query) + if err != nil { + panic(err) + } +} + +func mustCreateTables(ctx context.Context, db *sql.DB) { + query := `CREATE TABLE IF NOT EXISTS "events" ( + "id" uuid PRIMARY KEY NOT NULL DEFAULT gen_random_uuid(), + "created_at" timestamp(6) DEFAULT now(), + "name" text COLLATE "pg_catalog"."default", + "data" jsonb + );` + + _, err := db.ExecContext(ctx, query) + if err != nil { + panic(err) + } + + sqlx.Register("events", Event{}) +} + +type Event struct { + ID string `json:"id"` + CreatedAt time.Time `json:"created_at"` + Name string `json:"name"` + Data json.RawMessage `json:"data"` + + Presenter string `db:"-" json:"-"` +} + +func insertEvent(ctx context.Context, db *sql.DB, evt Event) (id string, err error) { + query := `INSERT INTO events(name,data) VALUES($1,$2) RETURNING id;` + err = db.QueryRowContext(ctx, query, evt.Name, evt.Data).Scan(&id) + return +} + +func listEvents(ctx context.Context, db *sql.DB) ([]Event, error) { + list := make([]Event, 0) + query := `SELECT * FROM events ORDER BY created_at;` + rows, err := db.QueryContext(ctx, query) + if err != nil { + return nil, err + } + // Not required. See sqlx.DefaultSchema.AutoCloseRows field. + // defer rows.Close() + + if err = sqlx.Bind(&list, rows); err != nil { + return nil, err + } + + return list, nil +} + +func getEvent(ctx context.Context, db *sql.DB, id string) (Event, error) { + query := `SELECT * FROM events WHERE id = $1 LIMIT 1;` + rows, err := db.QueryContext(ctx, query, id) + if err != nil { + return Event{}, err + } + + var evt Event + err = sqlx.Bind(&evt, rows) + + return evt, err +} + +func insert(db *sql.DB) iris.Handler { + return func(ctx iris.Context) { + var evt Event + if err := ctx.ReadJSON(&evt); err != nil { + errors.InvalidArgument.Details(ctx, "unable to read body", err.Error()) + return + } + + id, err := insertEvent(ctx, db, evt) + if err != nil { + errors.Internal.LogErr(ctx, err) + return + } + + ctx.JSON(iris.Map{"id": id}) + } +} + +func list(db *sql.DB) iris.Handler { + return func(ctx iris.Context) { + events, err := listEvents(ctx, db) + if err != nil { + errors.Internal.LogErr(ctx, err) + return + } + + ctx.JSON(events) + } +} + +func getByID(db *sql.DB) iris.Handler { + return func(ctx iris.Context) { + eventID := ctx.Params().Get("event_id") + + evt, err := getEvent(ctx, db, eventID) + if err != nil { + errors.Internal.LogErr(ctx, err) + return + } + + ctx.JSON(evt) + } +} diff --git a/x/reflex/struct.go b/x/reflex/struct.go index ec901f41..406d29a5 100644 --- a/x/reflex/struct.go +++ b/x/reflex/struct.go @@ -2,7 +2,7 @@ package reflex import "reflect" -// LookupFields returns a slice of all fields containg a struct field +// LookupFields returns a slice of all fields containing a struct field // of the given "fieldTag" of the "typ" struct. The fields returned // are flatted and reclusive over fields with value of struct. // Panics if "typ" is not a type of Struct. @@ -54,3 +54,14 @@ func lookupFields(typ reflect.Type, fieldTag string, parentIndex []int) []reflec return fields } + +// LookupUnderlineValueType returns the underline type of "v". +func LookupUnderlineValueType(v reflect.Value) (reflect.Value, reflect.Type) { + typ := v.Type() + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + v = reflect.New(typ).Elem() + } + + return v, typ +} diff --git a/x/sqlx/sqlx.go b/x/sqlx/sqlx.go new file mode 100644 index 00000000..91cb6d93 --- /dev/null +++ b/x/sqlx/sqlx.go @@ -0,0 +1,189 @@ +package sqlx + +import ( + "database/sql" + "fmt" + "reflect" + "strings" + "unsafe" + + "github.com/kataras/iris/v12/x/reflex" +) + +type ( + // Schema holds the row definitions. + Schema struct { + Name string + Rows map[reflect.Type]*Row + ColumnNameFunc ColumnNameFunc + AutoCloseRows bool + } + + // Row holds the column definitions and the struct type & name. + Row struct { + Schema string // e.g. public + Name string // e.g. users. Must set to a custom one if the select query contains AS names. + StructType reflect.Type + Columns map[string]*Column // e.g. "id":{"id", 0, [0]} + } + + // Column holds the database column name and other properties extracted by a struct's field. + Column struct { + Name string + Index int + FieldIndex []int + } +) + +// NewSchema returns a new Schema. Use its Register() method to cache +// a structure value so Bind() can fill all struct's fields based on a query. +func NewSchema() *Schema { + return &Schema{ + Name: "public", + Rows: make(map[reflect.Type]*Row), + ColumnNameFunc: snakeCase, + AutoCloseRows: true, + } +} + +// DefaultSchema initializes a common Schema. +var DefaultSchema = NewSchema() + +// Register caches a struct value to the default schema. +func Register(tableName string, value interface{}) *Schema { + return DefaultSchema.Register(tableName, value) +} + +// Bind sets "dst" to the result of "src" and reports any errors. +func Bind(dst interface{}, src *sql.Rows) error { + return DefaultSchema.Bind(dst, src) +} + +// Register caches a struct value to the schema. +func (s *Schema) Register(tableName string, value interface{}) *Schema { + typ := reflect.TypeOf(value) + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + + if tableName == "" { + // convert to a human name, e.g. sqlx.Food -> food. + typeName := typ.String() + if idx := strings.LastIndexByte(typeName, '.'); idx > 0 && len(typeName) > idx { + typeName = typeName[idx+1:] + } + tableName = snakeCase(typeName) + } + + columns, err := convertStructToColumns(typ, s.ColumnNameFunc) + if err != nil { + panic(fmt.Sprintf("sqlx: register: %q: %s", reflect.TypeOf(value).String(), err.Error())) + } + + s.Rows[typ] = &Row{ + Schema: s.Name, + Name: tableName, + StructType: typ, + Columns: columns, + } + + return s +} + +// Bind sets "dst" to the result of "src" and reports any errors. +func (s *Schema) Bind(dst interface{}, src *sql.Rows) error { + typ := reflect.TypeOf(dst) + if typ.Kind() != reflect.Ptr { + return fmt.Errorf("sqlx: bind: destination not a pointer") + } + + typ = typ.Elem() + + originalKind := typ.Kind() + if typ.Kind() == reflect.Slice { + typ = typ.Elem() + } + + r, ok := s.Rows[typ] + if !ok { + return fmt.Errorf("sqlx: bind: unregistered type: %q", typ.String()) + } + + columnTypes, err := src.ColumnTypes() + if err != nil { + return fmt.Errorf("sqlx: bind: table: %q: %w", r.Name, err) + } + + if expected, got := len(r.Columns), len(columnTypes); expected != got { + return fmt.Errorf("sqlx: bind: table: %q: unexpected number of result columns: %d: expected: %d", r.Name, got, expected) + } + + val := reflex.IndirectValue(reflect.ValueOf(dst)) + if s.AutoCloseRows { + defer src.Close() + } + + switch originalKind { + case reflect.Struct: + if src.Next() { + if err = r.bindSingle(typ, val, columnTypes, src); err != nil { + return err + } + } else { + return sql.ErrNoRows + } + + return src.Err() + case reflect.Slice: + for src.Next() { + elem := reflect.New(typ).Elem() + if err = r.bindSingle(typ, elem, columnTypes, src); err != nil { + return err + } + + val = reflect.Append(val, elem) + } + + if err = src.Err(); err != nil { + return err + } + + reflect.ValueOf(dst).Elem().Set(val) + return nil + default: + return fmt.Errorf("sqlx: bind: table: %q: unexpected destination kind: %q", r.Name, typ.Kind().String()) + } +} + +func (r *Row) bindSingle(typ reflect.Type, val reflect.Value, columnTypes []*sql.ColumnType, scanner interface{ Scan(...interface{}) error }) error { + fieldPtrs, err := r.lookupStructFieldPtrs(typ, val, columnTypes) + if err != nil { + return fmt.Errorf("sqlx: bind: table: %q: %w", r.Name, err) + } + + return scanner.Scan(fieldPtrs...) +} + +func (r *Row) lookupStructFieldPtrs(typ reflect.Type, val reflect.Value, columnTypes []*sql.ColumnType) ([]interface{}, error) { + fieldPtrs := make([]interface{}, 0, len(columnTypes)) + + for _, columnType := range columnTypes { + columnName := columnType.Name() + tableColumn, ok := r.Columns[columnName] + if !ok { + continue + } + + tableColumnField, err := val.FieldByIndexErr(tableColumn.FieldIndex) + if err != nil { + return nil, fmt.Errorf("column: %q: %w", tableColumn.Name, err) + } + + tableColumnFieldType := tableColumnField.Type() + + fieldPtr := reflect.NewAt(tableColumnFieldType, unsafe.Pointer(tableColumnField.UnsafeAddr())).Elem().Addr().Interface() + fieldPtrs = append(fieldPtrs, fieldPtr) + } + + return fieldPtrs, nil +} diff --git a/x/sqlx/sqlx_test.go b/x/sqlx/sqlx_test.go new file mode 100644 index 00000000..b8567cf7 --- /dev/null +++ b/x/sqlx/sqlx_test.go @@ -0,0 +1,75 @@ +package sqlx + +/* +import ( + "reflect" + "testing" + + sqlmock "github.com/DATA-DOG/go-sqlmock" +) + +type food struct { + ID string + Name string + Presenter bool `db:"-"` +} + +func TestTableBind(t *testing.T) { + Register("foods", food{}) + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatal(err) + } + + mock.ExpectQuery("SELECT .* FROM foods WHERE id = ?"). + WithArgs("42"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow("42", "banana"). + AddRow("43", "broccoli")) + + rows, err := db.Query("SELECT .* FROM foods WHERE id = ? LIMIT 1", "42") + if err != nil { + t.Fatal(err) + } + + var f food + err = Bind(&f, rows) + if err != nil { + t.Fatal(err) + } + + expectedSingle := food{"42", "banana", false} + if !reflect.DeepEqual(f, expectedSingle) { + t.Fatalf("expected value: %#+v but got: %#+v", expectedSingle, f) + } + + mock.ExpectQuery("SELECT .* FROM foods"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow("42", "banana"). + AddRow("43", "broccoli"). + AddRow("44", "chicken")) + rows, err = db.Query("SELECT .* FROM foods") + if err != nil { + t.Fatal(err) + } + + var foods []food + err = Bind(&foods, rows) + if err != nil { + t.Fatal(err) + } + + expectedMany := []food{ + {"42", "banana", false}, + {"43", "broccoli", false}, + {"44", "chicken", false}, + } + + for i := range foods { + if !reflect.DeepEqual(foods[i], expectedMany[i]) { + t.Fatalf("[%d] expected: %#+v but got: %#+v", i, expectedMany[i], foods[i]) + } + } +} +*/ diff --git a/x/sqlx/struct_row.go b/x/sqlx/struct_row.go new file mode 100644 index 00000000..1246a5b0 --- /dev/null +++ b/x/sqlx/struct_row.go @@ -0,0 +1,92 @@ +package sqlx + +import ( + "fmt" + "reflect" + "strings" + + "github.com/kataras/iris/v12/x/reflex" +) + +// DefaultTag is the default struct field tag. +var DefaultTag = "db" + +type ColumnNameFunc = func(string) string + +func convertStructToColumns(typ reflect.Type, nameFunc ColumnNameFunc) (map[string]*Column, error) { + if kind := typ.Kind(); kind != reflect.Struct { + return nil, fmt.Errorf("convert struct: invalid type: expected a struct value but got: %q", kind.String()) + } + + // Retrieve only fields valid for database. + fields := reflex.LookupFields(typ, "") + + columns := make(map[string]*Column, len(fields)) + for i, field := range fields { + column, ok, err := convertStructFieldToColumn(field, DefaultTag, nameFunc) + if !ok { + continue + } + + if err != nil { + return nil, fmt.Errorf("convert struct: field name: %q: %w", field.Name, err) + } + + column.Index = i + columns[column.Name] = column + } + + return columns, nil +} + +func convertStructFieldToColumn(field reflect.StructField, optionalTag string, nameFunc ColumnNameFunc) (*Column, bool, error) { + c := &Column{ + Name: nameFunc(field.Name), + FieldIndex: field.Index, + } + + fieldTag, ok := field.Tag.Lookup(optionalTag) + if ok { + if fieldTag == "-" { + return nil, false, nil + } + + if err := parseOptions(fieldTag, c); err != nil { + return nil, false, err + } + } + + return c, true, nil +} + +func parseOptions(fieldTag string, c *Column) error { + options := strings.Split(fieldTag, ",") + for _, opt := range options { + if opt == "" { + continue // skip empty. + } + + var key, value string + + kv := strings.Split(opt, "=") // When more options come to play. + switch len(kv) { + case 2: + key = kv[0] + value = kv[1] + case 1: + c.Name = kv[0] + return nil + default: + return fmt.Errorf("option: %s: expected key value separated by '='", opt) + } + + switch key { + case "name": + c.Name = value + default: + return fmt.Errorf("unexpected tag option: %s", key) + } + } + + return nil +} diff --git a/x/sqlx/util.go b/x/sqlx/util.go new file mode 100644 index 00000000..536bc2af --- /dev/null +++ b/x/sqlx/util.go @@ -0,0 +1,42 @@ +package sqlx + +import "strings" + +// snakeCase converts a given string to a friendly snake case, e.g. +// - userId to user_id +// - ID to id +// - ProviderAPIKey to provider_api_key +// - Option to option +func snakeCase(camel string) string { + var ( + b strings.Builder + prevWasUpper bool + ) + + for i, c := range camel { + if isUppercase(c) { // it's upper. + if b.Len() > 0 && !prevWasUpper { // it's not the first and the previous was not uppercased too (e.g "ID"). + b.WriteRune('_') + } else { // check for XxxAPIKey, it should be written as xxx_api_key. + next := i + 1 + if next > 1 && len(camel)-1 > next { + if !isUppercase(rune(camel[next])) { + b.WriteRune('_') + } + } + } + + b.WriteRune(c - 'A' + 'a') // write its lowercase version. + prevWasUpper = true + } else { + b.WriteRune(c) // write it as it is, it's already lowercased. + prevWasUpper = false + } + } + + return b.String() +} + +func isUppercase(c rune) bool { + return 'A' <= c && c <= 'Z' +}