diff --git a/auth/auth.go b/auth/auth.go index 58274d3..f4fa65e 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -10,6 +10,7 @@ import ( ) type Claims struct { + ID uint `json:"id"` Username string `json:"username"` Privileges uint `json:"privileges"` jwt.RegisteredClaims diff --git a/data/context.go b/data/context.go index f694c91..0a03fb4 100644 --- a/data/context.go +++ b/data/context.go @@ -66,9 +66,10 @@ type Contribution struct { type User struct { ModelBase - Username string `json:"username"` - Password string `json:"password"` - Privileges uint `json:"admin"` + Username string `json:"username"` + Password string `json:"password"` + Privileges uint `json:"admin"` + LastLogin *time.Time `json:"lastLogin"` } var Db *gorm.DB diff --git a/endpoints/changepassword.go b/endpoints/changepassword.go new file mode 100644 index 0000000..f5fbb11 --- /dev/null +++ b/endpoints/changepassword.go @@ -0,0 +1,45 @@ +package endpoints + +import ( + "encoding/json" + "net/http" + "time" + + . "github.com/imosed/signet/data" +) + +type ChangePasswordRequest struct { + UserID uint `json:"userID"` + Password string `json:"password"` +} + +func ChangePassword(w http.ResponseWriter, r *http.Request) { + var req ChangePasswordRequest + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + panic("Could not decode body") + } + + var user User + Db.Table("users").First(&user, req.UserID) + + var password string + password, err = GetHashedPassword(req.Password) + if err != nil { + panic("Could not get password") + } + + if user.LastLogin == nil { + Db.Table("users").Where("id = ?", req.UserID).Updates(map[string]interface{}{"last_login": time.Now(), "password": password}) + } else { + Db.Table("users").Where("id = ?", req.UserID).Update("password = ?", password) + } + + var resp SuccessResponse + resp.Success = true + + err = json.NewEncoder(w).Encode(resp) + if err != nil { + panic("Could not deliver response") + } +} diff --git a/endpoints/escalateprivileges.go b/endpoints/escalateprivileges.go index feed4e2..14f7996 100644 --- a/endpoints/escalateprivileges.go +++ b/endpoints/escalateprivileges.go @@ -10,14 +10,15 @@ import ( ) type EscalatePrivilegesRequest struct { - Username string + UserID uint `json:"userID"` + Privileges uint `json:"privileges"` } -func EscalatePrivileges(w http.ResponseWriter, r *http.Request) { +func ChangePrivileges(w http.ResponseWriter, r *http.Request) { var req EscalatePrivilegesRequest err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - log.Error().Err(err).Msg("Could not decode body in EscalatePrivileges call") + log.Error().Err(err).Msg("Could not decode body in ChangePrivileges call") return } @@ -28,8 +29,8 @@ func EscalatePrivileges(w http.ResponseWriter, r *http.Request) { claims, err = auth.GetUserClaims(r) if claims.Privileges < 2 { - Db.Table("users").Where("username = ?", req.Username).Find(&user) - if user.Privileges < 2 { + Db.Table("users").Where("id = ?", req.UserID).Find(&user) + if req.Privileges == SuperUser { resp.Success = false err = json.NewEncoder(w).Encode(resp) @@ -39,7 +40,8 @@ func EscalatePrivileges(w http.ResponseWriter, r *http.Request) { return } - user.Privileges = AdminPlus + user.Privileges = req.Privileges + Db.Save(user) resp.Success = true } else { resp.Success = false diff --git a/endpoints/login.go b/endpoints/login.go index 414d3d9..e996eff 100644 --- a/endpoints/login.go +++ b/endpoints/login.go @@ -13,7 +13,8 @@ import ( ) type LoginResponse struct { - Token *string `json:"token"` + Token *string `json:"token"` + LastLogin *time.Time `json:"lastLogin"` } func Login(w http.ResponseWriter, r *http.Request) { @@ -24,14 +25,12 @@ func Login(w http.ResponseWriter, r *http.Request) { return } - var userData struct { - ID uint - Password string - Privileges uint - } + var userData User + + var loginTime = time.Now() var resp LoginResponse - Db.Table("users").Select("id, password, privileges"). + Db.Table("users").Select("id, password, privileges, last_login"). Where("username = ?", req.Username).First(&userData) var passwordMatches bool passwordMatches, err = ComparePasswordAndHash(req.Password, userData.Password) @@ -40,7 +39,6 @@ func Login(w http.ResponseWriter, r *http.Request) { return } if !passwordMatches { - resp.Token = nil err = json.NewEncoder(w).Encode(resp) if err != nil { log.Error().Err(err).Msg("Failed to deliver failed login attempt response") @@ -49,10 +47,11 @@ func Login(w http.ResponseWriter, r *http.Request) { } token := jwt.NewWithClaims(jwt.SigningMethodHS256, &auth.Claims{ + ID: userData.ID, Username: req.Username, Privileges: userData.Privileges, RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Hour)), + ExpiresAt: jwt.NewNumericDate(loginTime.Add(10 * time.Hour)), }, }) @@ -63,6 +62,12 @@ func Login(w http.ResponseWriter, r *http.Request) { return } resp.Token = &tokenString + resp.LastLogin = userData.LastLogin + + if userData.LastLogin != nil { + // need to set this after the user changes their password + Db.Table("users").Where("id = ?", userData.ID).Update("last_login", loginTime) + } err = json.NewEncoder(w).Encode(resp) if err != nil { diff --git a/endpoints/register.go b/endpoints/register.go index 4698770..3c1e536 100644 --- a/endpoints/register.go +++ b/endpoints/register.go @@ -124,6 +124,17 @@ func determinePrivileges() uint { } } +func GetHashedPassword(password string) (encodedHash string, err error) { + hash, err := GenerateHash(password, &Params{ + Memory: uint32(viper.GetInt("hashing.memory")), + Iterations: uint32(viper.GetInt("hashing.iterations")), + Parallelism: uint8(viper.GetInt("hashing.parallelism")), + SaltLength: uint32(viper.GetInt("hashing.saltLength")), + KeyLength: uint32(viper.GetInt("hashing.keyLength")), + }) + return hash, err +} + func Register(w http.ResponseWriter, r *http.Request) { var req AuthenticationRequest err := json.NewDecoder(r.Body).Decode(&req) @@ -144,13 +155,7 @@ func Register(w http.ResponseWriter, r *http.Request) { } if noUsersRegistered() || claims.Privileges <= AdminPlus { - hash, err := GenerateHash(req.Password, &Params{ - Memory: uint32(viper.GetInt("hashing.memory")), - Iterations: uint32(viper.GetInt("hashing.iterations")), - Parallelism: uint8(viper.GetInt("hashing.parallelism")), - SaltLength: uint32(viper.GetInt("hashing.saltLength")), - KeyLength: uint32(viper.GetInt("hashing.keyLength")), - }) + hash, err := GetHashedPassword(req.Password) if err != nil { log.Error().Err(err).Msg("Could not generate hash for registration") return diff --git a/main.go b/main.go index 625bdea..eb9a0f6 100644 --- a/main.go +++ b/main.go @@ -48,8 +48,9 @@ func main() { router.HandleFunc("/ContributorStream", endpoints.ContributorStream) router.HandleFunc("/Login", endpoints.Login) router.HandleFunc("/Register", endpoints.Register) + router.HandleFunc("/ChangePassword", endpoints.ChangePassword) router.HandleFunc("/NearlyCompleteFunds", endpoints.NearlyCompleteFunds) - router.HandleFunc("/EscalatePrivileges", endpoints.EscalatePrivileges) + router.HandleFunc("/ChangePrivileges", endpoints.ChangePrivileges) router.HandleFunc("/UsersExist", endpoints.UsersExist) port := viper.GetInt("app.port")