diff --git a/_examples/miscellaneous/jwt/main.go b/_examples/miscellaneous/jwt/main.go index 930a86f5..8e4d6c9a 100644 --- a/_examples/miscellaneous/jwt/main.go +++ b/_examples/miscellaneous/jwt/main.go @@ -14,19 +14,12 @@ type UserClaims struct { } func main() { - // hmac - key := []byte("secret") - j, err := jwt.New(1*time.Minute, jwt.HS256, key) - if err != nil { - panic(err) - } - - // OPTIONAL encryption: - encryptionKey := []byte("itsa16bytesecret") - err = j.WithEncryption(jwt.A128GCM, jwt.DIRECT, encryptionKey) - if err != nil { - panic(err) - } + // Get keys from system's environment variables + // JWT_SECRET (for signing and verification) and JWT_SECRET_ENC(for encryption and decryption), + // or defaults to "secret" and "itsa16bytesecret" respectfully. + // + // Use the `jwt.New` instead for more flexibility, if necessary. + j := jwt.DefaultHMAC(15*time.Minute, "secret", "itsa16bytesecret") app := iris.New() app.Logger().SetLevel("debug") @@ -74,12 +67,62 @@ func main() { } /* -func load_From_File_Example() { +func default_RSA_Example() { + j := jwt.DefaultRSA(1 * time.Minute) +} + +Same as: + +func load_File_Or_Generate_RSA_Example() { + signKey, err := jwt.LoadRSA("jwt_sign.key", 2048) + if err != nil { + panic(err) + } + + j, err := jwt.New(15*time.Minute, jwt.RS256, signKey) + if err != nil { + panic(err) + } + + encKey, err := jwt.LoadRSA("jwt_enc.key", 2048) + if err != nil { + panic(err) + } + err = j.WithEncryption(jwt.A128CBCHS256, jwt.RSA15, encKey) + if err != nil { + panic(err) + } +} +*/ + +/* +func hmac_Example() { + // hmac + key := []byte("secret") + j, err := jwt.New(15*time.Minute, jwt.HS256, key) + if err != nil { + panic(err) + } + + // OPTIONAL encryption: + encryptionKey := []byte("itsa16bytesecret") + err = j.WithEncryption(jwt.A128GCM, jwt.DIRECT, encryptionKey) + if err != nil { + panic(err) + } +} +*/ + +/* +func load_From_File_With_Password_Example() { b, err := ioutil.ReadFile("./private_rsa.pem") if err != nil { panic(err) } - signKey := jwt.MustParseRSAPrivateKey(b, []byte("pass")) + signKey,err := jwt.ParseRSAPrivateKey(b, []byte("pass")) + if err != nil { + panic(err) + } j, err := jwt.New(15*time.Minute, jwt.RS256, signKey) if err != nil { @@ -89,23 +132,18 @@ func load_From_File_Example() { */ /* -func random_RSA_Sign_And_Encrypt_Example() { - j := jwt.Random(1 * time.Minute) -} -*/ - -/* -func random_manually_generate_RSA_Example() { - signey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - panic(err) - } - encryptionKey, err := rsa.GenerateKey(rand.Reader, 2048) +func generate_RSA_Example() { + signKey, err := rsa.GenerateKey(rand.Reader, 4096) if err != nil { panic(err) } - j, err := jwt.New(1*time.Minute, jwt.RS256, signey) + encryptionKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + panic(err) + } + + j, err := jwt.New(15*time.Minute, jwt.RS512, signKey) if err != nil { panic(err) } diff --git a/_examples/miscellaneous/jwt/private_rsa.pem b/_examples/miscellaneous/jwt/rsa_password_protected.key similarity index 100% rename from _examples/miscellaneous/jwt/private_rsa.pem rename to _examples/miscellaneous/jwt/rsa_password_protected.key diff --git a/context/context.go b/context/context.go index 8e78d1ea..8eed6d78 100644 --- a/context/context.go +++ b/context/context.go @@ -1946,7 +1946,7 @@ func (ctx *context) GetDomain() string { if host, _, err := net.SplitHostPort(hostport); err == nil { // has port. switch host { - case "127.0.0.1", "0.0.0.0", "::1", "[::1]", "0:0:0:0:0:0:0", "0:0:0:0:0:0:1": + case "127.0.0.1", "0.0.0.0", "::1", "[::1]", "0:0:0:0:0:0:0:0", "0:0:0:0:0:0:0:1": // loopback. return "localhost" default: diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go index 3565ec31..5ad52d79 100644 --- a/middleware/jwt/jwt.go +++ b/middleware/jwt/jwt.go @@ -2,10 +2,9 @@ package jwt import ( "crypto" - "crypto/rand" - "crypto/rsa" "encoding/json" "errors" + "os" "strings" "time" @@ -112,35 +111,6 @@ type JWT struct { DecriptionKey interface{} } -// Random returns a new `JWT` instance -// with in-memory generated rsa256 signing and encryption keys (development). -// It panics on errors. Next server ran will invalidate all request tokens. -// -// Use the `New` package-level function for production use. -func Random(maxAge time.Duration) *JWT { - sigKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - panic(err) - } - - j, err := New(maxAge, RS256, sigKey) - if err != nil { - panic(err) - } - - encKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - panic(err) - } - - err = j.WithEncryption(A128CBCHS256, RSA15, encKey) - if err != nil { - panic(err) - } - - return j -} - type privateKey interface{ Public() crypto.PublicKey } // New returns a new JWT instance. @@ -157,7 +127,7 @@ type privateKey interface{ Public() crypto.PublicKey } // 1. Generate key file, e.g: // $ openssl genrsa -des3 -out private.pem 2048 // 2. Read file contents with io.ReadFile("./private.pem") -// 3. Pass the []byte result to the `MustParseRSAPrivateKey(contents, password)` package-level helper +// 3. Pass the []byte result to the `ParseRSAPrivateKey(contents, password)` package-level helper // 4. Use the result *rsa.PrivateKey as "key" input parameter of this `New` function. // // See aliases.go file for available algorithms. @@ -185,6 +155,99 @@ func New(maxAge time.Duration, alg SignatureAlgorithm, key interface{}) (*JWT, e return j, nil } +// RSA filenames for `DefaultRSA`. +const ( + SignFilename = "jwt_sign.key" + EncFilename = "jwt_enc.key" +) + +// DefaultRSA returns a new `JWT` instance. +// It tries to parse RSA256 keys from "jwt_sign.key" and (optionally) "jwt_enc.key" files +// in the current working directory, and if not found, it generates and exports the keys. +// +// It panics on errors. +// Use the `New` package-level function instead for more options. +func DefaultRSA(maxAge time.Duration) *JWT { + // Do not try to create or load enc key if only sign key already exists. + withEncryption := true + if fileExists(SignFilename) { + withEncryption = fileExists(EncFilename) + } + + sigKey, err := LoadRSA(SignFilename, 2048) + if err != nil { + panic(err) + } + + j, err := New(maxAge, RS256, sigKey) + if err != nil { + panic(err) + } + + if withEncryption { + encKey, err := LoadRSA(EncFilename, 2048) + if err != nil { + panic(err) + } + err = j.WithEncryption(A128CBCHS256, RSA15, encKey) + if err != nil { + panic(err) + } + } + + return j +} + +const ( + signEnv = "JWT_SECRET" + encEnv = "JWT_SECRET_ENC" +) + +func getenv(key string, def string) string { + v := os.Getenv(key) + if v == "" { + return def + } + + return v +} + +// DefaultHMAC returns a new `JWT` instance. +// It tries to read hmac256 secret keys from system environment variables: +// * JWT_SECRET for signing and verification key and +// * JWT_SECRET_ENC for encryption and decryption key +// and defaults them to the given "keys" respectfully. +// +// It panics on errors. +// Use the `New` package-level function instead for more options. +func DefaultHMAC(maxAge time.Duration, keys ...string) *JWT { + var defaultSignSecret, defaultEncSecret string + + switch len(keys) { + case 1: + defaultSignSecret = keys[0] + case 2: + defaultEncSecret = keys[1] + } + + signSecret := getenv(signEnv, defaultSignSecret) + encSecret := getenv(encEnv, defaultEncSecret) + + j, err := New(maxAge, HS256, []byte(signSecret)) + if err != nil { + panic(err) + } + + if encSecret != "" { + err = j.WithEncryption(A128GCM, DIRECT, []byte(encSecret)) + if err != nil { + panic(err) + } + } + + return j +} + // WithEncryption method enables encryption and decryption of the token. // It sets an appropriate encrypter(`Encrypter` and the `DecriptionKey` fields) based on the key type. func (j *JWT) WithEncryption(contentEncryption ContentEncryption, alg KeyAlgorithm, key interface{}) error { diff --git a/middleware/jwt/jwt_test.go b/middleware/jwt/jwt_test.go index e580db94..75ebbe91 100644 --- a/middleware/jwt/jwt_test.go +++ b/middleware/jwt/jwt_test.go @@ -2,6 +2,7 @@ package jwt_test import ( + "os" "testing" "time" @@ -18,12 +19,21 @@ type userClaims struct { const testMaxAge = 3 * time.Second // Random RSA verification and encryption. -func TestRSA(t *testing.T) { - j := jwt.Random(testMaxAge) +func TestDefaultRSA(t *testing.T) { + j := jwt.DefaultRSA(testMaxAge) + t.Cleanup(func() { + os.Remove(jwt.SignFilename) + os.Remove(jwt.EncFilename) + }) testWriteVerifyToken(t, j) } // HMAC verification and encryption. +func TestDefaultHMAC(t *testing.T) { + j := jwt.DefaultHMAC(testMaxAge, "secret", "itsa16bytesecret") + testWriteVerifyToken(t, j) +} + func TestHMAC(t *testing.T) { j, err := jwt.New(testMaxAge, jwt.HS256, []byte("secret")) if err != nil { diff --git a/middleware/jwt/rsa_util.go b/middleware/jwt/rsa_util.go new file mode 100644 index 00000000..e68b19f8 --- /dev/null +++ b/middleware/jwt/rsa_util.go @@ -0,0 +1,106 @@ +package jwt + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "io/ioutil" + "os" +) + +// LoadRSA tries to read RSA Private Key from "fname" system file, +// if does not exist then it generates a new random one based on "bits" (e.g. 2048, 4096) +// and exports it to a new "fname" file. +func LoadRSA(fname string, bits int) (key *rsa.PrivateKey, err error) { + exists := fileExists(fname) + if exists { + key, err = importFromFile(fname) + } else { + key, err = rsa.GenerateKey(rand.Reader, bits) + } + + if err != nil { + return + } + + if !exists { + err = exportToFile(key, fname) + } + + return +} + +func exportToFile(key *rsa.PrivateKey, filename string) error { + b := x509.MarshalPKCS1PrivateKey(key) + encoded := pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: b, + }, + ) + + return ioutil.WriteFile(filename, encoded, 0644) +} + +func importFromFile(filename string) (*rsa.PrivateKey, error) { + b, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + + return ParseRSAPrivateKey(b, nil) +} + +func fileExists(filename string) bool { + info, err := os.Stat(filename) + if os.IsNotExist(err) { + return false + } + return !info.IsDir() +} + +var ( + // ErrNotPEM is an error type of the `ParseXXX` function(s) fired + // when the data are not PEM-encoded. + ErrNotPEM = errors.New("key must be PEM encoded") + // ErrInvalidKey is an error type of the `ParseXXX` function(s) fired + // when the contents are not type of rsa private key. + ErrInvalidKey = errors.New("key is not of type *rsa.PrivateKey") +) + +// ParseRSAPrivateKey encodes a PEM-encoded PKCS1 or PKCS8 private key protected with a password. +func ParseRSAPrivateKey(key, password []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(key) + if block == nil { + return nil, ErrNotPEM + } + + var ( + parsedKey interface{} + err error + ) + + var blockDecrypted []byte + if len(password) > 0 { + if blockDecrypted, err = x509.DecryptPEMBlock(block, password); err != nil { + return nil, err + } + } else { + blockDecrypted = block.Bytes + } + + if parsedKey, err = x509.ParsePKCS1PrivateKey(blockDecrypted); err != nil { + if parsedKey, err = x509.ParsePKCS8PrivateKey(blockDecrypted); err != nil { + return nil, err + } + } + + privateKey, ok := parsedKey.(*rsa.PrivateKey) + if !ok { + return nil, ErrInvalidKey + } + + return privateKey, nil +} diff --git a/middleware/jwt/util.go b/middleware/jwt/util.go deleted file mode 100644 index 87159ad6..00000000 --- a/middleware/jwt/util.go +++ /dev/null @@ -1,98 +0,0 @@ -package jwt - -import ( - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" -) - -// ErrNotPEM is a panic error of the MustParseXXX functions when the data are not PEM-encoded. -var ErrNotPEM = errors.New("key must be PEM encoded") - -// MustParseRSAPrivateKey encodes a PEM-encoded PKCS1 or PKCS8 private key protected with a password. -func MustParseRSAPrivateKey(key, password []byte) *rsa.PrivateKey { - block, _ := pem.Decode(key) - if block == nil { - panic(ErrNotPEM) - } - - var ( - parsedKey interface{} - err error - ) - - var blockDecrypted []byte - if blockDecrypted, err = x509.DecryptPEMBlock(block, password); err != nil { - panic(err) - } - - if parsedKey, err = x509.ParsePKCS1PrivateKey(blockDecrypted); err != nil { - if parsedKey, err = x509.ParsePKCS8PrivateKey(blockDecrypted); err != nil { - panic(err) - } - } - - privateKey, ok := parsedKey.(*rsa.PrivateKey) - if !ok { - panic("key is not of type *rsa.PrivateKey") - } - - return privateKey -} - -// MustParseRSAPublicKey encodes a PEM encoded PKCS1 or PKCS8 public key. -func MustParseRSAPublicKey(key []byte) *rsa.PublicKey { - var err error - - // Parse PEM block - var block *pem.Block - if block, _ = pem.Decode(key); block == nil { - panic(ErrNotPEM) - } - - // Parse the key - var parsedKey interface{} - if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil { - if cert, err := x509.ParseCertificate(block.Bytes); err == nil { - parsedKey = cert.PublicKey - } else { - panic(err) - } - } - - var pkey *rsa.PublicKey - var ok bool - if pkey, ok = parsedKey.(*rsa.PublicKey); !ok { - panic("key is not of type *rsa.PublicKey") - } - - return pkey -} - -/* -// MustParseEd25519 PEM encoded Ed25519. -func MustParseEd25519(key []byte) ed25519.PrivateKey { - // Parse PEM block - block, _ := pem.Decode(key) - if block == nil { - panic(ErrNotPEM) - } - - type ed25519PrivKey struct { - Version int - ObjectIdentifier struct { - ObjectIdentifier asn1.ObjectIdentifier - } - PrivateKey []byte - } - - var asn1PrivKey ed25519PrivKey - if _, err := asn1.Unmarshal(block.Bytes, &asn1PrivKey); err != nil { - panic(err) - } - - privateKey := ed25519.NewKeyFromSeed(asn1PrivKey.PrivateKey[2:]) - return privateKey -} -*/