@@ -10,6 +10,7 @@ import ( | |||
) | |||
type Claims struct { | |||
ID uint `json:"id"` | |||
Username string `json:"username"` | |||
Privileges uint `json:"privileges"` | |||
jwt.RegisteredClaims | |||
@@ -66,9 +66,10 @@ type Contribution struct { | |||
type User struct { | |||
ModelBase | |||
Username string `json:"username"` | |||
Password string `json:"password"` | |||
Privileges uint `json:"admin"` | |||
Username string `json:"username"` | |||
Password string `json:"password"` | |||
Privileges uint `json:"admin"` | |||
LastLogin *time.Time `json:"lastLogin"` | |||
} | |||
var Db *gorm.DB | |||
@@ -0,0 +1,45 @@ | |||
package endpoints | |||
import ( | |||
"encoding/json" | |||
"net/http" | |||
"time" | |||
. "github.com/imosed/signet/data" | |||
) | |||
type ChangePasswordRequest struct { | |||
UserID uint `json:"userID"` | |||
Password string `json:"password"` | |||
} | |||
func ChangePassword(w http.ResponseWriter, r *http.Request) { | |||
var req ChangePasswordRequest | |||
err := json.NewDecoder(r.Body).Decode(&req) | |||
if err != nil { | |||
panic("Could not decode body") | |||
} | |||
var user User | |||
Db.Table("users").First(&user, req.UserID) | |||
var password string | |||
password, err = GetHashedPassword(req.Password) | |||
if err != nil { | |||
panic("Could not get password") | |||
} | |||
if user.LastLogin == nil { | |||
Db.Table("users").Where("id = ?", req.UserID).Updates(map[string]interface{}{"last_login": time.Now(), "password": password}) | |||
} else { | |||
Db.Table("users").Where("id = ?", req.UserID).Update("password = ?", password) | |||
} | |||
var resp SuccessResponse | |||
resp.Success = true | |||
err = json.NewEncoder(w).Encode(resp) | |||
if err != nil { | |||
panic("Could not deliver response") | |||
} | |||
} |
@@ -10,14 +10,15 @@ import ( | |||
) | |||
type EscalatePrivilegesRequest struct { | |||
Username string | |||
UserID uint `json:"userID"` | |||
Privileges uint `json:"privileges"` | |||
} | |||
func EscalatePrivileges(w http.ResponseWriter, r *http.Request) { | |||
func ChangePrivileges(w http.ResponseWriter, r *http.Request) { | |||
var req EscalatePrivilegesRequest | |||
err := json.NewDecoder(r.Body).Decode(&req) | |||
if err != nil { | |||
log.Error().Err(err).Msg("Could not decode body in EscalatePrivileges call") | |||
log.Error().Err(err).Msg("Could not decode body in ChangePrivileges call") | |||
return | |||
} | |||
@@ -28,8 +29,8 @@ func EscalatePrivileges(w http.ResponseWriter, r *http.Request) { | |||
claims, err = auth.GetUserClaims(r) | |||
if claims.Privileges < 2 { | |||
Db.Table("users").Where("username = ?", req.Username).Find(&user) | |||
if user.Privileges < 2 { | |||
Db.Table("users").Where("id = ?", req.UserID).Find(&user) | |||
if req.Privileges == SuperUser { | |||
resp.Success = false | |||
err = json.NewEncoder(w).Encode(resp) | |||
@@ -39,7 +40,8 @@ func EscalatePrivileges(w http.ResponseWriter, r *http.Request) { | |||
return | |||
} | |||
user.Privileges = AdminPlus | |||
user.Privileges = req.Privileges | |||
Db.Save(user) | |||
resp.Success = true | |||
} else { | |||
resp.Success = false | |||
@@ -13,7 +13,8 @@ import ( | |||
) | |||
type LoginResponse struct { | |||
Token *string `json:"token"` | |||
Token *string `json:"token"` | |||
LastLogin *time.Time `json:"lastLogin"` | |||
} | |||
func Login(w http.ResponseWriter, r *http.Request) { | |||
@@ -24,14 +25,12 @@ func Login(w http.ResponseWriter, r *http.Request) { | |||
return | |||
} | |||
var userData struct { | |||
ID uint | |||
Password string | |||
Privileges uint | |||
} | |||
var userData User | |||
var loginTime = time.Now() | |||
var resp LoginResponse | |||
Db.Table("users").Select("id, password, privileges"). | |||
Db.Table("users").Select("id, password, privileges, last_login"). | |||
Where("username = ?", req.Username).First(&userData) | |||
var passwordMatches bool | |||
passwordMatches, err = ComparePasswordAndHash(req.Password, userData.Password) | |||
@@ -40,7 +39,6 @@ func Login(w http.ResponseWriter, r *http.Request) { | |||
return | |||
} | |||
if !passwordMatches { | |||
resp.Token = nil | |||
err = json.NewEncoder(w).Encode(resp) | |||
if err != nil { | |||
log.Error().Err(err).Msg("Failed to deliver failed login attempt response") | |||
@@ -49,10 +47,11 @@ func Login(w http.ResponseWriter, r *http.Request) { | |||
} | |||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &auth.Claims{ | |||
ID: userData.ID, | |||
Username: req.Username, | |||
Privileges: userData.Privileges, | |||
RegisteredClaims: jwt.RegisteredClaims{ | |||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Hour)), | |||
ExpiresAt: jwt.NewNumericDate(loginTime.Add(10 * time.Hour)), | |||
}, | |||
}) | |||
@@ -63,6 +62,12 @@ func Login(w http.ResponseWriter, r *http.Request) { | |||
return | |||
} | |||
resp.Token = &tokenString | |||
resp.LastLogin = userData.LastLogin | |||
if userData.LastLogin != nil { | |||
// need to set this after the user changes their password | |||
Db.Table("users").Where("id = ?", userData.ID).Update("last_login", loginTime) | |||
} | |||
err = json.NewEncoder(w).Encode(resp) | |||
if err != nil { | |||
@@ -124,6 +124,17 @@ func determinePrivileges() uint { | |||
} | |||
} | |||
func GetHashedPassword(password string) (encodedHash string, err error) { | |||
hash, err := GenerateHash(password, &Params{ | |||
Memory: uint32(viper.GetInt("hashing.memory")), | |||
Iterations: uint32(viper.GetInt("hashing.iterations")), | |||
Parallelism: uint8(viper.GetInt("hashing.parallelism")), | |||
SaltLength: uint32(viper.GetInt("hashing.saltLength")), | |||
KeyLength: uint32(viper.GetInt("hashing.keyLength")), | |||
}) | |||
return hash, err | |||
} | |||
func Register(w http.ResponseWriter, r *http.Request) { | |||
var req AuthenticationRequest | |||
err := json.NewDecoder(r.Body).Decode(&req) | |||
@@ -144,13 +155,7 @@ func Register(w http.ResponseWriter, r *http.Request) { | |||
} | |||
if noUsersRegistered() || claims.Privileges <= AdminPlus { | |||
hash, err := GenerateHash(req.Password, &Params{ | |||
Memory: uint32(viper.GetInt("hashing.memory")), | |||
Iterations: uint32(viper.GetInt("hashing.iterations")), | |||
Parallelism: uint8(viper.GetInt("hashing.parallelism")), | |||
SaltLength: uint32(viper.GetInt("hashing.saltLength")), | |||
KeyLength: uint32(viper.GetInt("hashing.keyLength")), | |||
}) | |||
hash, err := GetHashedPassword(req.Password) | |||
if err != nil { | |||
log.Error().Err(err).Msg("Could not generate hash for registration") | |||
return | |||
@@ -48,8 +48,9 @@ func main() { | |||
router.HandleFunc("/ContributorStream", endpoints.ContributorStream) | |||
router.HandleFunc("/Login", endpoints.Login) | |||
router.HandleFunc("/Register", endpoints.Register) | |||
router.HandleFunc("/ChangePassword", endpoints.ChangePassword) | |||
router.HandleFunc("/NearlyCompleteFunds", endpoints.NearlyCompleteFunds) | |||
router.HandleFunc("/EscalatePrivileges", endpoints.EscalatePrivileges) | |||
router.HandleFunc("/ChangePrivileges", endpoints.ChangePrivileges) | |||
router.HandleFunc("/UsersExist", endpoints.UsersExist) | |||
port := viper.GetInt("app.port") | |||