Refactor to have models handle their own requests.

This commit is contained in:
Ada Werefox 2025-05-13 23:10:55 -05:00
parent 490be49808
commit c273d061b5
8 changed files with 202 additions and 379 deletions

View file

@ -21,21 +21,21 @@ func main() {
ExposeHeaders: []string{"Content-Length"}, ExposeHeaders: []string{"Content-Length"},
AllowCredentials: true, AllowCredentials: true,
})) }))
// Authentication Workflow // Authentication Workflow
router.GET("/auth/callback", api.AuthCallback) router.GET("/auth/callback", api.AuthCallback)
router.GET("/auth/login", api.AuthLoginRedirect) router.GET("/auth/login", api.AuthLoginRedirect)
router.GET("/auth/logout", api.AuthLogoutRedirect) router.GET("/auth/logout", api.AuthLogoutRedirect)
router.GET("/user/token/generate", api.CreateAPIToken) router.GET("/user/token/generate", api.CreateAPIToken)
// Create
router.POST("/:object", api.CreateObject)
// Update
router.PUT("/:object", api.UpdateObject)
// Read
router.GET("/user/info", api.GetDiscordUser) router.GET("/user/info", api.GetDiscordUser)
router.GET("/user/authorized", api.GetUserLoggedIn) router.GET("/user/authorized", api.GetUserLoggedIn)
router.GET("/:object", api.GetObjects) // Create
router.GET("/all/:object", api.GetAllObjects) router.POST("/:object", api.ObjectRequest)
// Update
router.PUT("/:object", api.ObjectRequest)
// Read
router.GET("/:object", api.ObjectRequest)
// Delete // Delete
router.DELETE("/:object", api.DeleteObject) router.DELETE("/:object", api.ObjectRequest)
router.Run(":31337") router.Run(":31337")
} }

View file

