package sqlx

import (
	"context"
	"database/sql"
	"fmt"
	"reflect"
	"strings"
	"unsafe"

	"github.com/kataras/iris/v12/x/reflex"
)

type (
	// Schema holds the row definitions.
	Schema struct {
		Name           string
		Rows           map[reflect.Type]*Row
		ColumnNameFunc ColumnNameFunc
		AutoCloseRows  bool
	}

	// Row holds the column definitions and the struct type & name.
	Row struct {
		Schema     string // e.g. public
		Name       string // e.g. users. Must set to a custom one if the select query contains AS names.
		StructType reflect.Type
		Columns    map[string]*Column // e.g. "id":{"id", 0, [0]}
	}

	// Column holds the database column name and other properties extracted by a struct's field.
	Column struct {
		Name       string
		Index      int
		FieldIndex []int
	}
)

// NewSchema returns a new Schema. Use its Register() method to cache
// a structure value so Bind() can fill all struct's fields based on a query.
func NewSchema() *Schema {
	return &Schema{
		Name:           "public",
		Rows:           make(map[reflect.Type]*Row),
		ColumnNameFunc: snakeCase,
		AutoCloseRows:  true,
	}
}

// DefaultSchema initializes a common Schema.
var DefaultSchema = NewSchema()

// Register caches a struct value to the default schema.
func Register(tableName string, value interface{}) *Schema {
	return DefaultSchema.Register(tableName, value)
}

// Query is a shortcut of executing a query and bind the result to "dst".
func Query(ctx context.Context, db *sql.DB, dst interface{}, query string, args ...interface{}) error {
	return DefaultSchema.Query(ctx, db, dst, query, args...)
}

// Bind sets "dst" to the result of "src" and reports any errors.
func Bind(dst interface{}, src *sql.Rows) error {
	return DefaultSchema.Bind(dst, src)
}

// Register caches a struct value to the schema.
func (s *Schema) Register(tableName string, value interface{}) *Schema {
	typ := reflect.TypeOf(value)
	for typ.Kind() == reflect.Ptr {
		typ = typ.Elem()
	}

	if tableName == "" {
		// convert to a human name, e.g. sqlx.Food -> food.
		typeName := typ.String()
		if idx := strings.LastIndexByte(typeName, '.'); idx > 0 && len(typeName) > idx {
			typeName = typeName[idx+1:]
		}
		tableName = snakeCase(typeName)
	}

	columns, err := convertStructToColumns(typ, s.ColumnNameFunc)
	if err != nil {
		panic(fmt.Sprintf("sqlx: register: %q: %s", reflect.TypeOf(value).String(), err.Error()))
	}

	s.Rows[typ] = &Row{
		Schema:     s.Name,
		Name:       tableName,
		StructType: typ,
		Columns:    columns,
	}

	return s
}

// Query is a shortcut of executing a query and bind the result to "dst".
func (s *Schema) Query(ctx context.Context, db *sql.DB, dst interface{}, query string, args ...interface{}) error {
	rows, err := db.QueryContext(ctx, query, args...)
	if err != nil {
		return err
	}

	if !s.AutoCloseRows { // if not close on bind, we must close it here.
		defer rows.Close()
	}

	err = s.Bind(dst, rows)
	return err
}

// Bind sets "dst" to the result of "src" and reports any errors.
func (s *Schema) Bind(dst interface{}, src *sql.Rows) error {
	typ := reflect.TypeOf(dst)
	if typ.Kind() != reflect.Ptr {
		return fmt.Errorf("sqlx: bind: destination not a pointer")
	}

	typ = typ.Elem()

	originalKind := typ.Kind()
	if typ.Kind() == reflect.Slice {
		typ = typ.Elem()
	}

	r, ok := s.Rows[typ]
	if !ok {
		return fmt.Errorf("sqlx: bind: unregistered type: %q", typ.String())
	}

	columnTypes, err := src.ColumnTypes()
	if err != nil {
		return fmt.Errorf("sqlx: bind: table: %q: %w", r.Name, err)
	}

	if expected, got := len(r.Columns), len(columnTypes); expected != got {
		return fmt.Errorf("sqlx: bind: table: %q: unexpected number of result columns: %d: expected: %d", r.Name, got, expected)
	}

	val := reflex.IndirectValue(reflect.ValueOf(dst))
	if s.AutoCloseRows {
		defer src.Close()
	}

	switch originalKind {
	case reflect.Struct:
		if src.Next() {
			if err = r.bindSingle(typ, val, columnTypes, src); err != nil {
				return err
			}
		} else {
			return sql.ErrNoRows
		}

		return src.Err()
	case reflect.Slice:
		for src.Next() {
			elem := reflect.New(typ).Elem()
			if err = r.bindSingle(typ, elem, columnTypes, src); err != nil {
				return err
			}

			val = reflect.Append(val, elem)
		}

		if err = src.Err(); err != nil {
			return err
		}

		reflect.ValueOf(dst).Elem().Set(val)
		return nil
	default:
		return fmt.Errorf("sqlx: bind: table: %q: unexpected destination kind: %q", r.Name, typ.Kind().String())
	}
}

func (r *Row) bindSingle(typ reflect.Type, val reflect.Value, columnTypes []*sql.ColumnType, scanner interface{ Scan(...interface{}) error }) error {
	fieldPtrs, err := r.lookupStructFieldPtrs(typ, val, columnTypes)
	if err != nil {
		return fmt.Errorf("sqlx: bind: table: %q: %w", r.Name, err)
	}

	return scanner.Scan(fieldPtrs...)
}

func (r *Row) lookupStructFieldPtrs(typ reflect.Type, val reflect.Value, columnTypes []*sql.ColumnType) ([]interface{}, error) {
	fieldPtrs := make([]interface{}, 0, len(columnTypes))

	for _, columnType := range columnTypes {
		columnName := columnType.Name()
		tableColumn, ok := r.Columns[columnName]
		if !ok {
			continue
		}

		// TODO: when go 1.18 released, replace with that:
		/*
			tableColumnField, err := val.FieldByIndexErr(tableColumn.FieldIndex)
			if err != nil {
				return nil, fmt.Errorf("column: %q: %w", tableColumn.Name, err)
			}
		*/
		tableColumnField := val.FieldByIndex(tableColumn.FieldIndex)

		tableColumnFieldType := tableColumnField.Type()

		fieldPtr := reflect.NewAt(tableColumnFieldType, unsafe.Pointer(tableColumnField.UnsafeAddr())).Elem().Addr().Interface()
		fieldPtrs = append(fieldPtrs, fieldPtr)
	}

	return fieldPtrs, nil
}