continued auth system
This commit is contained in:
7
.env.example
Normal file
7
.env.example
Normal file
@@ -0,0 +1,7 @@
|
||||
CLIENT_ID= # Twitch Client ID
|
||||
# Required
|
||||
REDIR_URI= # Twitch OAuth Redirect URI
|
||||
# Required
|
||||
|
||||
SQLITE_DB= # SQlite DB location
|
||||
# Default: ./db.sqlite
|
||||
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -49,8 +50,26 @@ func (server *ApiServer) loadEndpoints() {
|
||||
})
|
||||
|
||||
server.engine.GET("/auth", func(c *gin.Context) {
|
||||
q := c.Request.URL.Query()
|
||||
if resp := loadAuthQueryOk(q); resp != nil {
|
||||
// ok
|
||||
// TODO check state (need state system)
|
||||
// TODO POST https://id.twitch.tv/oauth2/token - returns TwitchAuthTokenResp
|
||||
// convert expiresIn to time.Time (minus like 15 minutes as a buffer period)
|
||||
// UpdateUserAuth()
|
||||
// TODO return twitch ok (or err if can't POST)
|
||||
} else if resp := loadAuthQueryErr(q); resp != nil {
|
||||
// err from twitch
|
||||
// TODO check state (need state system)
|
||||
// TODO return twitch err
|
||||
} else {
|
||||
// err in params
|
||||
// TODO return param err
|
||||
}
|
||||
|
||||
// TODO auth response from twitch
|
||||
// parse args as TwitchAuthRespOk or TwitchAuthRespErr
|
||||
// verify state with db and client id with config
|
||||
c.JSON(http.StatusOK, serverInfo)
|
||||
})
|
||||
}
|
||||
@@ -69,6 +88,28 @@ type TwitchAuthParams struct {
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
func loadAuthQueryOk(query url.Values) *TwitchAuthRespOk {
|
||||
if query.Has("code") && query.Has("scope") && query.Has("state") {
|
||||
return &TwitchAuthRespOk{
|
||||
Code: query.Get("code"),
|
||||
Scope: query.Get("scope"),
|
||||
State: query.Get("state"),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadAuthQueryErr(query url.Values) *TwitchAuthRespErr {
|
||||
if query.Has("error") && query.Has("error_description") && query.Has("state") {
|
||||
return &TwitchAuthRespErr{
|
||||
Err: query.Get("error"),
|
||||
ErrDesc: query.Get("error_description"),
|
||||
State: query.Get("state"),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type TwitchAuthRespOk struct {
|
||||
Code string `json:"code"`
|
||||
Scope string `json:"scope"`
|
||||
@@ -80,3 +121,11 @@ type TwitchAuthRespErr struct {
|
||||
ErrDesc string `json:"error_description"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
type TwitchAuthTokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Scope []string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
64
db/db_cold/authTokens.go
Normal file
64
db/db_cold/authTokens.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package db_cold
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserAuth struct {
|
||||
gorm.Model
|
||||
UserID string `gorm:"primarykey"`
|
||||
UserName string
|
||||
UserLogin string
|
||||
UserEmail string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
TokenExpires time.Time
|
||||
}
|
||||
|
||||
func (db *DBColdConn) initUserAuth() {
|
||||
db.Gorm.AutoMigrate(&UserAuth{})
|
||||
}
|
||||
|
||||
func (db *DBColdConn) GetAllUserAuth() ([]UserAuth, error) {
|
||||
var userAuths []UserAuth
|
||||
res := db.Gorm.Find(&userAuths)
|
||||
if res.Error != nil {
|
||||
return nil, res.Error
|
||||
}
|
||||
return userAuths, nil
|
||||
}
|
||||
|
||||
// add or update user auth, based on ID
|
||||
func (db *DBColdConn) UpdateUserAuth(userID, userName, userLogin, accessToken, refreshToken string, tokenExpires time.Time) error {
|
||||
userAuth := UserAuth{
|
||||
UserID: userID,
|
||||
UserName: userName,
|
||||
UserLogin: userLogin,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
TokenExpires: tokenExpires,
|
||||
}
|
||||
|
||||
rows, err := gorm.G[UserAuth](db.Gorm).Where("id = ?", userID).Find(db.Ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(rows) > 0 {
|
||||
// update
|
||||
_, err := gorm.G[UserAuth](db.Gorm).Where("id = ?", userID).Updates(db.Ctx, userAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// add
|
||||
err := gorm.G[UserAuth](db.Gorm).Create(db.Ctx, &userAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,10 +1,29 @@
|
||||
package db_cold
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"zomo.dev/largehadroncollider/util"
|
||||
)
|
||||
|
||||
// sqlite file
|
||||
|
||||
func InitDBColdConn() (DBColdConn, error) {
|
||||
return DBColdConn{}, nil
|
||||
func InitDBColdConn(conf *util.Config) (*DBColdConn, error) {
|
||||
db, err := gorm.Open(sqlite.Open(conf.SQliteDB), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
cold := &DBColdConn{ db, ctx }
|
||||
cold.initUserAuth()
|
||||
|
||||
return cold, nil
|
||||
}
|
||||
|
||||
type DBColdConn struct {
|
||||
Gorm *gorm.DB
|
||||
Ctx context.Context
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package db_hot
|
||||
|
||||
import "zomo.dev/largehadroncollider/util"
|
||||
|
||||
// redis connection
|
||||
|
||||
func InitDBHotConn() (DBHotConn, error) {
|
||||
return DBHotConn{}, nil
|
||||
func InitDBHotConn(conf *util.Config) (*DBHotConn, error) {
|
||||
return &DBHotConn{}, nil
|
||||
}
|
||||
|
||||
type DBHotConn struct {
|
||||
|
||||
@@ -7,11 +7,11 @@ import (
|
||||
)
|
||||
|
||||
func InitDBConn(conf *util.Config) (*DBConn, error) {
|
||||
hot, err := db_hot.InitDBHotConn()
|
||||
hot, err := db_hot.InitDBHotConn(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cold, err := db_cold.InitDBColdConn()
|
||||
cold, err := db_cold.InitDBColdConn(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -19,6 +19,6 @@ func InitDBConn(conf *util.Config) (*DBConn, error) {
|
||||
}
|
||||
|
||||
type DBConn struct {
|
||||
hot db_hot.DBHotConn
|
||||
cold db_cold.DBColdConn
|
||||
Hot *db_hot.DBHotConn
|
||||
Cold *db_cold.DBColdConn
|
||||
}
|
||||
|
||||
50
ttv/auth.go
50
ttv/auth.go
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/adeithe/go-twitch/api"
|
||||
"zomo.dev/largehadroncollider/db"
|
||||
"zomo.dev/largehadroncollider/db/db_cold"
|
||||
"zomo.dev/largehadroncollider/util"
|
||||
)
|
||||
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
func initAuth(conf *util.Config, dbConn *db.DBConn) (*TwitchAuth, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
tokens, err := getTokensFromDB(dbConn)
|
||||
tokens, err := dbConn.Cold.GetAllUserAuth()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -26,27 +27,24 @@ func initAuth(conf *util.Config, dbConn *db.DBConn) (*TwitchAuth, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, account := range accounts {
|
||||
err := dbConn.Cold.UpdateUserAuth(account.UserID, account.UserName, account.UserLogin, account.AccessToken, account.RefreshToken, account.TokenExpires)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &TwitchAuth{ ctx, client, accounts }, nil
|
||||
}
|
||||
|
||||
type TwitchAuth struct {
|
||||
Ctx context.Context
|
||||
Client *api.Client
|
||||
Accounts []TwitchAuthAccount
|
||||
Accounts []db_cold.UserAuth
|
||||
}
|
||||
|
||||
type TwitchAuthAccount struct {
|
||||
api.User
|
||||
Token string
|
||||
}
|
||||
|
||||
func getTokensFromDB(dbConn *db.DBConn) ([]string, error) {
|
||||
// TODO db cold
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
func testTokens(ctx context.Context, client *api.Client, tokens []string) ([]TwitchAuthAccount, error) {
|
||||
accounts := make([]TwitchAuthAccount, 0)
|
||||
func testTokens(ctx context.Context, client *api.Client, tokens []db_cold.UserAuth) ([]db_cold.UserAuth, error) {
|
||||
accounts := make([]db_cold.UserAuth, 0)
|
||||
for _, token := range tokens {
|
||||
account, err := testToken(ctx, client, token)
|
||||
if err != nil {
|
||||
@@ -57,22 +55,30 @@ func testTokens(ctx context.Context, client *api.Client, tokens []string) ([]Twi
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
func testToken(ctx context.Context, client *api.Client, token string) (TwitchAuthAccount, error) {
|
||||
users, err := client.Users.List().Do(ctx, api.WithBearerToken(token))
|
||||
func testToken(ctx context.Context, client *api.Client, token db_cold.UserAuth) (db_cold.UserAuth, error) {
|
||||
// TODO check refresh time, refresh token if needed
|
||||
|
||||
users, err := client.Users.List().Do(ctx, api.WithBearerToken(token.AccessToken))
|
||||
if err != nil {
|
||||
return TwitchAuthAccount{}, err
|
||||
return db_cold.UserAuth{}, err
|
||||
}
|
||||
|
||||
usersData := users.Data
|
||||
|
||||
if len(usersData) <= 0 {
|
||||
return TwitchAuthAccount{}, errors.New("user data returned an empty array")
|
||||
return db_cold.UserAuth{}, errors.New("user data returned an empty array")
|
||||
}
|
||||
|
||||
// from twitch
|
||||
mainUser := usersData[0]
|
||||
token.UserLogin = mainUser.UserLogin
|
||||
token.UserName = mainUser.UserName
|
||||
token.UserEmail = mainUser.Email
|
||||
|
||||
return TwitchAuthAccount{
|
||||
mainUser,
|
||||
token,
|
||||
}, nil
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func refreshToken(token db_cold.UserAuth) (db_cold.UserAuth, error) {
|
||||
// TODO get new access token using refresh token
|
||||
// TODO this should be called regularly, as needed based on Expires
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ func LoadConfig() (*Config, error) {
|
||||
}
|
||||
// other sources?
|
||||
|
||||
config.def()
|
||||
config.verify()
|
||||
return &config, nil
|
||||
}
|
||||
@@ -24,6 +25,11 @@ func LoadConfig() (*Config, error) {
|
||||
type Config struct {
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
SQliteDB string
|
||||
}
|
||||
|
||||
func (c *Config) def() {
|
||||
c.SQliteDB = "./db.sqlite"
|
||||
}
|
||||
|
||||
func (c *Config) loadEnv() error {
|
||||
@@ -38,6 +44,9 @@ func (c *Config) loadEnv() error {
|
||||
if str, found := os.LookupEnv("REDIR_URI"); found {
|
||||
c.RedirectURI = strings.TrimSpace(str)
|
||||
}
|
||||
if str, found := os.LookupEnv("SQLITE_DB"); found {
|
||||
c.SQliteDB = strings.TrimSpace(str)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user