package main

import (
	"fmt"
	"time"

	"github.com/kataras/iris/v12"
	"github.com/kataras/iris/v12/middleware/jwt"
)

const (
	accessTokenMaxAge  = 10 * time.Minute
	refreshTokenMaxAge = time.Hour
)

var (
	privateKey, publicKey = jwt.MustLoadRSA("rsa_private_key.pem", "rsa_public_key.pem")

	signer   = jwt.NewSigner(jwt.RS256, privateKey, accessTokenMaxAge)
	verifier = jwt.NewVerifier(jwt.RS256, publicKey)
)

// UserClaims a custom access claims structure.
type UserClaims struct {
	ID string `json:"user_id"`
	// Do: `json:"username,required"` to have this field required
	// or see the Validate method below instead.
	Username string `json:"username"`
}

// GetID implements the partial context user's ID interface.
// Note that if claims were a map then the claims value converted to UserClaims
// and no need to implement any method.
//
// This is useful when multiple auth methods are used (e.g. basic auth, jwt)
// but they all share a couple of methods.
func (u *UserClaims) GetID() string {
	return u.ID
}

// GetUsername implements the partial context user's Username interface.
func (u *UserClaims) GetUsername() string {
	return u.Username
}

// Validate completes the middleware's custom ClaimsValidator.
// It will not accept a token which its claims missing the username field
// (useful to not accept refresh tokens generated by the same algorithm).
func (u *UserClaims) Validate() error {
	if u.Username == "" {
		return fmt.Errorf("username field is missing")
	}

	return nil
}

// For refresh token, we will just use the jwt.Claims
// structure which contains the standard JWT fields.

func main() {
	app := iris.New()
	app.OnErrorCode(iris.StatusUnauthorized, handleUnauthorized)

	app.Get("/authenticate", generateTokenPair)
	app.Get("/refresh", refreshToken)

	protectedAPI := app.Party("/protected")
	{
		verifyMiddleware := verifier.Verify(func() interface{} {
			return new(UserClaims)
		})

		protectedAPI.Use(verifyMiddleware)

		protectedAPI.Get("/", func(ctx iris.Context) {
			// Access the claims through: jwt.Get:
			// claims := jwt.Get(ctx).(*UserClaims)
			// ctx.Writef("Username: %s\n", claims.Username)
			//
			// OR through context's user (if at least one method was implement by our UserClaims):
			user := ctx.User()
			id, _ := user.GetID()
			username, _ := user.GetUsername()
			ctx.Writef("ID: %s\nUsername: %s\n", id, username)
		})
	}

	// http://localhost:8080/protected (401)
	// http://localhost:8080/authenticate (200) (response JSON {access_token, refresh_token})
	// http://localhost:8080/protected?token={access_token} (200)
	// http://localhost:8080/protected?token={refresh_token} (401)
	// http://localhost:8080/refresh?refresh_token={refresh_token}
	// OR http://localhost:8080/refresh (request JSON{refresh_token = {refresh_token}}) (200) (response JSON {access_token, refresh_token})
	// http://localhost:8080/refresh?refresh_token={access_token} (401)
	app.Listen(":8080")
}

func generateTokenPair(ctx iris.Context) {
	// Simulate a user...
	userID := "53afcf05-38a3-43c3-82af-8bbbe0e4a149"

	// Map the current user with the refresh token,
	// so we make sure, on refresh route, that this refresh token owns
	// to that user before re-generate.
	refreshClaims := jwt.Claims{Subject: userID}

	accessClaims := UserClaims{
		ID:       userID,
		Username: "kataras",
	}

	// Generates a Token Pair, long-live for refresh tokens, e.g. 1 hour.
	// First argument is the access claims,
	// second argument is the refresh claims,
	// third argument is the refresh max age.
	tokenPair, err := signer.NewTokenPair(accessClaims, refreshClaims, refreshTokenMaxAge)
	if err != nil {
		ctx.Application().Logger().Errorf("token pair: %v", err)
		ctx.StopWithStatus(iris.StatusInternalServerError)
		return
	}

	// Send the generated token pair to the client.
	// The tokenPair looks like: {"access_token": $token, "refresh_token": $token}
	ctx.JSON(tokenPair)
}

// There are various methods of refresh token, depending on the application requirements.
// In this example we will accept a refresh token only, we will verify only a refresh token
// and we re-generate a whole new pair. An alternative would be to accept a token pair
// of both access and refresh tokens, verify the refresh, verify the access with a Leeway time
// and check if its going to expire soon, then generate a single access token.
func refreshToken(ctx iris.Context) {
	// Assuming you have access to the current user, e.g. sessions.
	//
	// Simulate a database call against our jwt subject
	// to make sure that this refresh token is a pair generated by this user.
	// * Note: You can remove the ExpectSubject and do this validation later on by yourself.
	currentUserID := "53afcf05-38a3-43c3-82af-8bbbe0e4a149"

	// Get the refresh token from ?refresh_token=$token OR
	// the request body's JSON{"refresh_token": "$token"}.
	refreshToken := []byte(ctx.URLParam("refresh_token"))
	if len(refreshToken) == 0 {
		// You can read the whole body with ctx.GetBody/ReadBody too.
		var tokenPair jwt.TokenPair
		if err := ctx.ReadJSON(&tokenPair); err != nil {
			ctx.StopWithError(iris.StatusBadRequest, err)
			return
		}

		refreshToken = tokenPair.RefreshToken
	}

	// Verify the refresh token, which its subject MUST match the "currentUserID".
	_, err := verifier.VerifyToken(refreshToken, jwt.Expected{Subject: currentUserID})
	if err != nil {
		ctx.Application().Logger().Errorf("verify refresh token: %v", err)
		ctx.StatusCode(iris.StatusUnauthorized)
		return
	}

	/* Custom validation checks can be performed after Verify calls too:
	currentUserID := "53afcf05-38a3-43c3-82af-8bbbe0e4a149"
	userID := verifiedToken.StandardClaims.Subject
	if userID != currentUserID {
		ctx.StopWithStatus(iris.StatusUnauthorized)
		return
	}
	*/

	// All OK, re-generate the new pair and send to client,
	// we could only generate an access token as well.
	generateTokenPair(ctx)
}

func handleUnauthorized(ctx iris.Context) {
	if err := ctx.GetErr(); err != nil {
		ctx.Application().Logger().Errorf("unauthorized: %v", err)
	}

	ctx.WriteString("Unauthorized")
}