@@ -10,6 +10,7 @@ import ( | |||||
) | ) | ||||
type Claims struct { | type Claims struct { | ||||
ID uint `json:"id"` | |||||
Username string `json:"username"` | Username string `json:"username"` | ||||
Privileges uint `json:"privileges"` | Privileges uint `json:"privileges"` | ||||
jwt.RegisteredClaims | jwt.RegisteredClaims | ||||
@@ -66,9 +66,10 @@ type Contribution struct { | |||||
type User struct { | type User struct { | ||||
ModelBase | ModelBase | ||||
Username string `json:"username"` | |||||
Password string `json:"password"` | |||||
Privileges uint `json:"admin"` | |||||
Username string `json:"username"` | |||||
Password string `json:"password"` | |||||
Privileges uint `json:"admin"` | |||||
LastLogin *time.Time `json:"lastLogin"` | |||||
} | } | ||||
var Db *gorm.DB | var Db *gorm.DB | ||||
@@ -0,0 +1,45 @@ | |||||
package endpoints | |||||
import ( | |||||
"encoding/json" | |||||
"net/http" | |||||
"time" | |||||
. "github.com/imosed/signet/data" | |||||
) | |||||
type ChangePasswordRequest struct { | |||||
UserID uint `json:"userID"` | |||||
Password string `json:"password"` | |||||
} | |||||
func ChangePassword(w http.ResponseWriter, r *http.Request) { | |||||
var req ChangePasswordRequest | |||||
err := json.NewDecoder(r.Body).Decode(&req) | |||||
if err != nil { | |||||
panic("Could not decode body") | |||||
} | |||||
var user User | |||||
Db.Table("users").First(&user, req.UserID) | |||||
var password string | |||||
password, err = GetHashedPassword(req.Password) | |||||
if err != nil { | |||||
panic("Could not get password") | |||||
} | |||||
if user.LastLogin == nil { | |||||
Db.Table("users").Where("id = ?", req.UserID).Updates(map[string]interface{}{"last_login": time.Now(), "password": password}) | |||||
} else { | |||||
Db.Table("users").Where("id = ?", req.UserID).Update("password = ?", password) | |||||
} | |||||
var resp SuccessResponse | |||||
resp.Success = true | |||||
err = json.NewEncoder(w).Encode(resp) | |||||
if err != nil { | |||||
panic("Could not deliver response") | |||||
} | |||||
} |
@@ -10,14 +10,15 @@ import ( | |||||
) | ) | ||||
type EscalatePrivilegesRequest struct { | type EscalatePrivilegesRequest struct { | ||||
Username string | |||||
UserID uint `json:"userID"` | |||||
Privileges uint `json:"privileges"` | |||||
} | } | ||||
func EscalatePrivileges(w http.ResponseWriter, r *http.Request) { | |||||
func ChangePrivileges(w http.ResponseWriter, r *http.Request) { | |||||
var req EscalatePrivilegesRequest | var req EscalatePrivilegesRequest | ||||
err := json.NewDecoder(r.Body).Decode(&req) | err := json.NewDecoder(r.Body).Decode(&req) | ||||
if err != nil { | if err != nil { | ||||
log.Error().Err(err).Msg("Could not decode body in EscalatePrivileges call") | |||||
log.Error().Err(err).Msg("Could not decode body in ChangePrivileges call") | |||||
return | return | ||||
} | } | ||||
@@ -28,8 +29,8 @@ func EscalatePrivileges(w http.ResponseWriter, r *http.Request) { | |||||
claims, err = auth.GetUserClaims(r) | claims, err = auth.GetUserClaims(r) | ||||
if claims.Privileges < 2 { | if claims.Privileges < 2 { | ||||
Db.Table("users").Where("username = ?", req.Username).Find(&user) | |||||
if user.Privileges < 2 { | |||||
Db.Table("users").Where("id = ?", req.UserID).Find(&user) | |||||
if req.Privileges == SuperUser { | |||||
resp.Success = false | resp.Success = false | ||||
err = json.NewEncoder(w).Encode(resp) | err = json.NewEncoder(w).Encode(resp) | ||||
@@ -39,7 +40,8 @@ func EscalatePrivileges(w http.ResponseWriter, r *http.Request) { | |||||
return | return | ||||
} | } | ||||
user.Privileges = AdminPlus | |||||
user.Privileges = req.Privileges | |||||
Db.Save(user) | |||||
resp.Success = true | resp.Success = true | ||||
} else { | } else { | ||||
resp.Success = false | resp.Success = false | ||||
@@ -13,7 +13,8 @@ import ( | |||||
) | ) | ||||
type LoginResponse struct { | type LoginResponse struct { | ||||
Token *string `json:"token"` | |||||
Token *string `json:"token"` | |||||
LastLogin *time.Time `json:"lastLogin"` | |||||
} | } | ||||
func Login(w http.ResponseWriter, r *http.Request) { | func Login(w http.ResponseWriter, r *http.Request) { | ||||
@@ -24,14 +25,12 @@ func Login(w http.ResponseWriter, r *http.Request) { | |||||
return | return | ||||
} | } | ||||
var userData struct { | |||||
ID uint | |||||
Password string | |||||
Privileges uint | |||||
} | |||||
var userData User | |||||
var loginTime = time.Now() | |||||
var resp LoginResponse | var resp LoginResponse | ||||
Db.Table("users").Select("id, password, privileges"). | |||||
Db.Table("users").Select("id, password, privileges, last_login"). | |||||
Where("username = ?", req.Username).First(&userData) | Where("username = ?", req.Username).First(&userData) | ||||
var passwordMatches bool | var passwordMatches bool | ||||
passwordMatches, err = ComparePasswordAndHash(req.Password, userData.Password) | passwordMatches, err = ComparePasswordAndHash(req.Password, userData.Password) | ||||
@@ -40,7 +39,6 @@ func Login(w http.ResponseWriter, r *http.Request) { | |||||
return | return | ||||
} | } | ||||
if !passwordMatches { | if !passwordMatches { | ||||
resp.Token = nil | |||||
err = json.NewEncoder(w).Encode(resp) | err = json.NewEncoder(w).Encode(resp) | ||||
if err != nil { | if err != nil { | ||||
log.Error().Err(err).Msg("Failed to deliver failed login attempt response") | log.Error().Err(err).Msg("Failed to deliver failed login attempt response") | ||||
@@ -49,10 +47,11 @@ func Login(w http.ResponseWriter, r *http.Request) { | |||||
} | } | ||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &auth.Claims{ | token := jwt.NewWithClaims(jwt.SigningMethodHS256, &auth.Claims{ | ||||
ID: userData.ID, | |||||
Username: req.Username, | Username: req.Username, | ||||
Privileges: userData.Privileges, | Privileges: userData.Privileges, | ||||
RegisteredClaims: jwt.RegisteredClaims{ | RegisteredClaims: jwt.RegisteredClaims{ | ||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Hour)), | |||||
ExpiresAt: jwt.NewNumericDate(loginTime.Add(10 * time.Hour)), | |||||
}, | }, | ||||
}) | }) | ||||
@@ -63,6 +62,12 @@ func Login(w http.ResponseWriter, r *http.Request) { | |||||
return | return | ||||
} | } | ||||
resp.Token = &tokenString | resp.Token = &tokenString | ||||
resp.LastLogin = userData.LastLogin | |||||
if userData.LastLogin != nil { | |||||
// need to set this after the user changes their password | |||||
Db.Table("users").Where("id = ?", userData.ID).Update("last_login", loginTime) | |||||
} | |||||
err = json.NewEncoder(w).Encode(resp) | err = json.NewEncoder(w).Encode(resp) | ||||
if err != nil { | if err != nil { | ||||
@@ -124,6 +124,17 @@ func determinePrivileges() uint { | |||||
} | } | ||||
} | } | ||||
func GetHashedPassword(password string) (encodedHash string, err error) { | |||||
hash, err := GenerateHash(password, &Params{ | |||||
Memory: uint32(viper.GetInt("hashing.memory")), | |||||
Iterations: uint32(viper.GetInt("hashing.iterations")), | |||||
Parallelism: uint8(viper.GetInt("hashing.parallelism")), | |||||
SaltLength: uint32(viper.GetInt("hashing.saltLength")), | |||||
KeyLength: uint32(viper.GetInt("hashing.keyLength")), | |||||
}) | |||||
return hash, err | |||||
} | |||||
func Register(w http.ResponseWriter, r *http.Request) { | func Register(w http.ResponseWriter, r *http.Request) { | ||||
var req AuthenticationRequest | var req AuthenticationRequest | ||||
err := json.NewDecoder(r.Body).Decode(&req) | err := json.NewDecoder(r.Body).Decode(&req) | ||||
@@ -144,13 +155,7 @@ func Register(w http.ResponseWriter, r *http.Request) { | |||||
} | } | ||||
if noUsersRegistered() || claims.Privileges <= AdminPlus { | if noUsersRegistered() || claims.Privileges <= AdminPlus { | ||||
hash, err := GenerateHash(req.Password, &Params{ | |||||
Memory: uint32(viper.GetInt("hashing.memory")), | |||||
Iterations: uint32(viper.GetInt("hashing.iterations")), | |||||
Parallelism: uint8(viper.GetInt("hashing.parallelism")), | |||||
SaltLength: uint32(viper.GetInt("hashing.saltLength")), | |||||
KeyLength: uint32(viper.GetInt("hashing.keyLength")), | |||||
}) | |||||
hash, err := GetHashedPassword(req.Password) | |||||
if err != nil { | if err != nil { | ||||
log.Error().Err(err).Msg("Could not generate hash for registration") | log.Error().Err(err).Msg("Could not generate hash for registration") | ||||
return | return | ||||
@@ -48,8 +48,9 @@ func main() { | |||||
router.HandleFunc("/ContributorStream", endpoints.ContributorStream) | router.HandleFunc("/ContributorStream", endpoints.ContributorStream) | ||||
router.HandleFunc("/Login", endpoints.Login) | router.HandleFunc("/Login", endpoints.Login) | ||||
router.HandleFunc("/Register", endpoints.Register) | router.HandleFunc("/Register", endpoints.Register) | ||||
router.HandleFunc("/ChangePassword", endpoints.ChangePassword) | |||||
router.HandleFunc("/NearlyCompleteFunds", endpoints.NearlyCompleteFunds) | router.HandleFunc("/NearlyCompleteFunds", endpoints.NearlyCompleteFunds) | ||||
router.HandleFunc("/EscalatePrivileges", endpoints.EscalatePrivileges) | |||||
router.HandleFunc("/ChangePrivileges", endpoints.ChangePrivileges) | |||||
router.HandleFunc("/UsersExist", endpoints.UsersExist) | router.HandleFunc("/UsersExist", endpoints.UsersExist) | ||||
port := viper.GetInt("app.port") | port := viper.GetInt("app.port") | ||||