@ -6,7 +6,7 @@ import (
"encoding/pem" "encoding/pem"
"log" "log"
"net/http" "net/http"
"strconv" "slices"
"strings" "strings"
authdiscord "example.com/auth/discord" authdiscord "example.com/auth/discord"
@ -87,51 +87,41 @@ func checkAuthentication(context *gin.Context) *oauth2.Token {
} }
} }
} else { } else {
signedString := strings.Split(context.Request.Header.Get("Authorization"), " ")[1] if authHeader := context.Request.Header.Get("Authorization"); authHeader != "" {
token, err := jwt.ParseWithClaims(signedString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { signedString := strings.Split(authHeader, " ")[1]
userId, err := token.Claims.GetIssuer() token, err := jwt.ParseWithClaims(signedString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
userId, err := token.Claims.GetIssuer()
if err != nil {
return []byte(""), nil
}
apiKeyName, err := token.Claims.GetSubject()
if err != nil {
return []byte(""), nil
}
var user user.User
user.Get(GlobalDatabase, userId)
key := user.GetAPIKeySecret(GlobalDatabase, apiKeyName)
keyBlock, _ := pem.Decode([]byte(key))
privateKey, _ := x509.ParseECPrivateKey(keyBlock.Bytes)
return &privateKey.PublicKey, nil
})
if err != nil { if err != nil {
return []byte(""), nil log.Println(err)
context.AbortWithStatus(http.StatusBadRequest)
return nil
} }
apiKeyName, err := token.Claims.GetSubject() if token.Valid {
if err != nil { var oauthToken *oauth2.Token
return []byte(""), nil userId, _ := token.Claims.GetIssuer()
json.Unmarshal([]byte((*user.Get(GlobalDatabase, []string{userId}))[0].LoginToken), &oauthToken)
return oauthToken
} }
var user user.User
user.Get(GlobalDatabase, userId)
key := user.GetAPIKeySecret(GlobalDatabase, apiKeyName)
keyBlock, _ := pem.Decode([]byte(key))
privateKey, err := x509.ParseECPrivateKey(keyBlock.Bytes)
return &privateKey.PublicKey, nil
})
if err != nil {
log.Println(err)
context.AbortWithStatus(http.StatusBadRequest)
return nil
}
if token.Valid {
var oauthToken *oauth2.Token
userId, _ := token.Claims.GetIssuer()
json.Unmarshal([]byte((*user.Get(GlobalDatabase, []string{userId}))[0].LoginToken), &oauthToken)
return oauthToken
} }
} }
return nil return nil
} }
func objectIDStringsToInts(context *gin.Context, objectIDs []string) *[]uint {
var objectIDInts []uint
for _, objectID := range objectIDs {
objectIDInt, err := strconv.Atoi(objectID)
if err == nil {
objectIDInts = append(objectIDInts, uint(objectIDInt))
} else {
context.AbortWithStatus(http.StatusBadRequest)
}
}
return &objectIDInts
}
// Authentication Workflow // Authentication Workflow
func AuthCallback(context *gin.Context) { func AuthCallback(context *gin.Context) {
@ -183,11 +173,82 @@ func AuthLogoutRedirect(context *gin.Context) {
} else { } else {
log.Println(err) log.Println(err)
} }
log.Println(GlobalConfig.GetFrontendRootDomain())
context.Redirect(http.StatusTemporaryRedirect, GlobalConfig.GetFrontendRootDomain()) context.Redirect(http.StatusTemporaryRedirect, GlobalConfig.GetFrontendRootDomain())
} }
// Create Endpoints (POST) // Public Functions
func ObjectRequest(context *gin.Context) {
if !slices.Contains([]string{"GET", "POST", "PUT", "DELETE"}, context.Request.Method) {
context.AbortWithStatus(http.StatusBadRequest)
return
}
if (context.Request.Method != "GET") && (checkAuthentication(context) == nil) {
context.AbortWithStatus(http.StatusUnauthorized)
return
}
var modelNames []string
result := GlobalDatabase.Table("sqlite_master").Where("type = ?", "table").Pluck("name", &modelNames)
if result.Error != nil {
context.AbortWithStatus(http.StatusInternalServerError)
return
}
var filteredModelNames []string
for _, model := range modelNames {
if slices.Contains([]string{"api_keys", "sqlite_sequence"}, model) || slices.Contains(strings.Split(model, "_"), "associations") {
continue
}
model = strings.Replace(model, "people", "persons", 1)
model = strings.Replace(model, "_", "-", -1)
filteredModelNames = append(filteredModelNames, model[:len(model)-1])
}
objectType := context.Param("object")
if !slices.Contains(filteredModelNames, objectType) {
context.AbortWithStatus(http.StatusBadRequest)
return
}
var err error
switch objectType {
case "user":
user.GetAll(GlobalDatabase)
case "person":
person.GetAll(GlobalDatabase)
case "group":
err = group.HandleRequest(GlobalDatabase, context)
case "character":
character.GetAll(GlobalDatabase)
case "role":
role.GetAll(GlobalDatabase)
case "tier":
tier.GetAll(GlobalDatabase)
case "function-set":
functionset.GetAll(GlobalDatabase)
case "function":
// result = function.Create(GlobalDatabase, context)
function.GetAll(GlobalDatabase)
case "function-tag":
// result = functiontag.Create(GlobalDatabase, context)
functiontag.GetAll(GlobalDatabase)
case "inventory-slot":
inventoryslot.GetAll(GlobalDatabase)
case "item":
item.GetAll(GlobalDatabase)
case "item-tag":
itemtag.GetAll(GlobalDatabase)
case "customization":
customization.GetAll(GlobalDatabase)
case "schematic":
schematic.GetAll(GlobalDatabase)
default:
context.AbortWithStatus(http.StatusBadRequest)
return
}
if err != nil {
context.Status(http.StatusInternalServerError)
return
}
context.Status(http.StatusOK)
}
func CreateAPIToken(context *gin.Context) { func CreateAPIToken(context *gin.Context) {
name, nameOK := context.GetQuery("name") name, nameOK := context.GetQuery("name")
@ -196,7 +257,7 @@ func CreateAPIToken(context *gin.Context) {
if oauthToken != nil { if oauthToken != nil {
oauthTokenJSON, err := json.Marshal(&oauthToken) oauthTokenJSON, err := json.Marshal(&oauthToken)
if err != nil { if err != nil {
log.Printf("This should never happen, how did this happen???\n%s", err) log.Println("This should never happen, how did this happen???\n", err)
context.AbortWithStatus(http.StatusInternalServerError) context.AbortWithStatus(http.StatusInternalServerError)
return return
} }
@ -205,7 +266,7 @@ func CreateAPIToken(context *gin.Context) {
currentUser := (*user.Get(GlobalDatabase, []string{userId}))[0] currentUser := (*user.Get(GlobalDatabase, []string{userId}))[0]
result := currentUser.GenerateAPIKey(GlobalDatabase, name) result := currentUser.GenerateAPIKey(GlobalDatabase, name)
if result != nil { if result != nil {
log.Printf("This should also never happen, how did this happen???\n%s", err) log.Println("This should also never happen, how did this happen???\n", err)
context.AbortWithStatus(http.StatusInternalServerError) context.AbortWithStatus(http.StatusInternalServerError)
return return
} }
@ -227,105 +288,6 @@ func CreateAPIToken(context *gin.Context) {
} }
} }
func CreateObject(context *gin.Context) {
if checkAuthentication(context) != nil {
var result error
switch objectType := context.Param("object"); objectType {
case "user":
//
case "person":
//
case "group":
result = group.Create(GlobalDatabase, context)
case "character":
//
case "role":
//
case "tier":
//
case "function-set":
//
case "function":
result = function.Create(GlobalDatabase, context)
case "function-tag":
result = functiontag.Create(GlobalDatabase, context)
case "inventory-slot":
//
case "item":
//
case "item-tag":
//
case "customization":
//
case "schematic":
//
}
if result != nil {
context.JSON(http.StatusBadRequest, gin.H{
"message": result.Error(),
})
} else {
context.Status(http.StatusOK)
}
} else {
context.AbortWithStatus(http.StatusUnauthorized)
}
}
// Update Endpoints (PUT)
func UpdateObject(context *gin.Context) {
if checkAuthentication(context) != nil {
_, idOk := context.GetQuery("id")
if idOk {
var result error
switch objectType := context.Param("object"); objectType {
case "user":
//
case "person":
//
case "group":
result = group.Update(GlobalDatabase, context)
case "character":
//
case "role":
//
case "tier":
//
case "function-set":
//
case "function":
result = function.Update(GlobalDatabase, context)
case "function-tag":
result = functiontag.Update(GlobalDatabase, context)
case "inventory-slot":
//
case "item":
//
case "item-tag":
//
case "customization":
//
case "schematic":
//
}
if result != nil {
context.JSON(http.StatusBadRequest, gin.H{
"message": result.Error(),
})
} else {
context.Status(http.StatusOK)
}
} else {
context.AbortWithStatus(http.StatusBadRequest)
}
} else {
context.AbortWithStatus(http.StatusUnauthorized)
}
}
// Read Endpoints (GET)
func GetDiscordUser(context *gin.Context) { func GetDiscordUser(context *gin.Context) {
oauthToken := checkAuthentication(context) oauthToken := checkAuthentication(context)
if oauthToken != nil { if oauthToken != nil {
@ -350,191 +312,3 @@ func GetUserLoggedIn(context *gin.Context) {
"message": (checkAuthentication(context) != nil), "message": (checkAuthentication(context) != nil),
}) })
} }
func GetObjects(context *gin.Context) {
objectIDs, idOk := context.GetQueryArray("id")
if idOk {
switch objectType := context.Param("object"); objectType {
case "user":
context.JSON(http.StatusOK, gin.H{
"user": user.Get(GlobalDatabase, objectIDs)})
case "persons":
context.JSON(http.StatusOK, gin.H{
"persons": person.Get(GlobalDatabase, objectIDs),
})
case "groups":
objectIDInts := objectIDStringsToInts(context, objectIDs)
context.JSON(http.StatusOK, gin.H{
"groups": group.Get(GlobalDatabase, *objectIDInts),
})
case "characters":
context.JSON(http.StatusOK, gin.H{
"characters": character.Get(GlobalDatabase, objectIDs),
})
case "roles":
context.JSON(http.StatusOK, gin.H{
"roles": role.Get(GlobalDatabase, objectIDs),
})
case "tiers":
objectIDInts := objectIDStringsToInts(context, objectIDs)
context.JSON(http.StatusOK, gin.H{
"tiers": tier.Get(GlobalDatabase, *objectIDInts),
})
case "function-sets":
objectIDInts := objectIDStringsToInts(context, objectIDs)
context.JSON(http.StatusOK, gin.H{
"function_sets": functionset.Get(GlobalDatabase, *objectIDInts),
})
case "functions":
var uintObjectIDs []uint
for _, objectID := range objectIDs {
uintObjectID, _ := strconv.Atoi(objectID)
uintObjectIDs = append(uintObjectIDs, uint(uintObjectID))
}
context.JSON(http.StatusOK, gin.H{
"functions": function.Get(GlobalDatabase, uintObjectIDs),
})
case "function-tags":
objectIDInts := objectIDStringsToInts(context, objectIDs)
context.JSON(http.StatusOK, gin.H{
"function_tags": functiontag.Get(GlobalDatabase, *objectIDInts),
})
case "inventory-slot":
objectIDInts := objectIDStringsToInts(context, objectIDs)
context.JSON(http.StatusOK, gin.H{
"inventory_slot": inventoryslot.Get(GlobalDatabase, *objectIDInts),
})
case "items":
context.JSON(http.StatusOK, gin.H{
"items": item.Get(GlobalDatabase, objectIDs),
})
case "item-tags":
context.JSON(http.StatusOK, gin.H{
"item_tags": itemtag.Get(GlobalDatabase, objectIDs),
})
case "customizations":
context.JSON(http.StatusOK, gin.H{
"customizations": customization.Get(GlobalDatabase, objectIDs),
})
case "schematics":
objectIDInts := objectIDStringsToInts(context, objectIDs)
context.JSON(http.StatusOK, gin.H{
"schematics": schematic.Get(GlobalDatabase, *objectIDInts),
})
}
} else {
context.Status(http.StatusBadRequest)
}
}
func GetAllObjects(context *gin.Context) {
switch objectType := context.Param("object"); objectType {
case "persons":
context.JSON(http.StatusOK, gin.H{
"persons": person.GetAll(GlobalDatabase),
})
case "groups":
context.JSON(http.StatusOK, gin.H{
"groups": group.GetAll(GlobalDatabase),
})
case "characters":
context.JSON(http.StatusOK, gin.H{
"characters": character.GetAll(GlobalDatabase),
})
case "roles":
context.JSON(http.StatusOK, gin.H{
"roles": role.GetAll(GlobalDatabase),
})
case "tiers":
context.JSON(http.StatusOK, gin.H{
"tiers": tier.GetAll(GlobalDatabase),
})
case "function-sets":
context.JSON(http.StatusOK, gin.H{
"function_sets": functionset.GetAll(GlobalDatabase),
})
case "functions":
context.JSON(http.StatusOK, gin.H{
"functions": function.GetAll(GlobalDatabase),
})
case "function-tags":
context.JSON(http.StatusOK, gin.H{
"function_tags": functiontag.GetAll(GlobalDatabase),
})
case "inventory-slot":
context.JSON(http.StatusOK, gin.H{
"inventory_slot": inventoryslot.GetAll(GlobalDatabase),
})
case "items":
context.JSON(http.StatusOK, gin.H{
"items": item.GetAll(GlobalDatabase),
})
case "item-tags":
context.JSON(http.StatusOK, gin.H{
"item_tags": itemtag.GetAll(GlobalDatabase),
})
case "customizations":
context.JSON(http.StatusOK, gin.H{
"customizations": customization.GetAll(GlobalDatabase),
})
case "schematics":
context.JSON(http.StatusOK, gin.H{
"schematics": schematic.GetAll(GlobalDatabase),
})
}
}
// Delete Endpoints (DELETE)
func DeleteObject(context *gin.Context) {
if checkAuthentication(context) != nil {
objectIDs, idOk := context.GetQueryArray("id")
if idOk {
uintObjectIDs := objectIDStringsToInts(context, objectIDs)
var result error
switch objectType := context.Param("object"); objectType {
case "users":
//
case "persons":
//
case "groups":
log.Println(*uintObjectIDs, objectIDs)
result = group.Delete(GlobalDatabase, *uintObjectIDs)
case "characters":
//
case "roles":
//
case "tiers":
//
case "function-sets":
//
case "functions":
log.Println(uintObjectIDs)
result = function.Delete(GlobalDatabase, *uintObjectIDs)
case "function-tags":
result = functiontag.Delete(GlobalDatabase, *uintObjectIDs)
case "inventory-slots":
//
case "items":
//
case "item-tags":
//
case "customizations":
//
case "schematics":
//
}
if result != nil {
context.JSON(http.StatusBadRequest, gin.H{
"message": result.Error(),
})
} else {
context.Status(http.StatusOK)
}
} else {
context.AbortWithStatus(http.StatusBadRequest)
}
} else {
context.AbortWithStatus(http.StatusUnauthorized)
}
}

View file

@ -70,7 +70,7 @@ func (character Character) Delete(db *gorm.DB) error {
return nil return nil
} }
func Create(db *gorm.DB, name string, owners []string, roles []string, functionsets []uint, inventory []uint) error { func Create(db *gorm.DB, name string, owners []uint, roles []string, functionsets []uint, inventory []uint) error {
return Character{ return Character{
Name: name, Name: name,
Owners: *person.Get(db, owners), Owners: *person.Get(db, owners),
@ -99,7 +99,7 @@ func GetAll(db *gorm.DB) *[]Character {
return Get(db, outputCharacterNames) return Get(db, outputCharacterNames)
} }
func Update(db *gorm.DB, name string, owners []string, roles []string, functionsets []uint, inventory []uint) error { func Update(db *gorm.DB, name string, owners []uint, roles []string, functionsets []uint, inventory []uint) error {
return Character{ return Character{
Name: name, Name: name,
Owners: *person.Get(db, owners), Owners: *person.Get(db, owners),

View file

@ -2,6 +2,7 @@ package function
import ( import (
"encoding/json" "encoding/json"
"errors"
"io" "io"
"log" "log"
"strconv" "strconv"
@ -40,6 +41,10 @@ func (function *Function) getAssociations(db *gorm.DB) {
} }
func (params *functionParams) validate(context *gin.Context) error { func (params *functionParams) validate(context *gin.Context) error {
ID, IDOk := context.GetQuery("id")
if !IDOk {
return errors.New("ID was not included in the request")
}
body, err := io.ReadAll(context.Request.Body) body, err := io.ReadAll(context.Request.Body)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
@ -49,7 +54,7 @@ func (params *functionParams) validate(context *gin.Context) error {
var newParams functionParams var newParams functionParams
err = json.Unmarshal(body, &newParams) err = json.Unmarshal(body, &newParams)
log.Println(err, newParams) log.Println(err, newParams)
params.Id = newParams.Id params.Id = ID
params.Name = newParams.Name params.Name = newParams.Name
params.Tags = newParams.Tags params.Tags = newParams.Tags
params.Requirements = newParams.Requirements params.Requirements = newParams.Requirements
@ -141,7 +146,6 @@ func Update(db *gorm.DB, context *gin.Context) error {
func Delete(db *gorm.DB, inputFunctions []uint) error { func Delete(db *gorm.DB, inputFunctions []uint) error {
functions := Get(db, inputFunctions) functions := Get(db, inputFunctions)
log.Println(inputFunctions, functions)
for _, function := range *functions { for _, function := range *functions {
err := function.Delete(db) err := function.Delete(db)
if err != nil { if err != nil {

View file

@ -123,10 +123,10 @@ func Update(db *gorm.DB, context *gin.Context) error {
}.update(db) }.update(db)
} }
func Delete(db *gorm.DB, inputFunctions []uint) error { func Delete(db *gorm.DB, inputFunctionTags []uint) error {
functions := Get(db, inputFunctions) functionTags := Get(db, inputFunctionTags)
for _, function := range *functions { for _, functiontag := range *functionTags {
err := function.delete(db) err := functiontag.delete(db)
if err != nil { if err != nil {
return err return err
} }

View file

@ -4,7 +4,8 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"io" "io"
"log" "net/http"
"slices"
"strconv" "strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -25,12 +26,11 @@ type groupParams struct {
func (params *groupParams) validate(context *gin.Context) error { func (params *groupParams) validate(context *gin.Context) error {
ID, IDOk := context.GetQuery("id") ID, IDOk := context.GetQuery("id")
if !IDOk { if !IDOk && slices.Contains([]string{"PUT"}, context.Request.Method) {
return errors.New("ID was not included in the request") return errors.New("ID was not included in the request")
} }
body, err := io.ReadAll(context.Request.Body) body, err := io.ReadAll(context.Request.Body)
if err != nil { if err != nil {
log.Println(err)
return err return err
} }
var name groupParams var name groupParams
@ -40,7 +40,7 @@ func (params *groupParams) validate(context *gin.Context) error {
return nil return nil
} }
func (group *Group) Get(db *gorm.DB, inputGroup uint) { func (group *Group) get(db *gorm.DB, inputGroup uint) {
db.Model(&Group{}).Where("ID = ?", inputGroup).Take(&group) db.Model(&Group{}).Where("ID = ?", inputGroup).Take(&group)
} }
@ -63,7 +63,7 @@ func (group Group) delete(db *gorm.DB) error {
return nil return nil
} }
func (group Group) Create(db *gorm.DB) error { func (group Group) create(db *gorm.DB) error {
result := db.Create(&group) result := db.Create(&group)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
@ -83,30 +83,22 @@ func Create(db *gorm.DB, context *gin.Context) error {
} }
return Group{ return Group{
Name: newGroup.Name, Name: newGroup.Name,
}.Create(db) }.create(db)
} }
func Get(db *gorm.DB, inputGroups []uint) *[]Group { func Get(db *gorm.DB, inputGroups []uint) *[]Group {
var outputGroups []Group var outputGroups []Group
if len(inputGroups) < 1 {
db.Model(&Group{}).Select("id").Find(&inputGroups)
}
for _, inputGroup := range inputGroups { for _, inputGroup := range inputGroups {
var outputGroup Group var outputGroup Group
outputGroup.Get(db, inputGroup) outputGroup.get(db, inputGroup)
outputGroups = append(outputGroups, outputGroup) outputGroups = append(outputGroups, outputGroup)
} }
return &outputGroups return &outputGroups
} }
func GetAll(db *gorm.DB) *[]Group {
var outputGroups []Group
var outputGroupIDs []uint
result := db.Model(&Group{}).Select("id").Find(&outputGroupIDs)
if result.Error != nil {
log.Println(result.Error)
}
outputGroups = *Get(db, outputGroupIDs)
return &outputGroups
}
func Update(db *gorm.DB, context *gin.Context) error { func Update(db *gorm.DB, context *gin.Context) error {
var params groupParams var params groupParams
err := params.validate(context) err := params.validate(context)
@ -124,8 +116,19 @@ func Update(db *gorm.DB, context *gin.Context) error {
} }
func Delete(db *gorm.DB, inputGroups []uint) error { func Delete(db *gorm.DB, inputGroups []uint) error {
groups := Get(db, inputGroups) var groups []Group
for _, group := range *groups { // if len(inputGroups) < 1 {
// result := db.Model(&Group{}).Select("id").Find(&inputGroups)
// if result.Error != nil {
// return result.Error
// }
// }
for _, inputGroup := range inputGroups {
var group Group
group.get(db, inputGroup)
groups = append(groups, group)
}
for _, group := range groups {
err := group.delete(db) err := group.delete(db)
if err != nil { if err != nil {
return err return err
@ -133,3 +136,34 @@ func Delete(db *gorm.DB, inputGroups []uint) error {
} }
return nil return nil
} }
func HandleRequest(db *gorm.DB, context *gin.Context) error {
var params groupParams
if err := params.validate(context); err != nil {
return err
}
idArray, _ := context.GetQueryArray("id")
var idUintArray []uint
for _, id := range idArray {
idUint, err := strconv.Atoi(id)
if err != nil {
return err
}
idUintArray = append(idUintArray, uint(idUint))
}
var err error
switch context.Request.Method {
case "GET":
result := Get(db, idUintArray)
context.JSON(http.StatusOK, gin.H{
"result": result,
})
case "POST":
err = Create(db, context)
case "PUT":
err = Update(db, context)
case "DELETE":
err = Delete(db, idUintArray)
}
return err
}

View file

@ -26,14 +26,14 @@ func (person *Person) getAssociations(db *gorm.DB) {
db.Model(&person).Association("Groups").Find(&person.Groups) db.Model(&person).Association("Groups").Find(&person.Groups)
} }
func (person *Person) Get(db *gorm.DB, inputPerson string) { func (person *Person) Get(db *gorm.DB, inputPerson uint) {
db.Where("name = ?", inputPerson).Take(&person) db.Where("id = ?", inputPerson).Take(&person)
person.getAssociations(db) person.getAssociations(db)
} }
func (person Person) Update(db *gorm.DB) error { func (person Person) Update(db *gorm.DB) error {
var originalPerson Person var originalPerson Person
originalPerson.Get(db, person.Name) originalPerson.Get(db, person.ID)
groupsError := db.Model(&originalPerson).Association("Groups").Replace(&person.Groups) groupsError := db.Model(&originalPerson).Association("Groups").Replace(&person.Groups)
if groupsError != nil { if groupsError != nil {
return groupsError return groupsError
@ -56,7 +56,14 @@ func Create(db *gorm.DB, name string, groups []uint) error {
}.Create(db) }.Create(db)
} }
func Get(db *gorm.DB, inputPersons []string) *[]Person { func GetByName(db *gorm.DB, name string) *Person {
var person Person
db.Model(&Person{}).Where("name = ?", name).Take(&person)
person.getAssociations(db)
return &person
}
func Get(db *gorm.DB, inputPersons []uint) *[]Person {
var outputPersons []Person var outputPersons []Person
for _, inputPerson := range inputPersons { for _, inputPerson := range inputPersons {
var outputPerson Person var outputPerson Person
@ -67,8 +74,8 @@ func Get(db *gorm.DB, inputPersons []string) *[]Person {
} }
func GetAll(db *gorm.DB) *[]Person { func GetAll(db *gorm.DB) *[]Person {
var outputPersonNames []string var outputPersonNames []uint
result := db.Model(&Person{}).Select("name").Find(&outputPersonNames) result := db.Model(&Person{}).Select("id").Find(&outputPersonNames)
if result.Error != nil { if result.Error != nil {
log.Println(result.Error) log.Println(result.Error)
} }
@ -82,9 +89,13 @@ func Update(db *gorm.DB, name string, groups []uint) error {
}.Update(db) }.Update(db)
} }
func Delete(db *gorm.DB, inputPersons []string) { func Delete(db *gorm.DB, inputPersons []uint) error {
persons := Get(db, inputPersons) persons := Get(db, inputPersons)
for _, person := range *persons { for _, person := range *persons {
person.Delete(db) err := person.Delete(db)
if err != nil {
return err
}
} }
return nil
} }

View file

@ -145,11 +145,11 @@ func Exists(db *gorm.DB, id string) bool {
return (err == nil) return (err == nil)
} }
func Create(db *gorm.DB, id string, displayName string, username string, avatar string, avatarDecoration string, loginToken string, loggedIn bool) error { func Create(db *gorm.DB, discordId string, displayName string, username string, avatar string, avatarDecoration string, loginToken string, loggedIn bool) error {
person.Create(db, displayName, []uint{}) person.Create(db, displayName, []uint{})
newPerson := (*person.Get(db, []string{displayName}))[0] newPerson := person.GetByName(db, username)
newUser := User{ newUser := User{
Id: id, Id: discordId,
Person: person.Person{}, Person: person.Person{},
DisplayName: displayName, DisplayName: displayName,
Username: username, Username: username,
@ -189,10 +189,10 @@ func GetAll(db *gorm.DB) *[]User {
return Get(db, outputUserIDs) return Get(db, outputUserIDs)
} }
func Update(db *gorm.DB, id string, displayName string, username string, avatar string, avatarDecoration string, loginToken string, loggedIn bool) error { func Update(db *gorm.DB, discordId string, displayName string, username string, avatar string, avatarDecoration string, loginToken string, loggedIn bool) error {
newPerson := (*person.Get(db, []string{displayName}))[0] newPerson := *person.GetByName(db, username)
return User{ return User{
Id: id, Id: discordId,
Person: newPerson, Person: newPerson,
DisplayName: displayName, DisplayName: displayName,
Username: username, Username: username,