package endpoints import ( "crypto/rand" "crypto/subtle" "encoding/base64" "encoding/json" "errors" "fmt" "net/http" "strings" "github.com/imosed/signet/auth" . "github.com/imosed/signet/data" "github.com/rs/zerolog/log" "github.com/spf13/viper" "golang.org/x/crypto/argon2" ) const ( SuperUser uint = iota AdminPlus Admin ) type Params struct { Iterations uint32 Memory uint32 Parallelism uint8 SaltLength uint32 KeyLength uint32 } func GenerateRandomBytes(n uint32) ([]byte, error) { var b = make([]byte, n) _, err := rand.Read(b) if err != nil { return nil, err } return b, nil } func GenerateHash(password string, p *Params) (encodedHash string, err error) { salt, err := GenerateRandomBytes(p.SaltLength) if err != nil { return "", err } var hash = argon2.IDKey([]byte(password), salt, p.Iterations, p.Memory, p.Parallelism, p.KeyLength) var bSalt = base64.RawStdEncoding.EncodeToString(salt) var bHash = base64.RawStdEncoding.EncodeToString(hash) encodedHash = fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, p.Memory, p.Iterations, p.Parallelism, bSalt, bHash) return encodedHash, nil } func ComparePasswordAndHash(password, encodedHash string) (match bool, err error) { p, salt, hash, err := DecodeHash(encodedHash) if err != nil { return false, err } var otherHash = argon2.IDKey([]byte(password), salt, p.Iterations, p.Memory, p.Parallelism, p.KeyLength) if subtle.ConstantTimeCompare(hash, otherHash) == 1 { return true, nil } return false, nil } func DecodeHash(encodedHash string) (p *Params, salt, hash []byte, err error) { var details = strings.Split(encodedHash, "$") if len(details) != 6 { return nil, nil, nil, errors.New("the encoded hash is not in the correct format") } var version int _, err = fmt.Sscanf(details[2], "v=%d", &version) if err != nil { return nil, nil, nil, err } if version != argon2.Version { return nil, nil, nil, errors.New("the version of argon2 does not match the hash string") } p = &Params{} _, err = fmt.Sscanf(details[3], "m=%d,t=%d,p=%d", &p.Memory, &p.Iterations, &p.Parallelism) if err != nil { return nil, nil, nil, err } salt, err = base64.RawStdEncoding.DecodeString(details[4]) if err != nil { return nil, nil, nil, err } p.SaltLength = uint32(len(salt)) hash, err = base64.RawStdEncoding.DecodeString(details[5]) if err != nil { return nil, nil, nil, err } p.KeyLength = uint32(len(hash)) return p, salt, hash, nil } type AuthenticationRequest struct { Username string `json:"username"` Password string `json:"password"` } func noUsersRegistered() bool { results := Db.Find(&User{}) return results.RowsAffected == 0 } func determinePrivileges() uint { if noUsersRegistered() { return SuperUser } else { return Admin } } func Register(w http.ResponseWriter, r *http.Request) { var req AuthenticationRequest err := json.NewDecoder(r.Body).Decode(&req) if err != nil { log.Error().Err(err).Msg("Could not decode body in Register call") return } var claims *auth.Claims claims, err = auth.GetUserClaims(r) if err != nil { log.Error().Err(err).Msg("Could not determine if user is authenticated") return } if claims == nil { return } if noUsersRegistered() || claims.Privileges <= AdminPlus { hash, err := GenerateHash(req.Password, &Params{ Memory: uint32(viper.GetInt("hashing.memory")), Iterations: uint32(viper.GetInt("hashing.iterations")), Parallelism: uint8(viper.GetInt("hashing.parallelism")), SaltLength: uint32(viper.GetInt("hashing.saltLength")), KeyLength: uint32(viper.GetInt("hashing.keyLength")), }) if err != nil { log.Error().Err(err).Msg("Could not generate hash for registration") return } Db.Create(&User{ Username: req.Username, Password: hash, Privileges: determinePrivileges(), }) err = json.NewEncoder(w).Encode(SuccessResponse{Success: true}) if err != nil { log.Error().Err(err).Msg("Could not deliver successful account creation response") } } else if !noUsersRegistered() { err = json.NewEncoder(w).Encode(SuccessResponse{Success: false}) if err != nil { log.Error().Err(err).Msg("Could not deliver unsuccessful account creation response") } } else if claims.Privileges > SuperUser { w.WriteHeader(403) } }