diff --git a/x/client/client.go b/x/client/client.go index 3d6cc5aa..5069b85a 100644 --- a/x/client/client.go +++ b/x/client/client.go @@ -3,14 +3,15 @@ package client import ( "bytes" "context" - "crypto/tls" "encoding/json" "errors" "fmt" "io" "io/ioutil" + "mime/multipart" "net/http" "net/url" + "os" "strconv" "strings" ) @@ -28,15 +29,7 @@ type Client struct { func New(opts ...Option) *Client { c := &Client{ - HTTPClient: &http.Client{ - // Timeout: 15 * time.Second, - Transport: &http.Transport{ - MaxIdleConns: 200, - MaxConnsPerHost: 200, - MaxIdleConnsPerHost: 200, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }, - }, + HTTPClient: &http.Client{}, PersistentRequestOptions: defaultRequestOptions, } @@ -115,7 +108,7 @@ func RequestParam(key string, values ...string) RequestOption { // // Any HTTP returned error will be of type APIError // or a timeout error if the given context was canceled. -func (c *Client) Do(ctx context.Context, method, url string, payload interface{}, opts ...RequestOption) (*http.Response, error) { +func (c *Client) Do(ctx context.Context, method, urlpath string, payload interface{}, opts ...RequestOption) (*http.Response, error) { if ctx == nil { ctx = context.Background() } @@ -137,6 +130,8 @@ func (c *Client) Do(ctx context.Context, method, url string, payload interface{} body = bytes.NewBuffer(v) case string: body = strings.NewReader(v) + case url.Values: + body = strings.NewReader(v.Encode()) default: w := new(bytes.Buffer) // We assume it's a struct, we wont make use of reflection to find out though. @@ -149,11 +144,11 @@ func (c *Client) Do(ctx context.Context, method, url string, payload interface{} } if c.BaseURL != "" { - url = c.BaseURL + url // note that we don't do any special checks here, the caller is responsible. + urlpath = c.BaseURL + urlpath // note that we don't do any special checks here, the caller is responsible. } // Initialize the request. - req, err := http.NewRequestWithContext(ctx, method, url, body) + req, err := http.NewRequestWithContext(ctx, method, urlpath, body) if err != nil { return nil, err } @@ -188,27 +183,93 @@ func (c *Client) Do(ctx context.Context, method, url string, payload interface{} const ( acceptKey = "Accept" contentTypeKey = "Content-Type" + contentLengthKey = "Content-Length" contentTypePlainText = "plain/text" contentTypeJSON = "application/json" contentTypeFormURLEncoded = "application/x-www-form-urlencoded" ) -func (c *Client) JSON(ctx context.Context, method, url string, payload interface{}, opts ...RequestOption) (*http.Response, error) { +func (c *Client) JSON(ctx context.Context, method, urlpath string, payload interface{}, opts ...RequestOption) (*http.Response, error) { opts = append(opts, RequestHeader(true, contentTypeKey, contentTypeJSON)) - return c.Do(ctx, method, url, payload, opts...) + return c.Do(ctx, method, urlpath, payload, opts...) } -func (c *Client) Form(ctx context.Context, method, url string, payload interface{}, opts ...RequestOption) (*http.Response, error) { - opts = append(opts, RequestHeader(true, contentTypeKey, contentTypeFormURLEncoded)) - return c.Do(ctx, method, url, payload, opts...) +func (c *Client) Form(ctx context.Context, method, urlpath string, formValues url.Values, opts ...RequestOption) (*http.Response, error) { + payload := formValues.Encode() + + opts = append(opts, + RequestHeader(true, contentTypeKey, contentTypeFormURLEncoded), + RequestHeader(true, contentLengthKey, strconv.Itoa(len(payload))), + ) + + return c.Do(ctx, method, urlpath, payload, opts...) } -func (c *Client) ReadJSON(ctx context.Context, dest interface{}, method, url string, payload interface{}, opts ...RequestOption) error { +type Uploader struct { + client *Client + + body *bytes.Buffer + Writer *multipart.Writer +} + +func (u *Uploader) AddField(key, value string) error { + f, err := u.Writer.CreateFormField(key) + if err != nil { + return err + } + + _, err = io.Copy(f, strings.NewReader(value)) + return err +} + +func (u *Uploader) AddFileSource(key, filename string, source io.Reader) error { + f, err := u.Writer.CreateFormFile(key, filename) + if err != nil { + return err + } + + _, err = io.Copy(f, source) + return err +} + +func (u *Uploader) AddFile(key, filename string) error { + source, err := os.Open(filename) + if err != nil { + return err + } + + return u.AddFileSource(key, filename, source) +} + +func (u *Uploader) Upload(ctx context.Context, method, urlpath string, opts ...RequestOption) (*http.Response, error) { + err := u.Writer.Close() + if err != nil { + return nil, err + } + + payload := bytes.NewReader(u.body.Bytes()) + opts = append(opts, RequestHeader(true, contentTypeKey, u.Writer.FormDataContentType())) + + return u.client.Do(ctx, method, urlpath, payload, opts...) +} + +func (c *Client) NewUploader() *Uploader { + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + + return &Uploader{ + client: c, + body: body, + Writer: writer, + } +} + +func (c *Client) ReadJSON(ctx context.Context, dest interface{}, method, urlpath string, payload interface{}, opts ...RequestOption) error { if payload != nil { opts = append(opts, RequestHeader(true, contentTypeKey, contentTypeJSON)) } - resp, err := c.Do(ctx, method, url, payload, opts...) + resp, err := c.Do(ctx, method, urlpath, payload, opts...) if err != nil { return err } @@ -228,8 +289,8 @@ func (c *Client) ReadJSON(ctx context.Context, dest interface{}, method, url str // ReadPlain like ReadJSON but it accepts a pointer to a string or byte slice or integer // and it reads the body as plain text. -func (c *Client) ReadPlain(ctx context.Context, dest interface{}, method, url string, payload interface{}, opts ...RequestOption) error { - resp, err := c.Do(ctx, method, url, payload, opts...) +func (c *Client) ReadPlain(ctx context.Context, dest interface{}, method, urlpath string, payload interface{}, opts ...RequestOption) error { + resp, err := c.Do(ctx, method, urlpath, payload, opts...) if err != nil { return err } @@ -262,9 +323,9 @@ func (c *Client) ReadPlain(ctx context.Context, dest interface{}, method, url st // GetPlainUnquote reads the response body as raw text and tries to unquote it, // useful when the remote server sends a single key as a value but due to backend mistake // it sends it as JSON (quoted) instead of plain text. -func (c *Client) GetPlainUnquote(ctx context.Context, method, url string, payload interface{}, opts ...RequestOption) (string, error) { +func (c *Client) GetPlainUnquote(ctx context.Context, method, urlpath string, payload interface{}, opts ...RequestOption) (string, error) { var bodyStr string - if err := c.ReadPlain(ctx, &bodyStr, method, url, payload, opts...); err != nil { + if err := c.ReadPlain(ctx, &bodyStr, method, urlpath, payload, opts...); err != nil { return "", err } @@ -281,12 +342,12 @@ func (c *Client) GetPlainUnquote(ctx context.Context, method, url string, payloa // content-type and content-length of the original request. // // Returns the amount of bytes written to "dest". -func (c *Client) WriteTo(ctx context.Context, dest io.Writer, method, url string, payload interface{}, opts ...RequestOption) (int64, error) { +func (c *Client) WriteTo(ctx context.Context, dest io.Writer, method, urlpath string, payload interface{}, opts ...RequestOption) (int64, error) { if payload != nil { opts = append(opts, RequestHeader(true, contentTypeKey, contentTypeJSON)) } - resp, err := c.Do(ctx, method, url, payload, opts...) + resp, err := c.Do(ctx, method, urlpath, payload, opts...) if err != nil { return 0, err }