diff --git a/_examples/compression/main_test.go b/_examples/compression/main_test.go index e63f0134..62cbb459 100644 --- a/_examples/compression/main_test.go +++ b/_examples/compression/main_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "github.com/kataras/iris/v12" "github.com/kataras/iris/v12/context" "github.com/kataras/iris/v12/httptest" ) @@ -15,7 +16,49 @@ func TestCompression(t *testing.T) { e := httptest.New(t, app) var expectedReply = payload{Username: "Makis"} - body := e.GET("/").WithHeader(context.AcceptEncodingHeaderKey, context.GZIP).Expect(). + testBody(t, e.GET("/"), expectedReply) +} + +func TestCompressionAfterRecorder(t *testing.T) { + var expectedReply = payload{Username: "Makis"} + + app := iris.New() + app.Use(func(ctx iris.Context) { + ctx.Record() + ctx.Next() + }) + app.Use(iris.Compression) + + app.Get("/", func(ctx iris.Context) { + ctx.JSON(expectedReply) + }) + + e := httptest.New(t, app) + testBody(t, e.GET("/"), expectedReply) +} + +func TestCompressionBeforeRecorder(t *testing.T) { + var expectedReply = payload{Username: "Makis"} + + app := iris.New() + app.Use(iris.Compression) + app.Use(func(ctx iris.Context) { + ctx.Record() + ctx.Next() + }) + + app.Get("/", func(ctx iris.Context) { + ctx.JSON(expectedReply) + }) + + e := httptest.New(t, app) + testBody(t, e.GET("/"), expectedReply) +} + +func testBody(t *testing.T, req *httptest.Request, expectedReply interface{}) { + t.Helper() + + body := req.WithHeader(context.AcceptEncodingHeaderKey, context.GZIP).Expect(). Status(httptest.StatusOK). ContentEncoding(context.GZIP). ContentType(context.ContentJSONHeaderValue).Body().Raw() diff --git a/context/compress.go b/context/compress.go index 41f91b6f..d719116f 100644 --- a/context/compress.go +++ b/context/compress.go @@ -243,12 +243,24 @@ func releaseCompressResponseWriter(w *CompressResponseWriter) { func (w *CompressResponseWriter) FlushResponse() { w.FlushHeaders() + /* this should NEVER happen, see `context.CompressWriter` method. + if rec, ok := w.ResponseWriter.(*ResponseRecorder); ok { + // Usecase: record, then compression. + w.CompressWriter.Close() // flushes and closes. + rec.FlushResponse() + return + } + */ + // write the status, after header set and before any flushed content sent. w.ResponseWriter.FlushResponse() w.CompressWriter.Close() // flushes and closes. } +// FlushHeaders deletes the encoding headers if +// the compressed writer was disabled otherwise +// removes the content-length so next callers can re-calculate the correct length. func (w *CompressResponseWriter) FlushHeaders() { if w.Disabled { w.Header().Del(VaryHeaderKey) @@ -294,3 +306,18 @@ func (w *CompressResponseWriter) Flush() { w.ResponseWriter.Flush() } + +// WriteTo writes the "p" to "dest" Writer using the compression that this compress writer was made of. +func (w *CompressResponseWriter) WriteTo(dest io.Writer, p []byte) (int, error) { + if w.Disabled { + return dest.Write(p) + } + + cw, err := NewCompressWriter(dest, w.Encoding, w.Level) + if err != nil { + return 0, err + } + n, err := cw.Write(p) + cw.Close() + return n, err +} diff --git a/context/context.go b/context/context.go index e48ce4a2..f80a0b5f 100644 --- a/context/context.go +++ b/context/context.go @@ -2297,22 +2297,42 @@ func (ctx *Context) ClientSupportsEncoding(encodings ...string) bool { // Sometimes, using additional compression doesn't reduce payload size and // can even make the payload longer. func (ctx *Context) CompressWriter(enable bool) error { - cw, ok := ctx.writer.(*CompressResponseWriter) - if enable { - if ok { - // already a compress writer. + switch w := ctx.writer.(type) { + case *CompressResponseWriter: + if enable { return nil } - w, err := AcquireCompressResponseWriter(ctx.writer, ctx.request, -1) + w.Disabled = true + case *ResponseRecorder: + if enable { + // Keep the Recorder as ctx.writer. + // Wrap the existing net/http response writer + // with the compressed writer and + // replace the recorder's response writer + // reference with that compressed one. + // Fixes an issue when Record is called before CompressWriter. + cw, err := AcquireCompressResponseWriter(w.ResponseWriter, ctx.request, -1) + if err != nil { + return err + } + w.ResponseWriter = cw + } else { + cw, ok := w.ResponseWriter.(*CompressResponseWriter) + if ok { + cw.Disabled = true + } + } + default: + if !enable { + return nil + } + + cw, err := AcquireCompressResponseWriter(w, ctx.request, -1) if err != nil { return err } - ctx.writer = w - } else { - if ok { - cw.Disabled = true - } + ctx.writer = cw } return nil @@ -4341,7 +4361,7 @@ func (ctx *Context) BeginTransaction(pipe func(t *Transaction)) { } // write the temp contents to the original writer - t.Context().ResponseWriter().WriteTo(ctx.writer) + t.Context().ResponseWriter().CopyTo(ctx.writer) // give back to the transaction the original writer (SetBeforeFlush works this way and only this way) // this is tricky but nessecery if we want ctx.FireStatusCode to work inside transactions t.Context().ResetResponseWriter(ctx.writer) diff --git a/context/response_recorder.go b/context/response_recorder.go index 42831a18..9da74609 100644 --- a/context/response_recorder.go +++ b/context/response_recorder.go @@ -192,8 +192,8 @@ func (w *ResponseRecorder) Clone() ResponseWriter { return wc } -// WriteTo writes a response writer (temp: status code, headers and body) to another response writer -func (w *ResponseRecorder) WriteTo(res ResponseWriter) { +// CopyTo writes a response writer (temp: status code, headers and body) to another response writer +func (w *ResponseRecorder) CopyTo(res ResponseWriter) { if to, ok := res.(*ResponseRecorder); ok { // set the status code, to is first ( probably an error? (context.StatusCodeNotSuccessful, defaults to >=400). diff --git a/context/response_writer.go b/context/response_writer.go index 4caa9b94..e52f4e13 100644 --- a/context/response_writer.go +++ b/context/response_writer.go @@ -3,6 +3,7 @@ package context import ( "bufio" "errors" + "io" "net" "net/http" "sync" @@ -68,8 +69,8 @@ type ResponseWriter interface { // it copies the header, status code, headers and the beforeFlush finally returns a new ResponseRecorder. Clone() ResponseWriter - // WiteTo writes a response writer (temp: status code, headers and body) to another response writer - WriteTo(ResponseWriter) + // CopyTo writes a response writer (temp: status code, headers and body) to another response writer + CopyTo(ResponseWriter) // Flusher indicates if `Flush` is supported by the client. // @@ -112,6 +113,16 @@ type ResponseWriterReseter interface { Reset() bool } +// ResponseWriterWriteTo can be implemented +// by response writers that needs a special +// encoding before writing to their buffers. +// E.g. a custom recorder that wraps a custom compressed one. +// +// Not used by the framework itself. +type ResponseWriterWriteTo interface { + WriteTo(dest io.Writer, p []byte) +} + // +------------------------------------------------------------+ // | Response Writer Implementation | // +------------------------------------------------------------+ @@ -300,8 +311,8 @@ func (w *responseWriter) Clone() ResponseWriter { return wc } -// WriteTo writes a response writer (temp: status code, headers and body) to another response writer. -func (w *responseWriter) WriteTo(to ResponseWriter) { +// CopyTo writes a response writer (temp: status code, headers and body) to another response writer. +func (w *responseWriter) CopyTo(to ResponseWriter) { // set the status code, failure status code are first class if w.statusCode >= 400 { to.WriteHeader(w.statusCode)