diff --git a/helper/utilities.go b/helper/utilities.go new file mode 100644 index 0000000..b5eb746 --- /dev/null +++ b/helper/utilities.go @@ -0,0 +1,31 @@ +package helper + +import ( + "fmt" + "github.com/gin-gonic/gin" + "strconv" +) + +func GetIDOfElement(c *gin.Context, elementName string, source string, providedID int) (int, error) { + + if source == "path" { + id, err := strconv.Atoi(c.Param(elementName)) + if err != nil { + BadRequestError(c, fmt.Sprintf("No or incorrect format of path parameter")) + return -1, err + } + return id, nil + } else if source == "query" { + id, err := strconv.Atoi(c.Request.URL.Query().Get(elementName)) + if err != nil { + BadRequestError(c, fmt.Sprintf("No or incorrect format of query parameter")) + return -1, err + } + return id, nil + } else if source == "body" { + id := providedID + return id, nil + } else { + return -1, fmt.Errorf("invalid source of element ID") + } +} diff --git a/routes/dashboard/dashboard_middleware.go b/routes/dashboard/dashboard_middleware.go index b51b9f8..301b52b 100644 --- a/routes/dashboard/dashboard_middleware.go +++ b/routes/dashboard/dashboard_middleware.go @@ -4,8 +4,6 @@ import ( "fmt" "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/helper" "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/routes/scenario" - "strconv" - "github.com/gin-gonic/gin" "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/database" @@ -21,21 +19,9 @@ func CheckPermissions(c *gin.Context, operation database.CRUD, dabIDSource strin return false, dab } - var dabID int - if dabIDSource == "path" { - dabID, err = strconv.Atoi(c.Param("dashboardID")) - if err != nil { - helper.BadRequestError(c, fmt.Sprintf("No or incorrect format of dashboardID path parameter")) - return false, dab - } - } else if dabIDSource == "query" { - dabID, err = strconv.Atoi(c.Request.URL.Query().Get("dashboardID")) - if err != nil { - helper.BadRequestError(c, fmt.Sprintf("No or incorrect format of dashboardID query parameter")) - return false, dab - } - } else if dabIDSource == "body" { - dabID = dabIDBody + dabID, err := helper.GetIDOfElement(c, "dashboardID", dabIDSource, dabIDBody) + if err != nil { + return false, dab } err = dab.ByID(uint(dabID)) diff --git a/routes/file/file_middleware.go b/routes/file/file_middleware.go index 9faf537..f98c651 100644 --- a/routes/file/file_middleware.go +++ b/routes/file/file_middleware.go @@ -7,7 +7,6 @@ import ( "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/routes/simulationmodel" "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/routes/widget" "github.com/gin-gonic/gin" - "strconv" ) func checkPermissions(c *gin.Context, operation database.CRUD) (bool, File) { @@ -20,9 +19,8 @@ func checkPermissions(c *gin.Context, operation database.CRUD) (bool, File) { return false, f } - fileID, err := strconv.Atoi(c.Param("fileID")) + fileID, err := helper.GetIDOfElement(c, "fileID", "path", -1) if err != nil { - helper.BadRequestError(c, fmt.Sprintf("No or incorrect format of fileID path parameter")) return false, f } diff --git a/routes/scenario/scenario_middleware.go b/routes/scenario/scenario_middleware.go index db6a4de..13cc1e2 100644 --- a/routes/scenario/scenario_middleware.go +++ b/routes/scenario/scenario_middleware.go @@ -3,8 +3,6 @@ package scenario import ( "fmt" "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/helper" - "strconv" - "github.com/gin-gonic/gin" "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/database" @@ -24,24 +22,8 @@ func CheckPermissions(c *gin.Context, operation database.CRUD, screnarioIDSource return true, so } - var scenarioID int - if screnarioIDSource == "path" { - scenarioID, err = strconv.Atoi(c.Param("scenarioID")) - if err != nil { - helper.BadRequestError(c, fmt.Sprintf("No or incorrect format of scenarioID path parameter")) - return false, so - } - } else if screnarioIDSource == "query" { - scenarioID, err = strconv.Atoi(c.Request.URL.Query().Get("scenarioID")) - if err != nil { - helper.BadRequestError(c, fmt.Sprintf("No or incorrect format of scenarioID query parameter")) - return false, so - } - } else if screnarioIDSource == "body" { - scenarioID = scenarioIDbody - - } else { - helper.BadRequestError(c, fmt.Sprintf("The following source of scenario ID is not valid: %s", screnarioIDSource)) + scenarioID, err := helper.GetIDOfElement(c, "scenarioID", screnarioIDSource, scenarioIDbody) + if err != nil { return false, so } diff --git a/routes/signal/signal_middleware.go b/routes/signal/signal_middleware.go index d8c368e..672e26d 100644 --- a/routes/signal/signal_middleware.go +++ b/routes/signal/signal_middleware.go @@ -3,8 +3,6 @@ package signal import ( "fmt" "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/helper" - "strconv" - "github.com/gin-gonic/gin" "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/database" @@ -21,9 +19,8 @@ func checkPermissions(c *gin.Context, operation database.CRUD) (bool, Signal) { return false, sig } - signalID, err := strconv.Atoi(c.Param("signalID")) + signalID, err := helper.GetIDOfElement(c, "signalID", "path", -1) if err != nil { - helper.BadRequestError(c, fmt.Sprintf("No or incorrect format of signalID path parameter")) return false, sig } diff --git a/routes/simulationmodel/simulationmodel_middleware.go b/routes/simulationmodel/simulationmodel_middleware.go index a25c56f..ada13bb 100644 --- a/routes/simulationmodel/simulationmodel_middleware.go +++ b/routes/simulationmodel/simulationmodel_middleware.go @@ -3,8 +3,6 @@ package simulationmodel import ( "fmt" "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/helper" - "strconv" - "github.com/gin-gonic/gin" "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/database" @@ -21,21 +19,9 @@ func CheckPermissions(c *gin.Context, operation database.CRUD, modelIDSource str return false, m } - var modelID int - if modelIDSource == "path" { - modelID, err = strconv.Atoi(c.Param("modelID")) - if err != nil { - helper.BadRequestError(c, fmt.Sprintf("No or incorrect format of modelID path parameter")) - return false, m - } - } else if modelIDSource == "query" { - modelID, err = strconv.Atoi(c.Request.URL.Query().Get("modelID")) - if err != nil { - helper.BadRequestError(c, fmt.Sprintf("No or incorrect format of modelID query parameter")) - return false, m - } - } else if modelIDSource == "body" { - modelID = modelIDBody + modelID, err := helper.GetIDOfElement(c, "modelID", modelIDSource, modelIDBody) + if err != nil { + return false, m } err = m.ByID(uint(modelID)) diff --git a/routes/simulator/simulator_middleware.go b/routes/simulator/simulator_middleware.go index 5dc3c35..0c02b46 100644 --- a/routes/simulator/simulator_middleware.go +++ b/routes/simulator/simulator_middleware.go @@ -1,11 +1,9 @@ package simulator import ( - "fmt" "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/database" "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/helper" "github.com/gin-gonic/gin" - "strconv" ) func checkPermissions(c *gin.Context, modeltype database.ModelName, operation database.CRUD, hasID bool) (bool, Simulator) { @@ -20,9 +18,8 @@ func checkPermissions(c *gin.Context, modeltype database.ModelName, operation da if hasID { // Get the ID of the simulator from the context - simulatorID, err := strconv.Atoi(c.Param("simulatorID")) + simulatorID, err := helper.GetIDOfElement(c, "simulatorID", "path", -1) if err != nil { - helper.BadRequestError(c, fmt.Sprintf("Could not get simulator's ID from context")) return false, s } @@ -30,7 +27,6 @@ func checkPermissions(c *gin.Context, modeltype database.ModelName, operation da if helper.DBError(c, err) { return false, s } - } return true, s diff --git a/routes/widget/widget_middleware.go b/routes/widget/widget_middleware.go index f565ec1..7d3cb3e 100644 --- a/routes/widget/widget_middleware.go +++ b/routes/widget/widget_middleware.go @@ -3,8 +3,6 @@ package widget import ( "fmt" "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/helper" - "strconv" - "github.com/gin-gonic/gin" "git.rwth-aachen.de/acs/public/villas/villasweb-backend-go/database" @@ -14,8 +12,8 @@ import ( func CheckPermissions(c *gin.Context, operation database.CRUD, widgetIDBody int) (bool, Widget) { var w Widget - - err := database.ValidateRole(c, database.ModelWidget, operation) + var err error + err = database.ValidateRole(c, database.ModelWidget, operation) if err != nil { helper.UnprocessableEntityError(c, fmt.Sprintf("Access denied (role validation failed): %v", err.Error())) return false, w @@ -23,9 +21,8 @@ func CheckPermissions(c *gin.Context, operation database.CRUD, widgetIDBody int) var widgetID int if widgetIDBody < 0 { - widgetID, err = strconv.Atoi(c.Param("widgetID")) + widgetID, err = helper.GetIDOfElement(c, "widgetID", "path", -1) if err != nil { - helper.BadRequestError(c, fmt.Sprintf("No or incorrect format of widgetID path parameter")) return false, w } } else {