diff --git a/endpoints/authorization.go b/endpoints/authorization.go index 3aa1400..76e4033 100644 --- a/endpoints/authorization.go +++ b/endpoints/authorization.go @@ -15,26 +15,37 @@ const ( ) func getAuthorization(c *gin.Context) (AuthorizationScope, string) { + //get auth header header := c.GetHeader("Authorization") if header == "" { return AuthorizationScopeNone, "" } + + //check if user is authorized headerSpl := strings.Split(header, " ") if len(headerSpl) != 2 { return AuthorizationScopeNone, "" } prefix := headerSpl[0] token := strings.ToLower(headerSpl[1]) - if prefix == "Bearer" { - if storage.CheckLoginToken(token, c.ClientIP()) { - return AuthorizationScopeUser, token - } - } if prefix == "Bot" { + //attempt to authorize as bot if found, _ := storage.BotTokenFromToken(token); found { return AuthorizationScopeBot, token } } + if prefix == "Bearer" { + //attempt to authorize as user + userAgentString := c.GetHeader("User-Agent") + if userAgentString == "" { + return AuthorizationScopeNone, "" + } + ua := storage.ParseUA(userAgentString) + + if storage.CheckLoginToken(token, c.ClientIP(), ua) { + return AuthorizationScopeUser, token + } + } return AuthorizationScopeNone, "" } diff --git a/endpoints/endpoints.go b/endpoints/endpoints.go index 18979d8..4ec2f05 100644 --- a/endpoints/endpoints.go +++ b/endpoints/endpoints.go @@ -24,7 +24,8 @@ func Run() { public := r.Group("/") - public.POST("/login", login) //web login + public.POST("/login/password", loginPassword) //web login + public.POST("/login/token", loginToken) //web login public.GET("/access", access) //access token private := r.Group("/") diff --git a/endpoints/login.go b/endpoints/login.go index c72a02b..d8d4697 100644 --- a/endpoints/login.go +++ b/endpoints/login.go @@ -7,19 +7,30 @@ import ( "github.com/gin-gonic/gin" ) -type LoginBody struct { + +type LoginPasswordBody struct { Username string `json:"username" binding:"required"` Password string `json:"password" binding:"required"` } -func login(c *gin.Context) { - var loginBody LoginBody +type LoginTokenBody struct { + Token string `json:"token" binding:"required"` +} + +func loginPassword(c *gin.Context) { + var loginBody LoginPasswordBody if err := c.BindJSON(&loginBody); err != nil { fmt.Println(err) return } - - loggedIn, token := storage.CheckLogin(loginBody.Username, loginBody.Password, c.ClientIP()) + + userAgentString := c.GetHeader("User-Agent") + if userAgentString == "" { + return + } + ua := storage.ParseUA(userAgentString) + + loggedIn, token := storage.CheckLoginPassword(loginBody.Username, loginBody.Password, c.ClientIP(), ua) if loggedIn { c.JSON(200, gin.H{ @@ -32,8 +43,35 @@ func login(c *gin.Context) { } } + +func loginToken(c *gin.Context) { + var loginBody LoginTokenBody + if err := c.BindJSON(&loginBody); err != nil { + fmt.Println(err) + return + } + + userAgentString := c.GetHeader("User-Agent") + if userAgentString == "" { + return + } + ua := storage.ParseUA(userAgentString) + + loggedIn := storage.CheckLoginToken(loginBody.Token, c.ClientIP(), ua) + + if loggedIn { + c.JSON(200, gin.H{ + "token": loginBody.Token, + }) + } else { + c.JSON(401, gin.H{ + "error": "invalid username or password", + }) + } +} + func updateLogin(c *gin.Context) { - var updateLogin LoginBody + var updateLogin LoginPasswordBody if err := c.BindJSON(&updateLogin); err != nil { fmt.Println(err) return diff --git a/go.mod b/go.mod index 9b0196b..f65f767 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/leodido/go-urn v1.2.1 // indirect github.com/mattn/go-isatty v0.0.16 // indirect + github.com/mileusna/useragent v1.2.1 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.6 // indirect diff --git a/go.sum b/go.sum index 5438e6e..8581985 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,8 @@ github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ic github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mileusna/useragent v1.2.1 h1:p3RJWhi3LfuI6BHdddojREyK3p6qX67vIfOVMnUIVr0= +github.com/mileusna/useragent v1.2.1/go.mod h1:3d8TOmwL/5I8pJjyVDteHtgDGcefrFUX4ccGOMKNYYc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/storage/login.go b/storage/login.go index 6f38f6f..336e8a5 100644 --- a/storage/login.go +++ b/storage/login.go @@ -6,10 +6,41 @@ import ( "time" "git.zomo.dev/zomo/discord-retokenizer/util" - "github.com/go-redis/redis/v9" + "github.com/mileusna/useragent" "golang.org/x/crypto/bcrypt" ) +type UserAgentSimple struct { + Name string `json:"name"` + Version string `json:"version"` + OS string `json:"os"` + OSVersion string `json:"os_version"` + Mobile bool `json:"mobile"` + Tablet bool `json:"tablet"` + Desktop bool `json:"desktop"` +} + +func (ua UserAgentSimple) Compare(ua2 UserAgentSimple) bool { + return ua.Name == ua2.Name && + ua.OS == ua2.OS && + ua.Mobile == ua2.Mobile && + ua.Tablet == ua2.Tablet && + ua.Desktop == ua2.Desktop +} + +func ParseUA(userAgentString string) UserAgentSimple { + ua := useragent.Parse(userAgentString) + return UserAgentSimple{ + Name: ua.Name, + Version: ua.Version, + OS: ua.OS, + OSVersion: ua.OSVersion, + Mobile: ua.Mobile, + Tablet: ua.Tablet, + Desktop: ua.Desktop, + } +} + func UpdateUsername(username string) { if username != "" { client.Set(ctx, "username", username, 0) @@ -26,7 +57,7 @@ func UpdatePassword(password string) { } } -func CheckLogin(username string, password string, ip string) (bool, string) { +func CheckLoginPassword(username string, password string, ip string, userAgent UserAgentSimple) (bool, string) { if username == "" || password == "" { return false, "" } @@ -50,21 +81,40 @@ func CheckLogin(username string, password string, ip string) (bool, string) { return false, "" } - return true, createLoginToken(ip) + //return existing token if it exists + tokens := getLoginTokens() + fmt.Printf("There are %d tokens", len(tokens)) + for _, token := range tokens { + // fmt.Printf("Checking token %s\n", token.ID) + // fmt.Printf("IP:\n Given: %s\n Token: %s\n", ip, token.IP) + // fmt.Printf("UA:\n Given: %+v\n Token: %+v\n", userAgent, token.UserAgent) + // fmt.Printf("Compare UA: %t\n", token.UserAgent.Compare(userAgent)) + if token.IP == ip && userAgent.Compare(token.UserAgent) { + return true, token.Token + } + } + + return true, createLoginToken(ip, userAgent) } type LoginToken struct { - ID string `json:"id"` - TokenHash string `json:"token"` - IP string `json:"ip"` - End string `json:"end"` + ID string `json:"id"` + Token string `json:"token"` + IP string `json:"ip"` + End string `json:"end"` + UserAgent UserAgentSimple `json:"user_agent"` + CreatedAt string `json:"created_at"` + LastLogin string `json:"last_login"` } type LoginTokenSimple struct { - ID string `json:"id"` - IP string `json:"ip"` - End string `json:"end"` + ID string `json:"id"` + IP string `json:"ip"` + End string `json:"end"` + UserAgent UserAgentSimple `json:"user_agent"` + CreatedAt string `json:"created_at"` + LastLogin string `json:"last_login"` } func (t LoginToken) Simplify() LoginTokenSimple { @@ -72,6 +122,9 @@ func (t LoginToken) Simplify() LoginTokenSimple { ID: t.ID, IP: t.IP, End: t.End, + UserAgent: t.UserAgent, + CreatedAt: t.CreatedAt, + LastLogin: t.LastLogin, } } @@ -83,27 +136,25 @@ func (t *LoginToken) UnmarshalBinary(data []byte) error { return json.Unmarshal(data, t) } -func createLoginToken(ip string) string { +func createLoginToken(ip string, ua UserAgentSimple) string { token := util.GenerateToken() - - tokenHash, err := bcrypt.GenerateFromPassword([]byte(token), bcrypt.DefaultCost) - if err != nil { - panic(err) - } tokenData := LoginToken{ ID: util.GenerateID(), - TokenHash: string(tokenHash), + Token: token, IP: ip, End: util.GetEnd(token), + UserAgent: ua, + CreatedAt: time.Now().Format(time.RFC3339), + LastLogin: time.Now().Format(time.RFC3339), } - member := redis.Z{ - Score: float64(time.Now().Unix() + 4 * 60 * 60), - Member: tokenData, + err := client.RPush(ctx, "loginTokens", tokenData.ID).Err() + if err != nil { + panic(err) } - err = client.ZAdd(ctx, "loginTokens", member).Err() + err = client.Set(ctx, "loginToken:"+tokenData.ID, tokenData, 0).Err() if err != nil { panic(err) } @@ -112,27 +163,30 @@ func createLoginToken(ip string) string { } func getLoginTokens() []LoginToken { - expired, err := client.ZRangeByScore(ctx, "loginTokens", &redis.ZRangeBy{ - Min: "-inf", - Max: fmt.Sprintf("%d", time.Now().Unix()), - }).Result() - + var ids []string + err := client.LRange(ctx, "loginTokens", 0, -1).ScanSlice(&ids) if err != nil { panic(err) } - for _, e := range expired { - client.ZRem(ctx, "loginTokens", e) + var tokens []LoginToken + for _, id := range ids { + var token LoginToken + err = client.Get(ctx, "loginToken:"+id).Scan(&token) + if err != nil { + panic(err) + } + tokens = append(tokens, token) } + return tokens +} - var current []LoginToken - err = client.ZRange(ctx, "loginTokens", 0, -1).ScanSlice(¤t) - +func updateLastLoginToken(token LoginToken) { + token.CreatedAt = time.Now().Format(time.RFC3339) + err := client.Set(ctx, "loginToken:"+token.ID, token, 0).Err() if err != nil { panic(err) } - - return current } func GetLoginTokensSimple() []LoginTokenSimple { @@ -146,15 +200,34 @@ func GetLoginTokensSimple() []LoginTokenSimple { } func ClearLoginTokens() { - client.Del(ctx, "loginTokens") + var ids []string + err := client.LRange(ctx, "loginTokens", 0, -1).ScanSlice(&ids) + if err != nil { + panic(err) + } + + err = client.Del(ctx, "loginTokens").Err() + if err != nil { + panic(err) + } + + for _, id := range ids { + err = client.Del(ctx, "loginToken:"+id).Err() + if err != nil { + panic(err) + } + } } -func CheckLoginToken(token string, ip string) bool { +func CheckLoginToken(token string, ip string, userAgent UserAgentSimple) bool { current := getLoginTokens() for _, c := range current { - err := bcrypt.CompareHashAndPassword([]byte(c.TokenHash), []byte(token)) - if err == nil && ip == c.IP { + if token == c.Token && ip == c.IP && userAgent.Compare(c.UserAgent) { + fmt.Printf("Checking token %s\n", c.ID) + fmt.Printf("IP:\n Given: %s\n Token: %s\n", ip, c.IP) + fmt.Printf("UA:\n Given: %+v\n Token: %+v\n", userAgent, c.UserAgent) + updateLastLoginToken(c) return true } }