diff --git a/src/gin-cpularp.go b/src/gin-cpularp.go index 17ad1fa..56649fa 100755 --- a/src/gin-cpularp.go +++ b/src/gin-cpularp.go @@ -21,21 +21,21 @@ func main() { ExposeHeaders: []string{"Content-Length"}, AllowCredentials: true, })) + // Authentication Workflow router.GET("/auth/callback", api.AuthCallback) router.GET("/auth/login", api.AuthLoginRedirect) router.GET("/auth/logout", api.AuthLogoutRedirect) router.GET("/user/token/generate", api.CreateAPIToken) - // Create - router.POST("/:object", api.CreateObject) - // Update - router.PUT("/:object", api.UpdateObject) - // Read router.GET("/user/info", api.GetDiscordUser) router.GET("/user/authorized", api.GetUserLoggedIn) - router.GET("/:object", api.GetObjects) - router.GET("/all/:object", api.GetAllObjects) + // Create + router.POST("/:object", api.ObjectRequest) + // Update + router.PUT("/:object", api.ObjectRequest) + // Read + router.GET("/:object", api.ObjectRequest) // Delete - router.DELETE("/:object", api.DeleteObject) + router.DELETE("/:object", api.ObjectRequest) router.Run(":31337") } diff --git a/src/lib/api/api.go b/src/lib/api/api.go index e893a96..a04ef74 100644 --- a/src/lib/api/api.go +++ b/src/lib/api/api.go @@ -6,7 +6,7 @@ import ( "encoding/pem" "log" "net/http" - "strconv" + "slices" "strings" authdiscord "example.com/auth/discord" @@ -87,51 +87,41 @@ func checkAuthentication(context *gin.Context) *oauth2.Token { } } } else { - signedString := strings.Split(context.Request.Header.Get("Authorization"), " ")[1] - token, err := jwt.ParseWithClaims(signedString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { - userId, err := token.Claims.GetIssuer() + if authHeader := context.Request.Header.Get("Authorization"); authHeader != "" { + signedString := strings.Split(authHeader, " ")[1] + token, err := jwt.ParseWithClaims(signedString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { + userId, err := token.Claims.GetIssuer() + if err != nil { + return []byte(""), nil + } + apiKeyName, err := token.Claims.GetSubject() + if err != nil { + return []byte(""), nil + } + var user user.User + user.Get(GlobalDatabase, userId) + key := user.GetAPIKeySecret(GlobalDatabase, apiKeyName) + keyBlock, _ := pem.Decode([]byte(key)) + privateKey, _ := x509.ParseECPrivateKey(keyBlock.Bytes) + return &privateKey.PublicKey, nil + }) if err != nil { - return []byte(""), nil + log.Println(err) + context.AbortWithStatus(http.StatusBadRequest) + return nil } - apiKeyName, err := token.Claims.GetSubject() - if err != nil { - return []byte(""), nil + if token.Valid { + var oauthToken *oauth2.Token + userId, _ := token.Claims.GetIssuer() + json.Unmarshal([]byte((*user.Get(GlobalDatabase, []string{userId}))[0].LoginToken), &oauthToken) + return oauthToken } - var user user.User - user.Get(GlobalDatabase, userId) - key := user.GetAPIKeySecret(GlobalDatabase, apiKeyName) - keyBlock, _ := pem.Decode([]byte(key)) - privateKey, err := x509.ParseECPrivateKey(keyBlock.Bytes) - return &privateKey.PublicKey, nil - }) - if err != nil { - log.Println(err) - context.AbortWithStatus(http.StatusBadRequest) - return nil - } - if token.Valid { - var oauthToken *oauth2.Token - userId, _ := token.Claims.GetIssuer() - json.Unmarshal([]byte((*user.Get(GlobalDatabase, []string{userId}))[0].LoginToken), &oauthToken) - return oauthToken + } } return nil } -func objectIDStringsToInts(context *gin.Context, objectIDs []string) *[]uint { - var objectIDInts []uint - for _, objectID := range objectIDs { - objectIDInt, err := strconv.Atoi(objectID) - if err == nil { - objectIDInts = append(objectIDInts, uint(objectIDInt)) - } else { - context.AbortWithStatus(http.StatusBadRequest) - } - } - return &objectIDInts -} - // Authentication Workflow func AuthCallback(context *gin.Context) { @@ -183,11 +173,82 @@ func AuthLogoutRedirect(context *gin.Context) { } else { log.Println(err) } - log.Println(GlobalConfig.GetFrontendRootDomain()) context.Redirect(http.StatusTemporaryRedirect, GlobalConfig.GetFrontendRootDomain()) } -// Create Endpoints (POST) +// Public Functions + +func ObjectRequest(context *gin.Context) { + if !slices.Contains([]string{"GET", "POST", "PUT", "DELETE"}, context.Request.Method) { + context.AbortWithStatus(http.StatusBadRequest) + return + } + if (context.Request.Method != "GET") && (checkAuthentication(context) == nil) { + context.AbortWithStatus(http.StatusUnauthorized) + return + } + var modelNames []string + result := GlobalDatabase.Table("sqlite_master").Where("type = ?", "table").Pluck("name", &modelNames) + if result.Error != nil { + context.AbortWithStatus(http.StatusInternalServerError) + return + } + var filteredModelNames []string + for _, model := range modelNames { + if slices.Contains([]string{"api_keys", "sqlite_sequence"}, model) || slices.Contains(strings.Split(model, "_"), "associations") { + continue + } + model = strings.Replace(model, "people", "persons", 1) + model = strings.Replace(model, "_", "-", -1) + filteredModelNames = append(filteredModelNames, model[:len(model)-1]) + } + objectType := context.Param("object") + if !slices.Contains(filteredModelNames, objectType) { + context.AbortWithStatus(http.StatusBadRequest) + return + } + var err error + switch objectType { + case "user": + user.GetAll(GlobalDatabase) + case "person": + person.GetAll(GlobalDatabase) + case "group": + err = group.HandleRequest(GlobalDatabase, context) + case "character": + character.GetAll(GlobalDatabase) + case "role": + role.GetAll(GlobalDatabase) + case "tier": + tier.GetAll(GlobalDatabase) + case "function-set": + functionset.GetAll(GlobalDatabase) + case "function": + // result = function.Create(GlobalDatabase, context) + function.GetAll(GlobalDatabase) + case "function-tag": + // result = functiontag.Create(GlobalDatabase, context) + functiontag.GetAll(GlobalDatabase) + case "inventory-slot": + inventoryslot.GetAll(GlobalDatabase) + case "item": + item.GetAll(GlobalDatabase) + case "item-tag": + itemtag.GetAll(GlobalDatabase) + case "customization": + customization.GetAll(GlobalDatabase) + case "schematic": + schematic.GetAll(GlobalDatabase) + default: + context.AbortWithStatus(http.StatusBadRequest) + return + } + if err != nil { + context.Status(http.StatusInternalServerError) + return + } + context.Status(http.StatusOK) +} func CreateAPIToken(context *gin.Context) { name, nameOK := context.GetQuery("name") @@ -196,7 +257,7 @@ func CreateAPIToken(context *gin.Context) { if oauthToken != nil { oauthTokenJSON, err := json.Marshal(&oauthToken) if err != nil { - log.Printf("This should never happen, how did this happen???\n%s", err) + log.Println("This should never happen, how did this happen???\n", err) context.AbortWithStatus(http.StatusInternalServerError) return } @@ -205,7 +266,7 @@ func CreateAPIToken(context *gin.Context) { currentUser := (*user.Get(GlobalDatabase, []string{userId}))[0] result := currentUser.GenerateAPIKey(GlobalDatabase, name) if result != nil { - log.Printf("This should also never happen, how did this happen???\n%s", err) + log.Println("This should also never happen, how did this happen???\n", err) context.AbortWithStatus(http.StatusInternalServerError) return } @@ -227,105 +288,6 @@ func CreateAPIToken(context *gin.Context) { } } -func CreateObject(context *gin.Context) { - if checkAuthentication(context) != nil { - var result error - switch objectType := context.Param("object"); objectType { - case "user": - // - case "person": - // - case "group": - result = group.Create(GlobalDatabase, context) - case "character": - // - case "role": - // - case "tier": - // - case "function-set": - // - case "function": - result = function.Create(GlobalDatabase, context) - case "function-tag": - result = functiontag.Create(GlobalDatabase, context) - case "inventory-slot": - // - case "item": - // - case "item-tag": - // - case "customization": - // - case "schematic": - // - } - if result != nil { - context.JSON(http.StatusBadRequest, gin.H{ - "message": result.Error(), - }) - } else { - context.Status(http.StatusOK) - } - } else { - context.AbortWithStatus(http.StatusUnauthorized) - } -} - -// Update Endpoints (PUT) - -func UpdateObject(context *gin.Context) { - if checkAuthentication(context) != nil { - _, idOk := context.GetQuery("id") - if idOk { - var result error - switch objectType := context.Param("object"); objectType { - case "user": - // - case "person": - // - case "group": - result = group.Update(GlobalDatabase, context) - case "character": - // - case "role": - // - case "tier": - // - case "function-set": - // - case "function": - result = function.Update(GlobalDatabase, context) - case "function-tag": - result = functiontag.Update(GlobalDatabase, context) - case "inventory-slot": - // - case "item": - // - case "item-tag": - // - case "customization": - // - case "schematic": - // - } - if result != nil { - context.JSON(http.StatusBadRequest, gin.H{ - "message": result.Error(), - }) - } else { - context.Status(http.StatusOK) - } - } else { - context.AbortWithStatus(http.StatusBadRequest) - } - } else { - context.AbortWithStatus(http.StatusUnauthorized) - } -} - -// Read Endpoints (GET) - func GetDiscordUser(context *gin.Context) { oauthToken := checkAuthentication(context) if oauthToken != nil { @@ -350,191 +312,3 @@ func GetUserLoggedIn(context *gin.Context) { "message": (checkAuthentication(context) != nil), }) } - -func GetObjects(context *gin.Context) { - objectIDs, idOk := context.GetQueryArray("id") - if idOk { - switch objectType := context.Param("object"); objectType { - case "user": - context.JSON(http.StatusOK, gin.H{ - "user": user.Get(GlobalDatabase, objectIDs)}) - case "persons": - context.JSON(http.StatusOK, gin.H{ - "persons": person.Get(GlobalDatabase, objectIDs), - }) - case "groups": - objectIDInts := objectIDStringsToInts(context, objectIDs) - context.JSON(http.StatusOK, gin.H{ - "groups": group.Get(GlobalDatabase, *objectIDInts), - }) - case "characters": - context.JSON(http.StatusOK, gin.H{ - "characters": character.Get(GlobalDatabase, objectIDs), - }) - case "roles": - context.JSON(http.StatusOK, gin.H{ - "roles": role.Get(GlobalDatabase, objectIDs), - }) - case "tiers": - objectIDInts := objectIDStringsToInts(context, objectIDs) - context.JSON(http.StatusOK, gin.H{ - "tiers": tier.Get(GlobalDatabase, *objectIDInts), - }) - case "function-sets": - objectIDInts := objectIDStringsToInts(context, objectIDs) - context.JSON(http.StatusOK, gin.H{ - "function_sets": functionset.Get(GlobalDatabase, *objectIDInts), - }) - case "functions": - var uintObjectIDs []uint - for _, objectID := range objectIDs { - uintObjectID, _ := strconv.Atoi(objectID) - uintObjectIDs = append(uintObjectIDs, uint(uintObjectID)) - } - context.JSON(http.StatusOK, gin.H{ - "functions": function.Get(GlobalDatabase, uintObjectIDs), - }) - case "function-tags": - objectIDInts := objectIDStringsToInts(context, objectIDs) - context.JSON(http.StatusOK, gin.H{ - "function_tags": functiontag.Get(GlobalDatabase, *objectIDInts), - }) - case "inventory-slot": - objectIDInts := objectIDStringsToInts(context, objectIDs) - context.JSON(http.StatusOK, gin.H{ - "inventory_slot": inventoryslot.Get(GlobalDatabase, *objectIDInts), - }) - case "items": - context.JSON(http.StatusOK, gin.H{ - "items": item.Get(GlobalDatabase, objectIDs), - }) - case "item-tags": - context.JSON(http.StatusOK, gin.H{ - "item_tags": itemtag.Get(GlobalDatabase, objectIDs), - }) - case "customizations": - context.JSON(http.StatusOK, gin.H{ - "customizations": customization.Get(GlobalDatabase, objectIDs), - }) - case "schematics": - objectIDInts := objectIDStringsToInts(context, objectIDs) - context.JSON(http.StatusOK, gin.H{ - "schematics": schematic.Get(GlobalDatabase, *objectIDInts), - }) - } - } else { - context.Status(http.StatusBadRequest) - } -} - -func GetAllObjects(context *gin.Context) { - switch objectType := context.Param("object"); objectType { - case "persons": - context.JSON(http.StatusOK, gin.H{ - "persons": person.GetAll(GlobalDatabase), - }) - case "groups": - context.JSON(http.StatusOK, gin.H{ - "groups": group.GetAll(GlobalDatabase), - }) - case "characters": - context.JSON(http.StatusOK, gin.H{ - "characters": character.GetAll(GlobalDatabase), - }) - case "roles": - context.JSON(http.StatusOK, gin.H{ - "roles": role.GetAll(GlobalDatabase), - }) - case "tiers": - context.JSON(http.StatusOK, gin.H{ - "tiers": tier.GetAll(GlobalDatabase), - }) - case "function-sets": - context.JSON(http.StatusOK, gin.H{ - "function_sets": functionset.GetAll(GlobalDatabase), - }) - case "functions": - context.JSON(http.StatusOK, gin.H{ - "functions": function.GetAll(GlobalDatabase), - }) - case "function-tags": - context.JSON(http.StatusOK, gin.H{ - "function_tags": functiontag.GetAll(GlobalDatabase), - }) - case "inventory-slot": - context.JSON(http.StatusOK, gin.H{ - "inventory_slot": inventoryslot.GetAll(GlobalDatabase), - }) - case "items": - context.JSON(http.StatusOK, gin.H{ - "items": item.GetAll(GlobalDatabase), - }) - case "item-tags": - context.JSON(http.StatusOK, gin.H{ - "item_tags": itemtag.GetAll(GlobalDatabase), - }) - case "customizations": - context.JSON(http.StatusOK, gin.H{ - "customizations": customization.GetAll(GlobalDatabase), - }) - case "schematics": - context.JSON(http.StatusOK, gin.H{ - "schematics": schematic.GetAll(GlobalDatabase), - }) - } -} - -// Delete Endpoints (DELETE) - -func DeleteObject(context *gin.Context) { - if checkAuthentication(context) != nil { - objectIDs, idOk := context.GetQueryArray("id") - if idOk { - uintObjectIDs := objectIDStringsToInts(context, objectIDs) - var result error - switch objectType := context.Param("object"); objectType { - case "users": - // - case "persons": - // - case "groups": - log.Println(*uintObjectIDs, objectIDs) - result = group.Delete(GlobalDatabase, *uintObjectIDs) - case "characters": - // - case "roles": - // - case "tiers": - // - case "function-sets": - // - case "functions": - log.Println(uintObjectIDs) - result = function.Delete(GlobalDatabase, *uintObjectIDs) - case "function-tags": - result = functiontag.Delete(GlobalDatabase, *uintObjectIDs) - case "inventory-slots": - // - case "items": - // - case "item-tags": - // - case "customizations": - // - case "schematics": - // - } - if result != nil { - context.JSON(http.StatusBadRequest, gin.H{ - "message": result.Error(), - }) - } else { - context.Status(http.StatusOK) - } - } else { - context.AbortWithStatus(http.StatusBadRequest) - } - } else { - context.AbortWithStatus(http.StatusUnauthorized) - } -} diff --git a/src/lib/database/character/character.go b/src/lib/database/character/character.go index 7c6652c..50d8a2b 100644 --- a/src/lib/database/character/character.go +++ b/src/lib/database/character/character.go @@ -70,7 +70,7 @@ func (character Character) Delete(db *gorm.DB) error { return nil } -func Create(db *gorm.DB, name string, owners []string, roles []string, functionsets []uint, inventory []uint) error { +func Create(db *gorm.DB, name string, owners []uint, roles []string, functionsets []uint, inventory []uint) error { return Character{ Name: name, Owners: *person.Get(db, owners), @@ -99,7 +99,7 @@ func GetAll(db *gorm.DB) *[]Character { return Get(db, outputCharacterNames) } -func Update(db *gorm.DB, name string, owners []string, roles []string, functionsets []uint, inventory []uint) error { +func Update(db *gorm.DB, name string, owners []uint, roles []string, functionsets []uint, inventory []uint) error { return Character{ Name: name, Owners: *person.Get(db, owners), diff --git a/src/lib/database/function/function.go b/src/lib/database/function/function.go index e43ad2c..0230e33 100644 --- a/src/lib/database/function/function.go +++ b/src/lib/database/function/function.go @@ -2,6 +2,7 @@ package function import ( "encoding/json" + "errors" "io" "log" "strconv" @@ -40,6 +41,10 @@ func (function *Function) getAssociations(db *gorm.DB) { } func (params *functionParams) validate(context *gin.Context) error { + ID, IDOk := context.GetQuery("id") + if !IDOk { + return errors.New("ID was not included in the request") + } body, err := io.ReadAll(context.Request.Body) if err != nil { log.Println(err) @@ -49,7 +54,7 @@ func (params *functionParams) validate(context *gin.Context) error { var newParams functionParams err = json.Unmarshal(body, &newParams) log.Println(err, newParams) - params.Id = newParams.Id + params.Id = ID params.Name = newParams.Name params.Tags = newParams.Tags params.Requirements = newParams.Requirements @@ -141,7 +146,6 @@ func Update(db *gorm.DB, context *gin.Context) error { func Delete(db *gorm.DB, inputFunctions []uint) error { functions := Get(db, inputFunctions) - log.Println(inputFunctions, functions) for _, function := range *functions { err := function.Delete(db) if err != nil { diff --git a/src/lib/database/functiontag/functiontag.go b/src/lib/database/functiontag/functiontag.go index c8ffe16..a2255b0 100644 --- a/src/lib/database/functiontag/functiontag.go +++ b/src/lib/database/functiontag/functiontag.go @@ -123,10 +123,10 @@ func Update(db *gorm.DB, context *gin.Context) error { }.update(db) } -func Delete(db *gorm.DB, inputFunctions []uint) error { - functions := Get(db, inputFunctions) - for _, function := range *functions { - err := function.delete(db) +func Delete(db *gorm.DB, inputFunctionTags []uint) error { + functionTags := Get(db, inputFunctionTags) + for _, functiontag := range *functionTags { + err := functiontag.delete(db) if err != nil { return err } diff --git a/src/lib/database/group/group.go b/src/lib/database/group/group.go index df509e7..3019406 100644 --- a/src/lib/database/group/group.go +++ b/src/lib/database/group/group.go @@ -4,7 +4,8 @@ import ( "encoding/json" "errors" "io" - "log" + "net/http" + "slices" "strconv" "github.com/gin-gonic/gin" @@ -25,12 +26,11 @@ type groupParams struct { func (params *groupParams) validate(context *gin.Context) error { ID, IDOk := context.GetQuery("id") - if !IDOk { + if !IDOk && slices.Contains([]string{"PUT"}, context.Request.Method) { return errors.New("ID was not included in the request") } body, err := io.ReadAll(context.Request.Body) if err != nil { - log.Println(err) return err } var name groupParams @@ -40,7 +40,7 @@ func (params *groupParams) validate(context *gin.Context) error { return nil } -func (group *Group) Get(db *gorm.DB, inputGroup uint) { +func (group *Group) get(db *gorm.DB, inputGroup uint) { db.Model(&Group{}).Where("ID = ?", inputGroup).Take(&group) } @@ -63,7 +63,7 @@ func (group Group) delete(db *gorm.DB) error { return nil } -func (group Group) Create(db *gorm.DB) error { +func (group Group) create(db *gorm.DB) error { result := db.Create(&group) if result.Error != nil { return result.Error @@ -83,30 +83,22 @@ func Create(db *gorm.DB, context *gin.Context) error { } return Group{ Name: newGroup.Name, - }.Create(db) + }.create(db) } func Get(db *gorm.DB, inputGroups []uint) *[]Group { var outputGroups []Group + if len(inputGroups) < 1 { + db.Model(&Group{}).Select("id").Find(&inputGroups) + } for _, inputGroup := range inputGroups { var outputGroup Group - outputGroup.Get(db, inputGroup) + outputGroup.get(db, inputGroup) outputGroups = append(outputGroups, outputGroup) } return &outputGroups } -func GetAll(db *gorm.DB) *[]Group { - var outputGroups []Group - var outputGroupIDs []uint - result := db.Model(&Group{}).Select("id").Find(&outputGroupIDs) - if result.Error != nil { - log.Println(result.Error) - } - outputGroups = *Get(db, outputGroupIDs) - return &outputGroups -} - func Update(db *gorm.DB, context *gin.Context) error { var params groupParams err := params.validate(context) @@ -124,8 +116,19 @@ func Update(db *gorm.DB, context *gin.Context) error { } func Delete(db *gorm.DB, inputGroups []uint) error { - groups := Get(db, inputGroups) - for _, group := range *groups { + var groups []Group + // if len(inputGroups) < 1 { + // result := db.Model(&Group{}).Select("id").Find(&inputGroups) + // if result.Error != nil { + // return result.Error + // } + // } + for _, inputGroup := range inputGroups { + var group Group + group.get(db, inputGroup) + groups = append(groups, group) + } + for _, group := range groups { err := group.delete(db) if err != nil { return err @@ -133,3 +136,34 @@ func Delete(db *gorm.DB, inputGroups []uint) error { } return nil } + +func HandleRequest(db *gorm.DB, context *gin.Context) error { + var params groupParams + if err := params.validate(context); err != nil { + return err + } + idArray, _ := context.GetQueryArray("id") + var idUintArray []uint + for _, id := range idArray { + idUint, err := strconv.Atoi(id) + if err != nil { + return err + } + idUintArray = append(idUintArray, uint(idUint)) + } + var err error + switch context.Request.Method { + case "GET": + result := Get(db, idUintArray) + context.JSON(http.StatusOK, gin.H{ + "result": result, + }) + case "POST": + err = Create(db, context) + case "PUT": + err = Update(db, context) + case "DELETE": + err = Delete(db, idUintArray) + } + return err +} diff --git a/src/lib/database/person/person.go b/src/lib/database/person/person.go index 681d18c..3cbbc8f 100644 --- a/src/lib/database/person/person.go +++ b/src/lib/database/person/person.go @@ -26,14 +26,14 @@ func (person *Person) getAssociations(db *gorm.DB) { db.Model(&person).Association("Groups").Find(&person.Groups) } -func (person *Person) Get(db *gorm.DB, inputPerson string) { - db.Where("name = ?", inputPerson).Take(&person) +func (person *Person) Get(db *gorm.DB, inputPerson uint) { + db.Where("id = ?", inputPerson).Take(&person) person.getAssociations(db) } func (person Person) Update(db *gorm.DB) error { var originalPerson Person - originalPerson.Get(db, person.Name) + originalPerson.Get(db, person.ID) groupsError := db.Model(&originalPerson).Association("Groups").Replace(&person.Groups) if groupsError != nil { return groupsError @@ -56,7 +56,14 @@ func Create(db *gorm.DB, name string, groups []uint) error { }.Create(db) } -func Get(db *gorm.DB, inputPersons []string) *[]Person { +func GetByName(db *gorm.DB, name string) *Person { + var person Person + db.Model(&Person{}).Where("name = ?", name).Take(&person) + person.getAssociations(db) + return &person +} + +func Get(db *gorm.DB, inputPersons []uint) *[]Person { var outputPersons []Person for _, inputPerson := range inputPersons { var outputPerson Person @@ -67,8 +74,8 @@ func Get(db *gorm.DB, inputPersons []string) *[]Person { } func GetAll(db *gorm.DB) *[]Person { - var outputPersonNames []string - result := db.Model(&Person{}).Select("name").Find(&outputPersonNames) + var outputPersonNames []uint + result := db.Model(&Person{}).Select("id").Find(&outputPersonNames) if result.Error != nil { log.Println(result.Error) } @@ -82,9 +89,13 @@ func Update(db *gorm.DB, name string, groups []uint) error { }.Update(db) } -func Delete(db *gorm.DB, inputPersons []string) { +func Delete(db *gorm.DB, inputPersons []uint) error { persons := Get(db, inputPersons) for _, person := range *persons { - person.Delete(db) + err := person.Delete(db) + if err != nil { + return err + } } + return nil } diff --git a/src/lib/database/user/user.go b/src/lib/database/user/user.go index ff46976..dfb30d2 100644 --- a/src/lib/database/user/user.go +++ b/src/lib/database/user/user.go @@ -145,11 +145,11 @@ func Exists(db *gorm.DB, id string) bool { return (err == nil) } -func Create(db *gorm.DB, id string, displayName string, username string, avatar string, avatarDecoration string, loginToken string, loggedIn bool) error { +func Create(db *gorm.DB, discordId string, displayName string, username string, avatar string, avatarDecoration string, loginToken string, loggedIn bool) error { person.Create(db, displayName, []uint{}) - newPerson := (*person.Get(db, []string{displayName}))[0] + newPerson := person.GetByName(db, username) newUser := User{ - Id: id, + Id: discordId, Person: person.Person{}, DisplayName: displayName, Username: username, @@ -189,10 +189,10 @@ func GetAll(db *gorm.DB) *[]User { return Get(db, outputUserIDs) } -func Update(db *gorm.DB, id string, displayName string, username string, avatar string, avatarDecoration string, loginToken string, loggedIn bool) error { - newPerson := (*person.Get(db, []string{displayName}))[0] +func Update(db *gorm.DB, discordId string, displayName string, username string, avatar string, avatarDecoration string, loginToken string, loggedIn bool) error { + newPerson := *person.GetByName(db, username) return User{ - Id: id, + Id: discordId, Person: newPerson, DisplayName: displayName, Username: username,