diff --git a/NOTICE b/NOTICE index 72189f00..44bfc0ff 100644 --- a/NOTICE +++ b/NOTICE @@ -44,6 +44,10 @@ Revision ID: 5fc50a00491616d5cd0cbce3abd8b699838e25ca easyjson 8ab5ff9cd8e4e43 https://github.com/mailru/easyjson 2e8b79f6c47d324 a31dd803cf + + securecookie e59506cc896acb7 https://github.com/gorilla/securecookie + f7bf732d4fdf5e2 + 5f7ccd8983 semver 4487282d78122a2 https://github.com/blang/semver 45e413d7515e7c5 16b70c33fd diff --git a/_examples/README.md b/_examples/README.md index 32be6859..78905cd2 100644 --- a/_examples/README.md +++ b/_examples/README.md @@ -124,6 +124,7 @@ * [Embedded Single Page Application with other routes](file-server/single-page-application/embedded-single-page-application-with-other-routes/main.go) * [Upload File](file-server/upload-file/main.go) * [Upload Multiple Files](file-server/upload-files/main.go) + * [WebDAV](file-server/webdav/main.go) * View * [Overview](view/overview/main.go) * [Layout](view/layout) @@ -212,7 +213,7 @@ * [Basic](i18n/basic) * [Ttemplates and Functions](i18n/template) * [Pluralization and Variables](i18n/plurals) -* Authentication, Authorization & Bot Detection +* Authentication, Authorization & Bot Detection * Basic Authentication * [Basic](auth/basicauth/basic) * [Load from a slice of Users](auth/basicauth/users_list) @@ -225,6 +226,7 @@ * [Blocklist](auth/jwt/blocklist/main.go) * [Refresh Token](auth/jwt/refresh-token/main.go) * [Tutorial](auth/jwt/tutorial) + * [SSO](auth/sso) **NEW (GO 1.18 Generics required)** * [JWT (community edition)](https://github.com/iris-contrib/middleware/tree/v12/jwt/_example/main.go) * [OAUth2](auth/goth/main.go) * [Manage Permissions](auth/permissions/main.go) @@ -277,6 +279,7 @@ * [Authenticated Controller](mvc/authenticated-controller/main.go) * [Versioned Controller](mvc/versioned-controller/main.go) * [Websocket Controller](mvc/websocket) + * [Websocket + Authentication (SSO)](mvc/websocket-sso) **NEW (GO 1.18 Generics required)** * [Register Middleware](mvc/middleware) * [gRPC](mvc/grpc-compatible) * [gRPC Bidirectional Stream](mvc/grpc-compatible-bidirectional-stream) diff --git a/_examples/auth/sso/README.md b/_examples/auth/sso/README.md new file mode 100644 index 00000000..ffa7a823 --- /dev/null +++ b/_examples/auth/sso/README.md @@ -0,0 +1,12 @@ +# SSO (Single Sign On) + +```sh +$ go run . +``` + +1. GET/POST: http://localhost:8080/signin +2. GET: http://localhost:8080/member +3. GET: http://localhost:8080/owner +4. POST: http://localhost:8080/refresh +5. GET: http://localhost:8080/signout +6. GET: http://localhost:8080/signout-all \ No newline at end of file diff --git a/_examples/auth/sso/main.go b/_examples/auth/sso/main.go new file mode 100644 index 00000000..f09236a2 --- /dev/null +++ b/_examples/auth/sso/main.go @@ -0,0 +1,135 @@ +//go:build go1.18 + +package main + +import ( + "fmt" + + "github.com/kataras/iris/v12" + "github.com/kataras/iris/v12/sso" +) + +func allowRole(role AccessRole) sso.TVerify[User] { + return func(u User) error { + if !u.Role.Allow(role) { + return fmt.Errorf("invalid role") + } + + return nil + } +} + +const configFilename = "./sso.yml" + +func main() { + app := iris.New() + app.RegisterView(iris.Blocks(iris.Dir("./views"), ".html"). + LayoutDir("layouts"). + Layout("main")) + + /* + // Easiest 1-liner way, load from configuration and initialize a new sso instance: + s := sso.MustLoad[User]("./sso.yml") + // Bind a configuration from file: + var c sso.Configuration + c.BindFile("./sso.yml") + s, err := sso.New[User](c) + // OR create new programmatically configuration: + config := sso.Configuration{ + ...fields + } + s, err := sso.New[User](config) + // OR generate a new configuration: + config := sso.MustGenerateConfiguration() + s, err := sso.New[User](config) + // OR generate a new config and save it if cannot open the config file. + if _, err := os.Stat(configFilename); err != nil { + generatedConfig := sso.MustGenerateConfiguration() + configContents, err := generatedConfig.ToYAML() + if err != nil { + panic(err) + } + + err = os.WriteFile(configFilename, configContents, 0600) + if err != nil { + panic(err) + } + } + */ + + // 1. Load configuration from a file. + ssoConfig, err := sso.LoadConfiguration(configFilename) + if err != nil { + panic(err) + } + + // 2. Initialize a new sso instance for "User" claims (generics: go1.18 +). + s, err := sso.New[User](ssoConfig) + if err != nil { + panic(err) + } + + // 3. Add a custom provider, in our case is just a memory-based one. + s.AddProvider(NewProvider()) + // 3.1. Optionally set a custom error handler. + // s.SetErrorHandler(new(sso.DefaultErrorHandler)) + + app.Get("/signin", renderSigninForm) + // 4. generate token pairs. + app.Post("/signin", s.SigninHandler) + // 5. refresh token pairs. + app.Post("/refresh", s.RefreshHandler) + // 6. calls the provider's InvalidateToken method. + app.Get("/signout", s.SignoutHandler) + // 7. calls the provider's InvalidateTokens method. + app.Get("/signout-all", s.SignoutAllHandler) + + // 8.1. allow access for users with "Member" role. + app.Get("/member", s.VerifyHandler(allowRole(Member)), renderMemberPage(s)) + // 8.2. allow access for users with "Owner" role. + app.Get("/owner", s.VerifyHandler(allowRole(Owner)), renderOwnerPage(s)) + + /* Subdomain user verify: + app.Subdomain("owner", s.VerifyHandler(allowRole(Owner))).Get("/", renderOwnerPage(s)) + */ + app.Listen(":8080", iris.WithOptimizations) // Setup HTTPS/TLS for production instead. + /* Test subdomain user verify, one way is ingrok, + add the below to the arguments above: + , iris.WithConfiguration(iris.Configuration{ + EnableOptmizations: true, + Tunneling: iris.TunnelingConfiguration{ + AuthToken: "YOUR_AUTH_TOKEN", + Region: "us", + Tunnels: []tunnel.Tunnel{ + { + Name: "Iris SSO (Test)", + Addr: ":8080", + Hostname: "YOUR_DOMAIN", + }, + { + Name: "Iris SSO (Test Subdomain)", + Addr: ":8080", + Hostname: "owner.YOUR_DOMAIN", + }, + }, + }, + })*/ +} + +func renderSigninForm(ctx iris.Context) { + ctx.View("signin", iris.Map{"Title": "Signin Page"}) +} + +func renderMemberPage(s *sso.SSO[User]) iris.Handler { + return func(ctx iris.Context) { + user := s.GetUser(ctx) + ctx.Writef("Hello member: %s\n", user.Email) + } +} + +func renderOwnerPage(s *sso.SSO[User]) iris.Handler { + return func(ctx iris.Context) { + user := s.GetUser(ctx) + ctx.Writef("Hello owner: %s\n", user.Email) + } +} diff --git a/_examples/auth/sso/sso.yml b/_examples/auth/sso/sso.yml new file mode 100644 index 00000000..c25a22bc --- /dev/null +++ b/_examples/auth/sso/sso.yml @@ -0,0 +1,32 @@ +Cookie: # optional. + Name: "iris_sso" + Hash: "D*G-KaPdSgUkXp2s5v8y/B?E(H+MbQeThWmYq3t6w9z$C&F)J@NcRfUjXn2r4u7x" # length of 64 characters (512-bit). + Block: "VkYp3s6v9y$B&E)H@McQfTjWmZq4t7w!" # length of 32 characters (256-bit). +Keys: + - ID: IRIS_SSO_ACCESS # required. + Alg: EdDSA + MaxAge: 2h # 2 hours lifetime for access tokens. + Private: |+ + -----BEGIN PRIVATE KEY----- + MC4CAQAwBQYDK2VwBCIEIFdZWoDdFny5SMnP9Fyfr8bafi/B527EVZh8JJjDTIFO + -----END PRIVATE KEY----- + Public: |+ + -----BEGIN PUBLIC KEY----- + MCowBQYDK2VwAyEAzpgjKSr9E032DX+foiOxq1QDsbzjLxagTN+yVpGWZB4= + -----END PUBLIC KEY----- + - ID: IRIS_SSO_REFRESH # optional. Good practise to have it though. + Alg: EdDSA + # 1 month lifetime for refresh tokens, + # after that period the user has to signin again. + MaxAge: 720h + Private: |+ + -----BEGIN PRIVATE KEY----- + MC4CAQAwBQYDK2VwBCIEIHJ1aoIjA2sRp5eqGjGR3/UMucrHbBdBv9p8uwfzZ1KZ + -----END PRIVATE KEY----- + Public: |+ + -----BEGIN PUBLIC KEY----- + MCowBQYDK2VwAyEAsKKAr+kDtfAqwG7cZdoEAfh9jHt9W8qi9ur5AA1KQAQ= + -----END PUBLIC KEY----- + # Example of setting a binary form of the encryption key for refresh tokens, + # it could be a "string" as well. + EncryptionKey: !!binary stSNLTu91YyihPxzeEOXKwGVMG00CjcC/68G8nMgmqA= diff --git a/_examples/auth/sso/user.go b/_examples/auth/sso/user.go new file mode 100644 index 00000000..a2e93017 --- /dev/null +++ b/_examples/auth/sso/user.go @@ -0,0 +1,33 @@ +//go:build go1.18 + +package main + +type AccessRole uint16 + +func (r AccessRole) Is(v AccessRole) bool { + return r&v != 0 +} + +func (r AccessRole) Allow(v AccessRole) bool { + return r&v >= v +} + +const ( + InvalidAccessRole AccessRole = 1 << iota + Read + Write + Delete + + Owner = Read | Write | Delete + Member = Read | Write +) + +type User struct { + ID string `json:"id"` + Email string `json:"email"` + Role AccessRole `json:"role"` +} + +func (u User) GetID() string { + return u.ID +} diff --git a/_examples/auth/sso/user_provider.go b/_examples/auth/sso/user_provider.go new file mode 100644 index 00000000..151e426b --- /dev/null +++ b/_examples/auth/sso/user_provider.go @@ -0,0 +1,100 @@ +//go:build go1.18 + +package main + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/kataras/iris/v12/sso" +) + +type Provider struct { + dataset []User + + invalidated map[string]struct{} // key = token. Entry is blocked. + invalidatedAll map[string]int64 // key = user id, value = timestamp. Issued before is consider invalid. + mu sync.RWMutex +} + +func NewProvider() *Provider { + return &Provider{ + dataset: []User{ + { + ID: "id-1", + Email: "kataras2006@hotmail.com", + Role: Owner, + }, + { + ID: "id-2", + Email: "example@example.com", + Role: Member, + }, + }, + invalidated: make(map[string]struct{}), + invalidatedAll: make(map[string]int64), + } +} + +func (p *Provider) Signin(ctx context.Context, username, password string) (User, error) { // fired on SigninHandler. + // your database... + for _, user := range p.dataset { + if user.Email == username { + return user, nil + } + } + + return User{}, fmt.Errorf("user not found") +} + +func (p *Provider) ValidateToken(ctx context.Context, standardClaims sso.StandardClaims, u User) error { // fired on VerifyHandler. + // your database and checks of blocked tokens... + + // check for specific token ids. + p.mu.RLock() + _, tokenBlocked := p.invalidated[standardClaims.ID] + if !tokenBlocked { + // this will disallow refresh tokens with origin jwt token id as the blocked access token as well. + if standardClaims.OriginID != "" { + _, tokenBlocked = p.invalidated[standardClaims.OriginID] + } + } + p.mu.RUnlock() + + if tokenBlocked { + return fmt.Errorf("token was invalidated") + } + // + + // check all tokens issuet before the "InvalidateToken" method was fired for this user. + p.mu.RLock() + ts, oldUserBlocked := p.invalidatedAll[u.ID] + p.mu.RUnlock() + + if oldUserBlocked && standardClaims.IssuedAt <= ts { + return fmt.Errorf("token was invalidated") + } + // + + return nil // else valid. +} + +func (p *Provider) InvalidateToken(ctx context.Context, standardClaims sso.StandardClaims, u User) error { // fired on SignoutHandler. + // invalidate this specific token. + p.mu.Lock() + p.invalidated[standardClaims.ID] = struct{}{} + p.mu.Unlock() + + return nil +} + +func (p *Provider) InvalidateTokens(ctx context.Context, u User) error { // fired on SignoutAllHandler. + // invalidate all previous tokens came from "u". + p.mu.Lock() + p.invalidatedAll[u.ID] = time.Now().Unix() + p.mu.Unlock() + + return nil +} diff --git a/_examples/auth/sso/views/layouts/main.html b/_examples/auth/sso/views/layouts/main.html new file mode 100644 index 00000000..c6efd56d --- /dev/null +++ b/_examples/auth/sso/views/layouts/main.html @@ -0,0 +1,30 @@ + + + + + + {{ if .Title }}{{ .Title }}{{ else }}Default Main Title{{ end }} + + + +
+
{{ template "content" . }}
+ +
+ + \ No newline at end of file diff --git a/_examples/auth/sso/views/partials/footer.html b/_examples/auth/sso/views/partials/footer.html new file mode 100644 index 00000000..69b8f9b8 --- /dev/null +++ b/_examples/auth/sso/views/partials/footer.html @@ -0,0 +1 @@ +Iris Web Framework © 2022 \ No newline at end of file diff --git a/_examples/auth/sso/views/signin.html b/_examples/auth/sso/views/signin.html new file mode 100644 index 00000000..57f7514c --- /dev/null +++ b/_examples/auth/sso/views/signin.html @@ -0,0 +1,9 @@ +
+
+ + + + + +
+
\ No newline at end of file diff --git a/_examples/file-server/webdav/main.go b/_examples/file-server/webdav/main.go new file mode 100644 index 00000000..0db07792 --- /dev/null +++ b/_examples/file-server/webdav/main.go @@ -0,0 +1,57 @@ +package main + +import ( + "net/http" + "os" + "strings" + + "github.com/kataras/iris/v12" + "github.com/kataras/iris/v12/middleware/accesslog" + "github.com/kataras/iris/v12/middleware/recover" + + "golang.org/x/net/webdav" +) + +func main() { + app := iris.New() + + app.Logger().SetLevel("debug") + app.Use(recover.New()) + app.Use(accesslog.New(os.Stdout).Handler) + + webdavHandler := &webdav.Handler{ + FileSystem: webdav.Dir("./"), + LockSystem: webdav.NewMemLS(), + Logger: func(r *http.Request, err error) { + if err != nil { + app.Logger().Error(err) + } + }, + } + + app.HandleMany(strings.Join(iris.WebDAVMethods, " "), "/{p:path}", iris.FromStd(webdavHandler)) + + app.Listen(":8080", + iris.WithoutServerError(iris.ErrServerClosed), + iris.WithoutPathCorrection, + ) +} + +/* Test with cURL or postman: + +* List files: + curl --location --request PROPFIND 'http://localhost:8080' +* Get File: + curl --location --request GET 'http://localhost:8080/test.txt' +* Upload File: + curl --location --request PUT 'http://localhost:8080/newfile.txt' \ + --header 'Content-Type: text/plain' \ + --data-raw 'This is a new file!' +* Copy File: + curl --location --request COPY 'http://localhost:8080/test.txt' \ + --header 'Destination: newdir/test.txt' +* Create New Directory: + curl --location --request MKCOL 'http://localhost:8080/anewdir/' + +And e.t.c. +*/ diff --git a/_examples/file-server/webdav/newdir/.gitkeep b/_examples/file-server/webdav/newdir/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/_examples/file-server/webdav/test.txt b/_examples/file-server/webdav/test.txt new file mode 100644 index 00000000..af5626b4 --- /dev/null +++ b/_examples/file-server/webdav/test.txt @@ -0,0 +1 @@ +Hello, world! diff --git a/_examples/mvc/grpc-compatible/main.go b/_examples/mvc/grpc-compatible/main.go index f93bf7a1..3ef16174 100644 --- a/_examples/mvc/grpc-compatible/main.go +++ b/_examples/mvc/grpc-compatible/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "log" pb "github.com/kataras/iris/v12/_examples/mvc/grpc-compatible/helloworld" @@ -47,17 +48,32 @@ func newApp() *iris.Application { // Register MVC application controller for gRPC services. // You can bind as many mvc gRpc services in the same Party or app, // as the ServiceName differs. - mvc.New(app).Handle(ctrl, mvc.GRPC{ - Server: grpcServer, // Required. - ServiceName: "helloworld.Greeter", // Required. - Strict: false, - }) + mvc.New(app). + Register(new(myService)). + Handle(ctrl, mvc.GRPC{ + Server: grpcServer, // Required. + ServiceName: "helloworld.Greeter", // Required. + Strict: false, + }) return app } +type service interface { + DoSomething() error +} + +type myService struct{} + +func (s *myService) DoSomething() error { + log.Println("service: DoSomething") + return nil +} + type myController struct { // Ctx iris.Context + + SingletonDependency service } // SayHello implements helloworld.GreeterServer. @@ -70,5 +86,10 @@ type myController struct { // @Success 200 {string} string "Hello {name}" // @Router /helloworld.Greeter/SayHello [post] func (c *myController) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { + err := c.SingletonDependency.DoSomething() + if err != nil { + return nil, err + } + return &pb.HelloReply{Message: "Hello " + in.GetName()}, nil } diff --git a/_examples/mvc/websocket-sso/browser/index.html b/_examples/mvc/websocket-sso/browser/index.html new file mode 100644 index 00000000..09bd4498 --- /dev/null +++ b/_examples/mvc/websocket-sso/browser/index.html @@ -0,0 +1,106 @@ + + + + Online visitors MVC example + + + + +
+ 1 online visitor +
+ + + + + + + + +

+    
+    
+    
+
+
+
+
\ No newline at end of file
diff --git a/_examples/mvc/websocket-sso/main.go b/_examples/mvc/websocket-sso/main.go
new file mode 100644
index 00000000..1e59e264
--- /dev/null
+++ b/_examples/mvc/websocket-sso/main.go
@@ -0,0 +1,72 @@
+//go:build go1.18
+
+package main
+
+import (
+	"fmt"
+
+	"github.com/kataras/iris/v12"
+	"github.com/kataras/iris/v12/mvc"
+	"github.com/kataras/iris/v12/sso"
+	"github.com/kataras/iris/v12/websocket"
+)
+
+// $ go run .
+func main() {
+	app := newApp()
+
+	// http://localhost:8080/signin (creds: kataras2006@hotmail.com 123456)
+	// http://localhost:8080/protected
+	// http://localhost:8080/signout
+	app.Listen(":8080")
+}
+
+func newApp() *iris.Application {
+	app := iris.New()
+
+	// Auth part.
+	app.RegisterView(iris.Blocks(iris.Dir("./views"), ".html").
+		LayoutDir("layouts").
+		Layout("main"))
+
+	s := sso.MustLoad[User]("./sso.yml")
+	s.AddProvider(NewProvider())
+
+	app.Get("/signin", renderSigninForm)
+	app.Post("/signin", s.SigninHandler)
+	app.Get("/signout", s.SignoutHandler)
+	//
+
+	websocketAPI := app.Party("/protected")
+	websocketAPI.Use(s.VerifyHandler())
+	websocketAPI.HandleDir("/", iris.Dir("./browser")) // render the ./browser/index.html.
+
+	websocketMVC := mvc.New(websocketAPI)
+	websocketMVC.HandleWebsocket(new(websocketController))
+	websocketServer := websocket.New(websocket.DefaultGorillaUpgrader, websocketMVC)
+	websocketAPI.Get("/ws", s.VerifyHandler() /* optional */, websocket.Handler(websocketServer))
+
+	return app
+}
+
+func renderSigninForm(ctx iris.Context) {
+	ctx.View("signin", iris.Map{"Title": "Signin Page"})
+}
+
+type websocketController struct {
+	*websocket.NSConn `stateless:"true"`
+}
+
+func (c *websocketController) Namespace() string {
+	return "default"
+}
+
+func (c *websocketController) OnChat(msg websocket.Message) error {
+	ctx := websocket.GetContext(c.Conn)
+	user := sso.GetUser[User](ctx)
+
+	msg.Body = []byte(fmt.Sprintf("%s: %s", user.Email, string(msg.Body)))
+	c.Conn.Server().Broadcast(c, msg)
+
+	return nil
+}
diff --git a/_examples/mvc/websocket-sso/sso.yml b/_examples/mvc/websocket-sso/sso.yml
new file mode 100644
index 00000000..c25a22bc
--- /dev/null
+++ b/_examples/mvc/websocket-sso/sso.yml
@@ -0,0 +1,32 @@
+Cookie: # optional.
+  Name: "iris_sso"
+  Hash: "D*G-KaPdSgUkXp2s5v8y/B?E(H+MbQeThWmYq3t6w9z$C&F)J@NcRfUjXn2r4u7x" # length of 64 characters (512-bit).
+  Block: "VkYp3s6v9y$B&E)H@McQfTjWmZq4t7w!" # length of 32 characters (256-bit).
+Keys:
+  - ID: IRIS_SSO_ACCESS # required.
+    Alg: EdDSA
+    MaxAge: 2h # 2 hours lifetime for access tokens. 
+    Private: |+
+      -----BEGIN PRIVATE KEY-----
+      MC4CAQAwBQYDK2VwBCIEIFdZWoDdFny5SMnP9Fyfr8bafi/B527EVZh8JJjDTIFO
+      -----END PRIVATE KEY-----
+    Public: |+
+      -----BEGIN PUBLIC KEY-----
+      MCowBQYDK2VwAyEAzpgjKSr9E032DX+foiOxq1QDsbzjLxagTN+yVpGWZB4=
+      -----END PUBLIC KEY-----
+  - ID: IRIS_SSO_REFRESH # optional. Good practise to have it though.
+    Alg: EdDSA
+    # 1 month lifetime for refresh tokens,
+    # after that period the user has to signin again.
+    MaxAge: 720h
+    Private: |+
+      -----BEGIN PRIVATE KEY-----
+      MC4CAQAwBQYDK2VwBCIEIHJ1aoIjA2sRp5eqGjGR3/UMucrHbBdBv9p8uwfzZ1KZ
+      -----END PRIVATE KEY-----
+    Public: |+
+      -----BEGIN PUBLIC KEY-----
+      MCowBQYDK2VwAyEAsKKAr+kDtfAqwG7cZdoEAfh9jHt9W8qi9ur5AA1KQAQ=
+      -----END PUBLIC KEY-----
+    # Example of setting a binary form of the encryption key for refresh tokens,
+    # it could be a "string" as well.
+    EncryptionKey: !!binary stSNLTu91YyihPxzeEOXKwGVMG00CjcC/68G8nMgmqA=
diff --git a/_examples/mvc/websocket-sso/user.go b/_examples/mvc/websocket-sso/user.go
new file mode 100644
index 00000000..a2e93017
--- /dev/null
+++ b/_examples/mvc/websocket-sso/user.go
@@ -0,0 +1,33 @@
+//go:build go1.18
+
+package main
+
+type AccessRole uint16
+
+func (r AccessRole) Is(v AccessRole) bool {
+	return r&v != 0
+}
+
+func (r AccessRole) Allow(v AccessRole) bool {
+	return r&v >= v
+}
+
+const (
+	InvalidAccessRole AccessRole = 1 << iota
+	Read
+	Write
+	Delete
+
+	Owner  = Read | Write | Delete
+	Member = Read | Write
+)
+
+type User struct {
+	ID    string     `json:"id"`
+	Email string     `json:"email"`
+	Role  AccessRole `json:"role"`
+}
+
+func (u User) GetID() string {
+	return u.ID
+}
diff --git a/_examples/mvc/websocket-sso/user_provider.go b/_examples/mvc/websocket-sso/user_provider.go
new file mode 100644
index 00000000..a1846c15
--- /dev/null
+++ b/_examples/mvc/websocket-sso/user_provider.go
@@ -0,0 +1,100 @@
+//go:build go1.18
+
+package main
+
+import (
+	"context"
+	"fmt"
+	"sync"
+	"time"
+
+	"github.com/kataras/iris/v12/sso"
+)
+
+type Provider struct {
+	dataset []User
+
+	invalidated    map[string]struct{} // key = token. Entry is blocked.
+	invalidatedAll map[string]int64    // key = user id, value = timestamp. Issued before is consider invalid.
+	mu             sync.RWMutex
+}
+
+func NewProvider() *Provider {
+	return &Provider{
+		dataset: []User{
+			{
+				ID:    "id-1",
+				Email: "kataras2006@hotmail.com",
+				Role:  Owner,
+			},
+			{
+				ID:    "id-2",
+				Email: "example@example.com",
+				Role:  Member,
+			},
+		},
+		invalidated:    make(map[string]struct{}),
+		invalidatedAll: make(map[string]int64),
+	}
+}
+
+func (p *Provider) Signin(ctx context.Context, username, password string) (User, error) { // fired on SigninHandler.
+	// your database...
+	for _, user := range p.dataset {
+		if user.Email == username {
+			return user, nil
+		}
+	}
+
+	return User{}, fmt.Errorf("user not found")
+}
+
+func (p *Provider) ValidateToken(ctx context.Context, standardClaims sso.StandardClaims, u User) error { // fired on VerifyHandler.
+	// your database and checks of blocked tokens...
+
+	// check for specific token ids.
+	p.mu.RLock()
+	_, tokenBlocked := p.invalidated[standardClaims.ID]
+	if !tokenBlocked {
+		// this will disallow refresh tokens with issuer as the blocked access token as well.
+		if standardClaims.Issuer != "" {
+			_, tokenBlocked = p.invalidated[standardClaims.Issuer]
+		}
+	}
+	p.mu.RUnlock()
+
+	if tokenBlocked {
+		return fmt.Errorf("token was invalidated")
+	}
+	//
+
+	// check all tokens issuet before the "InvalidateToken" method was fired for this user.
+	p.mu.RLock()
+	ts, oldUserBlocked := p.invalidatedAll[u.ID]
+	p.mu.RUnlock()
+
+	if oldUserBlocked && standardClaims.IssuedAt <= ts {
+		return fmt.Errorf("token was invalidated")
+	}
+	//
+
+	return nil // else valid.
+}
+
+func (p *Provider) InvalidateToken(ctx context.Context, standardClaims sso.StandardClaims, u User) error { // fired on SignoutHandler.
+	// invalidate this specific token.
+	p.mu.Lock()
+	p.invalidated[standardClaims.ID] = struct{}{}
+	p.mu.Unlock()
+
+	return nil
+}
+
+func (p *Provider) InvalidateTokens(ctx context.Context, u User) error { // fired on SignoutAllHandler.
+	// invalidate all previous tokens came from "u".
+	p.mu.Lock()
+	p.invalidatedAll[u.ID] = time.Now().Unix()
+	p.mu.Unlock()
+
+	return nil
+}
diff --git a/_examples/mvc/websocket-sso/views/layouts/main.html b/_examples/mvc/websocket-sso/views/layouts/main.html
new file mode 100644
index 00000000..c6efd56d
--- /dev/null
+++ b/_examples/mvc/websocket-sso/views/layouts/main.html
@@ -0,0 +1,30 @@
+
+
+
+    
+    
+    {{ if .Title }}{{ .Title }}{{ else }}Default Main Title{{ end }}
+
+
+
+    
+
{{ template "content" . }}
+ +
+ + \ No newline at end of file diff --git a/_examples/mvc/websocket-sso/views/partials/footer.html b/_examples/mvc/websocket-sso/views/partials/footer.html new file mode 100644 index 00000000..69b8f9b8 --- /dev/null +++ b/_examples/mvc/websocket-sso/views/partials/footer.html @@ -0,0 +1 @@ +Iris Web Framework © 2022 \ No newline at end of file diff --git a/_examples/mvc/websocket-sso/views/signin.html b/_examples/mvc/websocket-sso/views/signin.html new file mode 100644 index 00000000..57f7514c --- /dev/null +++ b/_examples/mvc/websocket-sso/views/signin.html @@ -0,0 +1,9 @@ +
+
+ + + + + +
+
\ No newline at end of file diff --git a/aliases.go b/aliases.go index da362586..6c18ebcb 100644 --- a/aliases.go +++ b/aliases.go @@ -660,8 +660,41 @@ const ( StatusNetworkReadTimeout = context.StatusNetworkReadTimeout ) -// StatusText returns a text for the HTTP status code. It returns the empty -// string if the code is unknown. -// -// Shortcut for core/router#StatusText. -var StatusText = context.StatusText +var ( + // StatusText returns a text for the HTTP status code. It returns the empty + // string if the code is unknown. + // + // Shortcut for core/router#StatusText. + StatusText = context.StatusText + // RegisterMethods adds custom http methods to the "AllMethods" list. + // Use it on initialization of your program. + // + // Shortcut for core/router#RegisterMethods. + RegisterMethods = router.RegisterMethods + + // WebDAVMethods contains a list of WebDAV HTTP Verbs. + // Register using RegiterMethods package-level function or + // through HandleMany party-level method. + WebDAVMethods = []string{ + MethodGet, + MethodHead, + MethodPatch, + MethodPut, + MethodPost, + MethodDelete, + MethodOptions, + MethodConnect, + MethodTrace, + "MKCOL", + "COPY", + "MOVE", + "LOCK", + "UNLOCK", + "PROPFIND", + "PROPPATCH", + "LINK", + "UNLINK", + "PURGE", + "VIEW", + } +) diff --git a/context/context.go b/context/context.go index ca479f36..e1944f4e 100644 --- a/context/context.go +++ b/context/context.go @@ -5314,7 +5314,11 @@ type SecureCookie interface { // // Example: https://github.com/kataras/iris/tree/master/_examples/cookies/securecookie func CookieEncoding(encoding SecureCookie, cookieNames ...string) CookieOption { - return func(_ *Context, c *http.Cookie, op uint8) { + if encoding == nil { + return func(_ *Context, _ *http.Cookie, _ uint8) {} + } + + return func(ctx *Context, c *http.Cookie, op uint8) { if op == OpCookieDel { return } @@ -5328,10 +5332,12 @@ func CookieEncoding(encoding SecureCookie, cookieNames ...string) CookieOption { // Should encode, it's a write to the client operation. newVal, err := encoding.Encode(c.Name, c.Value) if err != nil { + ctx.Application().Logger().Error(err) c.Value = "" } else { c.Value = newVal } + return case OpCookieGet: // Should decode, it's a read from the client operation. diff --git a/core/router/api_builder.go b/core/router/api_builder.go index aa220470..f0a14874 100644 --- a/core/router/api_builder.go +++ b/core/router/api_builder.go @@ -39,6 +39,13 @@ var AllMethods = []string{ http.MethodTrace, } +// RegisterMethods adds custom http methods to the "AllMethods" list. +// Use it on initialization of your program. +func RegisterMethods(newCustomHTTPVerbs ...string) { + newMethods := append(AllMethods, newCustomHTTPVerbs...) + AllMethods = removeDuplicates(newMethods) +} + // repository passed to all parties(subrouters), it's the object witch keeps // all the routes. type repository struct { diff --git a/go.mod b/go.mod index a1d99da5..30959e4c 100644 --- a/go.mod +++ b/go.mod @@ -18,13 +18,14 @@ require ( github.com/goccy/go-json v0.9.5 github.com/golang/snappy v0.0.4 github.com/google/uuid v1.3.0 + github.com/gorilla/securecookie v1.1.1 github.com/iris-contrib/httpexpect/v2 v2.3.1 github.com/iris-contrib/jade v1.1.4 github.com/iris-contrib/schema v0.0.6 github.com/json-iterator/go v1.1.12 github.com/kataras/blocks v0.0.5 github.com/kataras/golog v0.1.7 - github.com/kataras/jwt v0.1.5 + github.com/kataras/jwt v0.1.8 github.com/kataras/neffos v0.0.19 github.com/kataras/pio v0.0.10 github.com/kataras/sitemap v0.0.5 diff --git a/go.sum b/go.sum index 379cb4aa..d117201f 100644 --- a/go.sum +++ b/go.sum @@ -85,6 +85,8 @@ github.com/gopherjs/gopherjs v0.0.0-20220221023154-0b2280d3ff96 h1:QJq7UBOuoynsy github.com/gopherjs/gopherjs v0.0.0-20220221023154-0b2280d3ff96/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= +github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= @@ -109,8 +111,8 @@ github.com/kataras/blocks v0.0.5 h1:jFrsHEDfXZhHTbhkNWgMgpfEQNj1Bwr1IYEYZ9Xxoxg= github.com/kataras/blocks v0.0.5/go.mod h1:kcJIuvuA8QmGKFLHIZHdCAPCjcE85IhttzXd6W+ayfE= github.com/kataras/golog v0.1.7 h1:0TY5tHn5L5DlRIikepcaRR/6oInIr9AiWsxzt0vvlBE= github.com/kataras/golog v0.1.7/go.mod h1:jOSQ+C5fUqsNSwurB/oAHq1IFSb0KI3l6GMa7xB6dZA= -github.com/kataras/jwt v0.1.5 h1:3UScbsLyo7fsKP6IRPzySH0mcAdTsEu104iWMjGqEyE= -github.com/kataras/jwt v0.1.5/go.mod h1:4ss3aGJi58q3YGmhLUiOvNJnL7UlTXD7+Wf+skgsTmQ= +github.com/kataras/jwt v0.1.8 h1:u71baOsYD22HWeSOg32tCHbczPjdCk7V4MMeJqTtmGk= +github.com/kataras/jwt v0.1.8/go.mod h1:Q5j2IkcIHnfwy+oNY3TVWuEBJNw0ADgCcXK9CaZwV4o= github.com/kataras/neffos v0.0.19 h1:j3jp/hzvGFQjnkkLWGNjae5qMSdpMYr66Lxgf8CgcAw= github.com/kataras/neffos v0.0.19/go.mod h1:CAAuFqHYX5t0//LLMiVWooOSp5FPeBRD8cn/892P1JE= github.com/kataras/pio v0.0.10 h1:b0qtPUqOpM2O+bqa5wr2O6dN4cQNwSmFd6HQqgVae0g= diff --git a/sso/configuration.go b/sso/configuration.go new file mode 100644 index 00000000..e3100788 --- /dev/null +++ b/sso/configuration.go @@ -0,0 +1,162 @@ +//go:build go1.18 + +package sso + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/gorilla/securecookie" + "github.com/kataras/jwt" + "gopkg.in/yaml.v3" +) + +const ( + KIDAccess = "IRIS_SSO_ACCESS" + KIDRefresh = "IRIS_SSO_REFRESH" +) + +type ( + Configuration struct { + Cookie CookieConfiguration `json:"cookie" yaml:"Cookie" toml:"Cookie" ini:"cookie"` + // keep it to always renew the refresh token. RefreshStrategy string `json:"refresh_strategy" yaml:"RefreshStrategy" toml:"RefreshStrategy" ini:"refresh_strategy"` + Keys jwt.KeysConfiguration `json:"keys" yaml:"Keys" toml:"Keys" ini:"keys"` + } + + CookieConfiguration struct { + Name string `json:"cookie" yaml:"Name" toml:"Name" ini:"name"` + Hash string `json:"hash" yaml:"Hash" toml:"Hash" ini:"hash"` + Block string `json:"block" yaml:"Block" toml:"Block" ini:"block"` + } +) + +func (c *Configuration) validate() (jwt.Keys, error) { + if c.Cookie.Name != "" { + if c.Cookie.Hash == "" || c.Cookie.Block == "" { + return nil, fmt.Errorf("cookie block and cookie hash are required for security reasons when cookie is used") + } + } + + keys, err := c.Keys.Load() + if err != nil { + return nil, fmt.Errorf("sso: %w", err) + } + + if _, ok := keys[KIDAccess]; !ok { + return nil, fmt.Errorf("sso: %s access token is missing from the configuration", KIDAccess) + } + + // Let's keep refresh optional. + // if _, ok := keys[KIDRefresh]; !ok { + // return nil, fmt.Errorf("sso: %s refresh token is missing from the configuration", KIDRefresh) + // } + return keys, nil +} + +// BindRandom binds the "c" configuration to random values for keys and cookie security. +// Keys will not be persisted between restarts, +// a more persistent storage should be considered for production applications. +func (c *Configuration) BindRandom() error { + accessPublic, accessPrivate, err := jwt.GenerateEdDSA() + if err != nil { + return err + } + + refreshPublic, refreshPrivate, err := jwt.GenerateEdDSA() + if err != nil { + return err + } + + *c = Configuration{ + Cookie: CookieConfiguration{ + Name: "iris_sso", + Hash: string(securecookie.GenerateRandomKey(64)), + Block: string(securecookie.GenerateRandomKey(32)), + }, + Keys: jwt.KeysConfiguration{ + { + ID: KIDAccess, + Alg: jwt.EdDSA.Name(), + MaxAge: 2 * time.Hour, + Public: string(accessPublic), + Private: string(accessPrivate), + }, + { + ID: KIDRefresh, + Alg: jwt.EdDSA.Name(), + MaxAge: 720 * time.Hour, + Public: string(refreshPublic), + Private: string(refreshPrivate), + EncryptionKey: string(jwt.MustGenerateRandom(32)), + }, + }, + } + + return nil +} + +func (c *Configuration) BindFile(filename string) error { + switch filepath.Ext(filename) { + case ".json": + contents, err := os.ReadFile(filename) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + generatedConfig := MustGenerateConfiguration() + if generatedYAML, gErr := generatedConfig.ToJSON(); gErr == nil { + err = fmt.Errorf("%w: example:\n\n%s", err, generatedYAML) + } + } + return err + } + + return json.Unmarshal(contents, c) + default: + contents, err := os.ReadFile(filename) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + generatedConfig := MustGenerateConfiguration() + if generatedYAML, gErr := generatedConfig.ToYAML(); gErr == nil { + err = fmt.Errorf("%w: example:\n\n%s", err, generatedYAML) + } + } + return err + } + + return yaml.Unmarshal(contents, c) + } + +} + +func (c *Configuration) ToYAML() ([]byte, error) { + return yaml.Marshal(c) +} + +func (c *Configuration) ToJSON() ([]byte, error) { + return json.Marshal(c) +} + +func MustGenerateConfiguration() (c Configuration) { + if err := c.BindRandom(); err != nil { + panic(err) + } + + return +} + +func LoadConfiguration(filename string) (c Configuration, err error) { + err = c.BindFile(filename) + return +} + +func MustLoadConfiguration(filename string) Configuration { + c, err := LoadConfiguration(filename) + if err != nil { + panic(err) + } + + return c +} diff --git a/sso/provider.go b/sso/provider.go new file mode 100644 index 00000000..816c2d79 --- /dev/null +++ b/sso/provider.go @@ -0,0 +1,83 @@ +//go:build go1.18 + +package sso + +import ( + stdContext "context" + "fmt" + + "github.com/kataras/iris/v12/context" + "github.com/kataras/iris/v12/middleware/jwt" + "github.com/kataras/iris/v12/x/errors" +) + +type VerifiedToken = jwt.VerifiedToken + +type Provider[T User] interface { // A provider can implement Transformer and ErrorHandler as well. + Signin(ctx stdContext.Context, username, password string) (T, error) + + // We could do this instead of transformer below but let's keep separated logic methods: + // ValidateToken(ctx context.Context, tok *VerifiedToken, t *T) error + ValidateToken(ctx stdContext.Context, standardClaims StandardClaims, t T) error + + InvalidateToken(ctx stdContext.Context, standardClaims StandardClaims, t T) error + InvalidateTokens(ctx stdContext.Context, t T) error +} + +// ClaimsProvider is an optional interface, which may not be used at all. +// If completed by a Provider, it signs the jwt token +// using these claims to each of the following token types. +type ClaimsProvider interface { + GetAccessTokenClaims() StandardClaims + GetRefreshTokenClaims(accessClaims StandardClaims) StandardClaims +} + +type Transformer[T User] interface { + Transform(ctx stdContext.Context, tok *VerifiedToken) (T, error) +} + +type TransformerFunc[T User] func(ctx stdContext.Context, tok *VerifiedToken) (T, error) + +func (fn TransformerFunc[T]) Transform(ctx stdContext.Context, tok *VerifiedToken) (T, error) { + return fn(ctx, tok) +} + +type ErrorHandler interface { + InvalidArgument(ctx *context.Context, err error) + Unauthenticated(ctx *context.Context, err error) +} + +type DefaultErrorHandler struct{} + +func (e *DefaultErrorHandler) InvalidArgument(ctx *context.Context, err error) { + errors.InvalidArgument.Details(ctx, "unable to parse body", err.Error()) +} + +func (e *DefaultErrorHandler) Unauthenticated(ctx *context.Context, err error) { + errors.Unauthenticated.Err(ctx, err) +} + +type provider[T User] struct{} + +func newProvider[T User]() *provider[T] { + return new(provider[T]) +} + +func (p *provider[T]) Signin(ctx stdContext.Context, username, password string) (T, error) { // fired on SigninHandler. + // your database... + var t T + return t, fmt.Errorf("user not found") +} + +func (p *provider[T]) ValidateToken(ctx stdContext.Context, standardClaims StandardClaims, t T) error { // fired on VerifyHandler. + // your database and checks of blocked tokens... + return nil +} + +func (p *provider[T]) InvalidateToken(ctx stdContext.Context, standardClaims StandardClaims, t T) error { // fired on SignoutHandler. + return nil +} + +func (p *provider[T]) InvalidateTokens(ctx stdContext.Context, t T) error { // fired on SignoutAllHandler. + return nil +} diff --git a/sso/sso.go b/sso/sso.go new file mode 100644 index 00000000..047b4756 --- /dev/null +++ b/sso/sso.go @@ -0,0 +1,568 @@ +//go:build go1.18 + +package sso + +import ( + stdContext "context" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/kataras/iris/v12/context" + + "github.com/google/uuid" + "github.com/gorilla/securecookie" + "github.com/kataras/jwt" +) + +type ( + SSO[T User] struct { + config Configuration + + keys jwt.Keys + securecookie context.SecureCookie + + providers []Provider[T] // at least one. + errorHandler ErrorHandler + transformer Transformer[T] + claimsProvider ClaimsProvider + refreshEnabled bool // if KIDRefresh exists in keys. + } + + TVerify[T User] func(t T) error + + SigninRequest struct { + Username string `json:"username" form:"username,omitempty"` // username OR email, username has priority over email. + Email string `json:"email" form:"email,omitempty"` // email OR username. + Password string `json:"password" form:"password"` + } + + SigninResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + } + + RefreshRequest struct { + RefreshToken string `json:"refresh_token"` + } +) + +func MustLoad[T User](filename string) *SSO[T] { + var config Configuration + if err := config.BindFile(filename); err != nil { + panic(err) + } + + s, err := New[T](config) + if err != nil { + panic(err) + } + + return s +} + +func Must[T User](s *SSO[T], err error) *SSO[T] { + if err != nil { + panic(err) + } + + return s +} + +func New[T User](config Configuration) (*SSO[T], error) { + keys, err := config.validate() + if err != nil { + return nil, err + } + _, refreshEnabled := keys[KIDRefresh] + + s := &SSO[T]{ + config: config, + keys: keys, + securecookie: securecookie.New([]byte(config.Cookie.Hash), []byte(config.Cookie.Block)), + refreshEnabled: refreshEnabled, + // providers: []Provider[T]{newProvider[T]()}, + errorHandler: new(DefaultErrorHandler), + } + + return s, nil +} + +func (s *SSO[T]) WithProviderAndErrorHandler(provider Provider[T], errHandler ErrorHandler) *SSO[T] { + if provider != nil { + for i := range s.providers { + s.providers[i] = nil + } + s.providers = nil + + s.providers = make([]Provider[T], 0, 1) + s.AddProvider(provider) + } + + if errHandler != nil { + s.SetErrorHandler(errHandler) + } + + return s +} + +func (s *SSO[T]) AddProvider(providers ...Provider[T]) *SSO[T] { + // defaultProviderTypename := strings.Replace(fmt.Sprintf("%T", s), "SSO", "provider", 1) + // if len(s.providers) == 1 && fmt.Sprintf("%T", s.providers[0]) == defaultProviderTypename { + // s.providers = append(s.providers[1:], p...) + + // A provider can also implement both transformer and + // error handler if that's the design option of the end-developer. + for _, p := range providers { + if s.transformer == nil { + if transformer, ok := p.(Transformer[T]); ok { + s.SetTransformer(transformer) + } + } + + if errHandler, ok := p.(ErrorHandler); ok { + s.SetErrorHandler(errHandler) + } + + if s.claimsProvider == nil { + if claimsProvider, ok := p.(ClaimsProvider); ok { + s.claimsProvider = claimsProvider + } + } + } + + s.providers = append(s.providers, providers...) + return s +} + +func (s *SSO[T]) SetErrorHandler(errHandler ErrorHandler) *SSO[T] { + s.errorHandler = errHandler + return s +} + +func (s *SSO[T]) SetTransformer(transformer Transformer[T]) *SSO[T] { + s.transformer = transformer + return s +} + +func (s *SSO[T]) SetTransformerFunc(transfermerFunc func(ctx stdContext.Context, tok *VerifiedToken) (T, error)) *SSO[T] { + s.transformer = TransformerFunc[T](transfermerFunc) + return s +} + +func (s *SSO[T]) Signin(ctx stdContext.Context, username, password string) ([]byte, []byte, error) { + var t T + + // get "t" from a valid provider. + if n := len(s.providers); n > 0 { + for i := 0; i < n; i++ { + p := s.providers[i] + + v, err := p.Signin(ctx, username, password) + if err != nil { + if i == n-1 { // last provider errored. + return nil, nil, fmt.Errorf("sso: signin: %w", err) + } + // keep searching. + continue + } + + // found. + t = v + break + } + } else { + return nil, nil, fmt.Errorf("sso: signin: no provider") + } + + // sign the tokens. + accessToken, refreshToken, err := s.sign(t) + if err != nil { + return nil, nil, fmt.Errorf("sso: signin: %w", err) + } + + return accessToken, refreshToken, nil +} + +func (s *SSO[T]) sign(t T) ([]byte, []byte, error) { + // sign the tokens. + var ( + accessStdClaims StandardClaims + refreshStdClaims StandardClaims + ) + + if s.claimsProvider != nil { + accessStdClaims = s.claimsProvider.GetAccessTokenClaims() + refreshStdClaims = s.claimsProvider.GetRefreshTokenClaims(accessStdClaims) + } + + iat := jwt.Clock().Unix() + + if accessStdClaims.IssuedAt == 0 { + accessStdClaims.IssuedAt = iat + } + + if accessStdClaims.ID == "" { + accessStdClaims.ID = uuid.NewString() + } + + if refreshStdClaims.IssuedAt == 0 { + refreshStdClaims.IssuedAt = iat + } + + if refreshStdClaims.ID == "" { + refreshStdClaims.ID = uuid.NewString() + } + + if refreshStdClaims.OriginID == "" { + // keep a reference of the access token the refresh token is created, + // if that access token is invalidated then + // its refresh token should be too so the user can force-login. + refreshStdClaims.OriginID = accessStdClaims.ID + } + + accessToken, err := s.keys.SignToken(KIDAccess, t, accessStdClaims) + if err != nil { + return nil, nil, fmt.Errorf("access: %w", err) + } + + var refreshToken []byte + if s.refreshEnabled { + refreshToken, err = s.keys.SignToken(KIDRefresh, t, refreshStdClaims) + if err != nil { + return nil, nil, fmt.Errorf("refresh: %w", err) + } + } + + return accessToken, refreshToken, nil +} + +func (s *SSO[T]) SigninHandler(ctx *context.Context) { + // No, let the developer decide it based on a middleware, e.g. iris.LimitRequestBodySize. + // ctx.SetMaxRequestBodySize(s.maxRequestBodySize) + + var ( + req SigninRequest + err error + ) + + switch ctx.GetContentTypeRequested() { + case context.ContentFormHeaderValue, context.ContentFormMultipartHeaderValue: + err = ctx.ReadForm(&req) + default: + err = ctx.ReadJSON(&req) + } + + if err != nil { + s.errorHandler.InvalidArgument(ctx, err) + return + } + + if req.Username == "" { + req.Username = req.Email + } + + accessTokenBytes, refreshTokenBytes, err := s.Signin(ctx, req.Username, req.Password) + if err != nil { + s.tryRemoveCookie(ctx) // remove cookie on invalidated. + + s.errorHandler.Unauthenticated(ctx, err) + return + } + accessToken := jwt.BytesToString(accessTokenBytes) + refreshToken := jwt.BytesToString(refreshTokenBytes) + + s.trySetCookie(ctx, accessToken) + + resp := SigninResponse{ + AccessToken: accessToken, + RefreshToken: refreshToken, + } + ctx.JSON(resp) +} + +func (s *SSO[T]) Verify(ctx stdContext.Context, token []byte) (T, StandardClaims, error) { + t, claims, err := s.verify(ctx, token) + if err != nil { + return t, StandardClaims{}, fmt.Errorf("sso: verify: %w", err) + } + + return t, claims, nil +} + +func (s *SSO[T]) verify(ctx stdContext.Context, token []byte) (T, StandardClaims, error) { + var t T + + if len(token) == 0 { // should never happen at this state. + return t, StandardClaims{}, jwt.ErrMissing + } + + verifiedToken, err := jwt.VerifyWithHeaderValidator(nil, nil, token, s.keys.ValidateHeader, jwt.Leeway(time.Minute)) + if err != nil { + return t, StandardClaims{}, err + } + + if s.transformer != nil { + if t, err = s.transformer.Transform(ctx, verifiedToken); err != nil { + return t, StandardClaims{}, err + } + } else { + if err = verifiedToken.Claims(&t); err != nil { + return t, StandardClaims{}, err + } + } + + standardClaims := verifiedToken.StandardClaims + + if n := len(s.providers); n > 0 { + for i := 0; i < n; i++ { + p := s.providers[i] + + err := p.ValidateToken(ctx, standardClaims, t) + if err != nil { + if i == n-1 { // last provider errored. + return t, StandardClaims{}, err + } + // keep searching. + continue + } + + // token is allowed. + break + } + } else { + // return t, StandardClaims{}, fmt.Errorf("no provider") + } + + return t, standardClaims, nil +} + +/* Good idea but not practical. +func Transform[T User, V User](transformer Transformer[T, V]) context.Handler { + return func(ctx *context.Context) { + existingUserValue := GetUser[T](ctx) + newUserValue, err := transformer.Transform(ctx, existingUserValue) + if err != nil { + ctx.SetErr(err) + return + } + + ctx.Values().Set(userContextKey, newUserValue) + ctx.Next() + } +} +*/ + +func (s *SSO[T]) VerifyHandler(verifyFuncs ...TVerify[T]) context.Handler { + return func(ctx *context.Context) { + accessToken := s.extractAccessToken(ctx) + + if accessToken == "" { // if empty, fire 401. + s.errorHandler.Unauthenticated(ctx, jwt.ErrMissing) + return + } + + t, claims, err := s.Verify(ctx, []byte(accessToken)) + if err != nil { + s.errorHandler.Unauthenticated(ctx, err) + return + } + + for _, verify := range verifyFuncs { + if verify == nil { + continue + } + + if err = verify(t); err != nil { + err = fmt.Errorf("sso: verify: %v", err) + s.errorHandler.Unauthenticated(ctx, err) + return + } + } + + ctx.SetUser(t) + + // store the user to the request. + ctx.Values().Set(accessTokenContextKey, accessToken) + + ctx.Values().Set(userContextKey, t) + ctx.Values().Set(standardClaimsContextKey, claims) + + ctx.Next() + } +} + +func (s *SSO[T]) extractAccessToken(ctx *context.Context) string { + // first try from authorization: bearer header. + accessToken := extractTokenFromHeader(ctx) + + // then if no header, try try extract from cookie. + if accessToken == "" { + if cookieName := s.config.Cookie.Name; cookieName != "" { + accessToken = ctx.GetCookie(cookieName, context.CookieEncoding(s.securecookie)) + } + } + + return accessToken +} + +func (s *SSO[T]) Refresh(ctx stdContext.Context, refreshToken []byte) ([]byte, []byte, error) { + if !s.refreshEnabled { + return nil, nil, fmt.Errorf("sso: refresh: disabled") + } + + t, _, err := s.verify(ctx, refreshToken) + if err != nil { + return nil, nil, fmt.Errorf("sso: refresh: %w", err) + } + + // refresh the tokens, both refresh & access tokens will be renew to prevent + // malicious 😈 users that may hold a refresh token. + accessTok, refreshTok, err := s.sign(t) + if err != nil { + return nil, nil, fmt.Errorf("sso: refresh: %w", err) + } + + return accessTok, refreshTok, nil +} + +func (s *SSO[T]) RefreshHandler(ctx *context.Context) { + var req RefreshRequest + err := ctx.ReadJSON(&req) + if err != nil { + s.errorHandler.InvalidArgument(ctx, err) + return + } + + accessTokenBytes, refreshTokenBytes, err := s.Refresh(ctx, []byte(req.RefreshToken)) + if err != nil { + // s.tryRemoveCookie(ctx) + s.errorHandler.Unauthenticated(ctx, err) + return + } + + accessToken := jwt.BytesToString(accessTokenBytes) + refreshToken := jwt.BytesToString(refreshTokenBytes) + + s.trySetCookie(ctx, accessToken) + + resp := SigninResponse{ + AccessToken: accessToken, + RefreshToken: refreshToken, + } + ctx.JSON(resp) +} + +func (s *SSO[T]) Signout(ctx stdContext.Context, token []byte, all bool) error { + t, standardClaims, err := s.verify(ctx, token) + if err != nil { + return fmt.Errorf("sso: signout: verify: %w", err) + } + + for i, n := 0, len(s.providers)-1; i <= n; i++ { + p := s.providers[i] + + if all { + err = p.InvalidateTokens(ctx, t) + } else { + err = p.InvalidateToken(ctx, standardClaims, t) + } + + if err != nil { + if i == n { // last provider errored. + return err + } + // keep trying. + continue + } + + // token is marked as invalidated by a provider. + break + } + + return nil +} + +func (s *SSO[T]) SignoutHandler(ctx *context.Context) { + s.signoutHandler(ctx, false) +} + +func (s *SSO[T]) SignoutAllHandler(ctx *context.Context) { + s.signoutHandler(ctx, true) +} + +func (s *SSO[T]) signoutHandler(ctx *context.Context, all bool) { + accessToken := s.extractAccessToken(ctx) + if accessToken == "" { + s.errorHandler.Unauthenticated(ctx, jwt.ErrMissing) + return + } + + err := s.Signout(ctx, []byte(accessToken), all) + if err != nil { + s.errorHandler.Unauthenticated(ctx, err) + return + } + + s.tryRemoveCookie(ctx) + + ctx.SetUser(nil) + + ctx.Values().Remove(accessTokenContextKey) + ctx.Values().Remove(userContextKey) + ctx.Values().Remove(standardClaimsContextKey) +} + +var headerKeys = [...]string{ + "Authorization", + "X-Authorization", +} + +func extractTokenFromHeader(ctx *context.Context) string { + for _, headerKey := range headerKeys { + headerValue := ctx.GetHeader(headerKey) + if headerValue == "" { + continue + } + + // pure check: authorization header format must be Bearer {token} + authHeaderParts := strings.Split(headerValue, " ") + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + continue + } + + return authHeaderParts[1] + } + + return "" +} + +func (s *SSO[T]) trySetCookie(ctx *context.Context, accessToken string) { + if cookieName := s.config.Cookie.Name; cookieName != "" { + maxAge := s.keys[KIDAccess].MaxAge + if maxAge == 0 { + maxAge = context.SetCookieKVExpiration + } + + cookie := &http.Cookie{ + Path: "/", + Name: cookieName, + Value: url.QueryEscape(accessToken), + HttpOnly: true, + Domain: ctx.Domain(), + SameSite: http.SameSiteLaxMode, + Expires: time.Now().Add(maxAge), + MaxAge: int(maxAge.Seconds()), + } + + ctx.SetCookie(cookie, context.CookieEncoding(s.securecookie)) + } +} + +func (s *SSO[T]) tryRemoveCookie(ctx *context.Context) { + if cookieName := s.config.Cookie.Name; cookieName != "" { + ctx.RemoveCookie(cookieName) + } +} diff --git a/sso/user.go b/sso/user.go new file mode 100644 index 00000000..5cd0b010 --- /dev/null +++ b/sso/user.go @@ -0,0 +1,53 @@ +//go:build go1.18 + +package sso + +import ( + "github.com/kataras/iris/v12/context" + + "github.com/kataras/jwt" +) + +type ( + StandardClaims = jwt.Claims + User = interface{} // any type. +) + +const accessTokenContextKey = "iris.sso.context.access_token" + +func GetAccessToken(ctx *context.Context) string { + return ctx.Values().GetString(accessTokenContextKey) +} + +const standardClaimsContextKey = "iris.sso.context.standard_claims" + +func GetStandardClaims(ctx *context.Context) StandardClaims { + if v := ctx.Values().Get(standardClaimsContextKey); v != nil { + if c, ok := v.(StandardClaims); ok { + return c + } + } + + return StandardClaims{} +} + +func (s *SSO[T]) GetStandardClaims(ctx *context.Context) StandardClaims { + return GetStandardClaims(ctx) +} + +const userContextKey = "iris.sso.context.user" + +func GetUser[T User](ctx *context.Context) T { + if v := ctx.Values().Get(userContextKey); v != nil { + if t, ok := v.(T); ok { + return t + } + } + + var empty T + return empty +} + +func (s *SSO[T]) GetUser(ctx *context.Context) T { + return GetUser[T](ctx) +} diff --git a/x/client/client_test.go b/x/client/client_test.go index 6d4563d4..c835b664 100644 --- a/x/client/client_test.go +++ b/x/client/client_test.go @@ -1,15 +1,15 @@ package client import ( - stdContext "context" + "context" + "encoding/json" + "net/http" + "net/http/httptest" "reflect" "testing" - - "github.com/kataras/iris/v12" - "github.com/kataras/iris/v12/httptest" ) -var defaultCtx = stdContext.Background() +var defaultCtx = context.Background() type testValue struct { Firstname string `json:"firstname"` @@ -18,40 +18,41 @@ type testValue struct { func TestClientJSON(t *testing.T) { expectedJSON := testValue{Firstname: "Makis"} - app := iris.New() - app.Get("/", sendJSON(t, expectedJSON)) + app := http.NewServeMux() + app.HandleFunc("/send", sendJSON(t, expectedJSON)) var irisGotJSON testValue - app.Post("/", readJSON(t, &irisGotJSON, &expectedJSON)) + app.HandleFunc("/read", readJSON(t, &irisGotJSON, &expectedJSON)) - srv := httptest.NewServer(t, app) + srv := httptest.NewServer(app) client := New(BaseURL(srv.URL)) // Test ReadJSON (read from server). var got testValue - if err := client.ReadJSON(defaultCtx, &got, iris.MethodGet, "/", nil); err != nil { + if err := client.ReadJSON(defaultCtx, &got, http.MethodGet, "/send", nil); err != nil { t.Fatal(err) } // Test JSON (send to server). - resp, err := client.JSON(defaultCtx, iris.MethodPost, "/", expectedJSON) + resp, err := client.JSON(defaultCtx, http.MethodPost, "/read", expectedJSON) if err != nil { t.Fatal(err) } client.DrainResponseBody(resp) } -func sendJSON(t *testing.T, v interface{}) iris.Handler { - return func(ctx iris.Context) { - if _, err := ctx.JSON(v); err != nil { +func sendJSON(t *testing.T, v interface{}) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + if err := json.NewEncoder(w).Encode(v); err != nil { t.Fatal(err) } } } -func readJSON(t *testing.T, ptr interface{}, expected interface{}) iris.Handler { - return func(ctx iris.Context) { - if err := ctx.ReadJSON(ptr); err != nil { +func readJSON(t *testing.T, ptr interface{}, expected interface{}) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(ptr); err != nil { t.Fatal(err) } diff --git a/x/errors/errors.go b/x/errors/errors.go index 5534d3b3..be9ee8c3 100644 --- a/x/errors/errors.go +++ b/x/errors/errors.go @@ -3,8 +3,9 @@ package errors import ( "encoding/json" "fmt" + "net/http" - "github.com/kataras/iris/v12" + "github.com/kataras/iris/v12/context" "github.com/kataras/iris/v12/x/client" ) @@ -13,11 +14,11 @@ import ( // // See "OnErrorLog" variable to change the way an error is logged, // by default the error is logged using the Application's Logger's Error method. -type LogErrorFunc = func(ctx iris.Context, err error) +type LogErrorFunc = func(ctx *context.Context, err error) // LogError can be modified to customize the way an error is logged to the server (most common: internal server errors, database errors et.c.). // Can be used to customize the error logging, e.g. using Sentry (cloud-based error console). -var LogError LogErrorFunc = func(ctx iris.Context, err error) { +var LogError LogErrorFunc = func(ctx *context.Context, err error) { ctx.Application().Logger().Error(err) } @@ -56,7 +57,7 @@ var errorCodeMap = make(map[ErrorCodeName]ErrorCode) // // Example: // var ( -// NotFound = errors.E("NOT_FOUND", iris.StatusNotFound) +// NotFound = errors.E("NOT_FOUND", http.StatusNotFound) // ) // ... // NotFound.Details(ctx, "resource not found", "user with id: %q was not found", userID) @@ -96,57 +97,57 @@ func RegisterErrorCodeMap(errorMap map[ErrorCodeName]int) { // List of default error codes a server should follow and send back to the client. var ( - Cancelled ErrorCodeName = E("CANCELLED", iris.StatusTokenRequired) - Unknown ErrorCodeName = E("UNKNOWN", iris.StatusInternalServerError) - InvalidArgument ErrorCodeName = E("INVALID_ARGUMENT", iris.StatusBadRequest) - DeadlineExceeded ErrorCodeName = E("DEADLINE_EXCEEDED", iris.StatusGatewayTimeout) - NotFound ErrorCodeName = E("NOT_FOUND", iris.StatusNotFound) - AlreadyExists ErrorCodeName = E("ALREADY_EXISTS", iris.StatusConflict) - PermissionDenied ErrorCodeName = E("PERMISSION_DENIED", iris.StatusForbidden) - Unauthenticated ErrorCodeName = E("UNAUTHENTICATED", iris.StatusUnauthorized) - ResourceExhausted ErrorCodeName = E("RESOURCE_EXHAUSTED", iris.StatusTooManyRequests) - FailedPrecondition ErrorCodeName = E("FAILED_PRECONDITION", iris.StatusBadRequest) - Aborted ErrorCodeName = E("ABORTED", iris.StatusConflict) - OutOfRange ErrorCodeName = E("OUT_OF_RANGE", iris.StatusBadRequest) - Unimplemented ErrorCodeName = E("UNIMPLEMENTED", iris.StatusNotImplemented) - Internal ErrorCodeName = E("INTERNAL", iris.StatusInternalServerError) - Unavailable ErrorCodeName = E("UNAVAILABLE", iris.StatusServiceUnavailable) - DataLoss ErrorCodeName = E("DATA_LOSS", iris.StatusInternalServerError) + Cancelled ErrorCodeName = E("CANCELLED", context.StatusTokenRequired) + Unknown ErrorCodeName = E("UNKNOWN", http.StatusInternalServerError) + InvalidArgument ErrorCodeName = E("INVALID_ARGUMENT", http.StatusBadRequest) + DeadlineExceeded ErrorCodeName = E("DEADLINE_EXCEEDED", http.StatusGatewayTimeout) + NotFound ErrorCodeName = E("NOT_FOUND", http.StatusNotFound) + AlreadyExists ErrorCodeName = E("ALREADY_EXISTS", http.StatusConflict) + PermissionDenied ErrorCodeName = E("PERMISSION_DENIED", http.StatusForbidden) + Unauthenticated ErrorCodeName = E("UNAUTHENTICATED", http.StatusUnauthorized) + ResourceExhausted ErrorCodeName = E("RESOURCE_EXHAUSTED", http.StatusTooManyRequests) + FailedPrecondition ErrorCodeName = E("FAILED_PRECONDITION", http.StatusBadRequest) + Aborted ErrorCodeName = E("ABORTED", http.StatusConflict) + OutOfRange ErrorCodeName = E("OUT_OF_RANGE", http.StatusBadRequest) + Unimplemented ErrorCodeName = E("UNIMPLEMENTED", http.StatusNotImplemented) + Internal ErrorCodeName = E("INTERNAL", http.StatusInternalServerError) + Unavailable ErrorCodeName = E("UNAVAILABLE", http.StatusServiceUnavailable) + DataLoss ErrorCodeName = E("DATA_LOSS", http.StatusInternalServerError) ) // Message sends an error with a simple message to the client. -func (e ErrorCodeName) Message(ctx iris.Context, format string, args ...interface{}) { +func (e ErrorCodeName) Message(ctx *context.Context, format string, args ...interface{}) { fail(ctx, e, sprintf(format, args...), "", nil, nil) } // Details sends an error with a message and details to the client. -func (e ErrorCodeName) Details(ctx iris.Context, msg, details string, detailsArgs ...interface{}) { +func (e ErrorCodeName) Details(ctx *context.Context, msg, details string, detailsArgs ...interface{}) { fail(ctx, e, msg, sprintf(details, detailsArgs...), nil, nil) } // Data sends an error with a message and json data to the client. -func (e ErrorCodeName) Data(ctx iris.Context, msg string, data interface{}) { +func (e ErrorCodeName) Data(ctx *context.Context, msg string, data interface{}) { fail(ctx, e, msg, "", nil, data) } // DataWithDetails sends an error with a message, details and json data to the client. -func (e ErrorCodeName) DataWithDetails(ctx iris.Context, msg, details string, data interface{}) { +func (e ErrorCodeName) DataWithDetails(ctx *context.Context, msg, details string, data interface{}) { fail(ctx, e, msg, details, nil, data) } // Validation sends an error which renders the invalid fields to the client. -func (e ErrorCodeName) Validation(ctx iris.Context, validationErrors ...ValidationError) { +func (e ErrorCodeName) Validation(ctx *context.Context, validationErrors ...ValidationError) { e.validation(ctx, validationErrors) } -func (e ErrorCodeName) validation(ctx iris.Context, validationErrors interface{}) { +func (e ErrorCodeName) validation(ctx *context.Context, validationErrors interface{}) { fail(ctx, e, "validation failure", "fields were invalid", validationErrors, nil) } // Err sends the error's text as a message to the client. // In exception, if the given "err" is a type of validation error // then the Validation method is called instead. -func (e ErrorCodeName) Err(ctx iris.Context, err error) { +func (e ErrorCodeName) Err(ctx *context.Context, err error) { if err == nil { return } @@ -163,7 +164,7 @@ func (e ErrorCodeName) Err(ctx iris.Context, err error) { // error using the "LogError" package-level function, which can be customized. // // See "LogErr" too. -func (e ErrorCodeName) Log(ctx iris.Context, format string, args ...interface{}) { +func (e ErrorCodeName) Log(ctx *context.Context, format string, args ...interface{}) { if SkipCanceled { if ctx.IsCanceled() { return @@ -171,7 +172,7 @@ func (e ErrorCodeName) Log(ctx iris.Context, format string, args ...interface{}) for _, arg := range args { if err, ok := arg.(error); ok { - if iris.IsErrCanceled(err) { + if context.IsErrCanceled(err) { return } } @@ -184,8 +185,8 @@ func (e ErrorCodeName) Log(ctx iris.Context, format string, args ...interface{}) // LogErr sends the given "err" as message to the client and prints that // error to using the "LogError" package-level function, which can be customized. -func (e ErrorCodeName) LogErr(ctx iris.Context, err error) { - if SkipCanceled && (ctx.IsCanceled() || iris.IsErrCanceled(err)) { +func (e ErrorCodeName) LogErr(ctx *context.Context, err error) { + if SkipCanceled && (ctx.IsCanceled() || context.IsErrCanceled(err)) { return } @@ -204,7 +205,7 @@ func (e ErrorCodeName) LogErr(ctx iris.Context, err error) { // the error will be sent using the "Internal.LogErr" method which sends // HTTP internal server error to the end-client and // prints the "err" using the "LogError" package-level function. -func HandleAPIError(ctx iris.Context, err error) { +func HandleAPIError(ctx *context.Context, err error) { // Error expected and came from the external server, // save its body so we can forward it to the end-client. if apiErr, ok := client.GetError(err); ok { @@ -228,7 +229,7 @@ var ( // The server fails to send an error on two cases: // 1. when the provided error code name is not registered (the error value is the ErrUnexpectedErrorCode) // 2. when the error contains data but cannot be encoded to json (the value of the error is the result error of json.Marshal). - ErrUnexpected = E("UNEXPECTED_ERROR", iris.StatusInternalServerError) + ErrUnexpected = E("UNEXPECTED_ERROR", http.StatusInternalServerError) // ErrUnexpectedErrorCode is the error which logged // when the given error code name is not registered. ErrUnexpectedErrorCode = New("unexpected error code name") @@ -261,13 +262,13 @@ func (err Error) Error() string { } if err.ErrorCode.Status <= 0 { - err.ErrorCode.Status = iris.StatusInternalServerError + err.ErrorCode.Status = http.StatusInternalServerError } return sprintf("iris http wire error: canonical name: %s, http status code: %d, message: %s, details: %s", err.ErrorCode.CanonicalName, err.ErrorCode.Status, err.Message, err.Details) } -func fail(ctx iris.Context, codeName ErrorCodeName, msg, details string, validationErrors interface{}, dataValue interface{}) { +func fail(ctx *context.Context, codeName ErrorCodeName, msg, details string, validationErrors interface{}, dataValue interface{}) { errorCode, ok := errorCodeMap[codeName] if !ok { // This SHOULD NEVER happen, all ErrorCodeNames MUST be registered.