mirror of
https://github.com/alice-lg/birdwatcher.git
synced 2025-03-09 00:00:05 +01:00
added access control
This commit is contained in:
parent
830bf89944
commit
0a4befdcaa
4 changed files with 58 additions and 8 deletions
|
@ -4,6 +4,7 @@ import (
|
|||
"flag"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/ecix/birdwatcher/bird"
|
||||
"github.com/ecix/birdwatcher/endpoints"
|
||||
|
@ -33,8 +34,15 @@ func makeRouter() *httprouter.Router {
|
|||
func PrintServiceInfo(conf *Config, birdConf bird.BirdConfig) {
|
||||
// General Info
|
||||
log.Println("Starting Birdwatcher")
|
||||
log.Println(" Using:", birdConf.BirdCmd)
|
||||
log.Println(" Listen:", birdConf.Listen)
|
||||
log.Println(" Using:", birdConf.BirdCmd)
|
||||
log.Println(" Listen:", birdConf.Listen)
|
||||
|
||||
// Endpoint Info
|
||||
if len(conf.Server.AllowFrom) == 0 {
|
||||
log.Println(" AllowFrom: ALL")
|
||||
} else {
|
||||
log.Println(" AllowFrom:", strings.Join(conf.Server.AllowFrom, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
@ -60,8 +68,9 @@ func main() {
|
|||
|
||||
PrintServiceInfo(conf, birdConf)
|
||||
|
||||
// Configure client
|
||||
// Configuration
|
||||
bird.Conf = birdConf
|
||||
endpoints.Conf = conf.Server
|
||||
|
||||
// Make server
|
||||
r := makeRouter()
|
||||
|
|
|
@ -4,24 +4,22 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
"github.com/imdario/mergo"
|
||||
|
||||
"github.com/ecix/birdwatcher/bird"
|
||||
"github.com/ecix/birdwatcher/endpoints"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig
|
||||
Server endpoints.ServerConfig
|
||||
|
||||
Status bird.StatusConfig
|
||||
Bird bird.BirdConfig
|
||||
Bird6 bird.BirdConfig
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
AllowFrom []string `toml:"allow_from"`
|
||||
}
|
||||
|
||||
// Try to load configfiles as specified in the files
|
||||
// list. For example:
|
||||
//
|
||||
|
|
6
endpoints/config.go
Normal file
6
endpoints/config.go
Normal file
|
@ -0,0 +1,6 @@
|
|||
package endpoints
|
||||
|
||||
// Endpoints / Server configuration
|
||||
type ServerConfig struct {
|
||||
AllowFrom []string `toml:"allow_from"`
|
||||
}
|
|
@ -1,6 +1,10 @@
|
|||
package endpoints
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
|
@ -8,10 +12,43 @@ import (
|
|||
"github.com/julienschmidt/httprouter"
|
||||
)
|
||||
|
||||
var Conf ServerConfig
|
||||
|
||||
func CheckAccess(req *http.Request) error {
|
||||
if len(Conf.AllowFrom) == 0 {
|
||||
return nil // AllowFrom ALL
|
||||
}
|
||||
|
||||
// Extract IP
|
||||
tokens := strings.Split(req.RemoteAddr, ":")
|
||||
ip := strings.Join(tokens[:len(tokens)-1], ":")
|
||||
ip = strings.Replace(ip, "[", "", -1)
|
||||
ip = strings.Replace(ip, "]", "", -1)
|
||||
|
||||
// Check Access
|
||||
for _, allowed := range Conf.AllowFrom {
|
||||
if ip == allowed {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Log this request
|
||||
log.Println("Rejecting access from:", ip)
|
||||
|
||||
return fmt.Errorf("%s is not allowed to access this service.", ip)
|
||||
}
|
||||
|
||||
func Endpoint(wrapped func(httprouter.Params) (bird.Parsed, bool)) httprouter.Handle {
|
||||
return func(w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
ps httprouter.Params) {
|
||||
|
||||
// Access Control
|
||||
if err := CheckAccess(r); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
res := make(map[string]interface{})
|
||||
|
||||
ret, from_cache := wrapped(ps)
|
||||
|
|
Loading…
Add table
Reference in a new issue