diff --git a/birdwatcher.go b/birdwatcher.go index e5a9621..759c4ad 100644 --- a/birdwatcher.go +++ b/birdwatcher.go @@ -4,6 +4,7 @@ import ( "flag" "log" "net/http" + "strings" "github.com/ecix/birdwatcher/bird" "github.com/ecix/birdwatcher/endpoints" @@ -33,8 +34,15 @@ func makeRouter() *httprouter.Router { func PrintServiceInfo(conf *Config, birdConf bird.BirdConfig) { // General Info log.Println("Starting Birdwatcher") - log.Println(" Using:", birdConf.BirdCmd) - log.Println(" Listen:", birdConf.Listen) + log.Println(" Using:", birdConf.BirdCmd) + log.Println(" Listen:", birdConf.Listen) + + // Endpoint Info + if len(conf.Server.AllowFrom) == 0 { + log.Println(" AllowFrom: ALL") + } else { + log.Println(" AllowFrom:", strings.Join(conf.Server.AllowFrom, ", ")) + } } func main() { @@ -60,8 +68,9 @@ func main() { PrintServiceInfo(conf, birdConf) - // Configure client + // Configuration bird.Conf = birdConf + endpoints.Conf = conf.Server // Make server r := makeRouter() diff --git a/config.go b/config.go index d8a82b4..6127f0f 100644 --- a/config.go +++ b/config.go @@ -4,24 +4,22 @@ package main import ( "fmt" + "github.com/BurntSushi/toml" "github.com/imdario/mergo" "github.com/ecix/birdwatcher/bird" + "github.com/ecix/birdwatcher/endpoints" ) type Config struct { - Server ServerConfig + Server endpoints.ServerConfig Status bird.StatusConfig Bird bird.BirdConfig Bird6 bird.BirdConfig } -type ServerConfig struct { - AllowFrom []string `toml:"allow_from"` -} - // Try to load configfiles as specified in the files // list. For example: // diff --git a/endpoints/config.go b/endpoints/config.go new file mode 100644 index 0000000..b727d4c --- /dev/null +++ b/endpoints/config.go @@ -0,0 +1,6 @@ +package endpoints + +// Endpoints / Server configuration +type ServerConfig struct { + AllowFrom []string `toml:"allow_from"` +} diff --git a/endpoints/endpoint.go b/endpoints/endpoint.go index e27ee35..ac95194 100644 --- a/endpoints/endpoint.go +++ b/endpoints/endpoint.go @@ -1,6 +1,10 @@ package endpoints import ( + "fmt" + "log" + "strings" + "encoding/json" "net/http" @@ -8,10 +12,43 @@ import ( "github.com/julienschmidt/httprouter" ) +var Conf ServerConfig + +func CheckAccess(req *http.Request) error { + if len(Conf.AllowFrom) == 0 { + return nil // AllowFrom ALL + } + + // Extract IP + tokens := strings.Split(req.RemoteAddr, ":") + ip := strings.Join(tokens[:len(tokens)-1], ":") + ip = strings.Replace(ip, "[", "", -1) + ip = strings.Replace(ip, "]", "", -1) + + // Check Access + for _, allowed := range Conf.AllowFrom { + if ip == allowed { + return nil + } + } + + // Log this request + log.Println("Rejecting access from:", ip) + + return fmt.Errorf("%s is not allowed to access this service.", ip) +} + func Endpoint(wrapped func(httprouter.Params) (bird.Parsed, bool)) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + + // Access Control + if err := CheckAccess(r); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } + res := make(map[string]interface{}) ret, from_cache := wrapped(ps)