diff --git a/birdwatcher.go b/birdwatcher.go index 9f53a4e..ac55877 100644 --- a/birdwatcher.go +++ b/birdwatcher.go @@ -59,7 +59,7 @@ func makeRouter(config endpoints.ServerConfig) *httprouter.Router { r.GET("/routes/filtered/:protocol", endpoints.Endpoint(endpoints.RoutesFiltered)) } if isModuleEnabled("routes_prefixed", whitelist) { - r.GET("/routes/prefix/:prefix", endpoints.Endpoint(endpoints.RoutesPrefixed)) + r.GET("/routes/prefix", endpoints.Endpoint(endpoints.RoutesPrefixed)) } if isModuleEnabled("route_net", whitelist) { r.GET("/route/net/:net", endpoints.Endpoint(endpoints.RouteNet)) diff --git a/endpoints/endpoint.go b/endpoints/endpoint.go index e364239..587bb9d 100644 --- a/endpoints/endpoint.go +++ b/endpoints/endpoint.go @@ -12,6 +12,8 @@ import ( "github.com/julienschmidt/httprouter" ) +type endpoint func(*http.Request, httprouter.Params) (bird.Parsed, bool) + var Conf ServerConfig func CheckAccess(req *http.Request) error { @@ -38,7 +40,7 @@ func CheckAccess(req *http.Request) error { return fmt.Errorf("%s is not allowed to access this service.", ip) } -func Endpoint(wrapped func(httprouter.Params) (bird.Parsed, bool)) httprouter.Handle { +func Endpoint(wrapped endpoint) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { @@ -51,7 +53,7 @@ func Endpoint(wrapped func(httprouter.Params) (bird.Parsed, bool)) httprouter.Ha res := make(map[string]interface{}) - ret, from_cache := wrapped(ps) + ret, from_cache := wrapped(r, ps) if ret == nil { w.WriteHeader(http.StatusTooManyRequests) return diff --git a/endpoints/protocols.go b/endpoints/protocols.go index cfc7a9c..cc1b30a 100644 --- a/endpoints/protocols.go +++ b/endpoints/protocols.go @@ -1,14 +1,16 @@ package endpoints import ( + "net/http" + "github.com/ecix/birdwatcher/bird" "github.com/julienschmidt/httprouter" ) -func Protocols(ps httprouter.Params) (bird.Parsed, bool) { +func Protocols(r *http.Request, ps httprouter.Params) (bird.Parsed, bool) { return bird.Protocols() } -func Bgp(ps httprouter.Params) (bird.Parsed, bool) { +func Bgp(r *http.Request, ps httprouter.Params) (bird.Parsed, bool) { return bird.ProtocolsBgp() } diff --git a/endpoints/routes.go b/endpoints/routes.go index c8a95a4..36debdc 100644 --- a/endpoints/routes.go +++ b/endpoints/routes.go @@ -2,12 +2,13 @@ package endpoints import ( "fmt" + "net/http" "github.com/ecix/birdwatcher/bird" "github.com/julienschmidt/httprouter" ) -func ProtoRoutes(ps httprouter.Params) (bird.Parsed, bool) { +func ProtoRoutes(r *http.Request, ps httprouter.Params) (bird.Parsed, bool) { protocol, err := ValidateProtocolParam(ps.ByName("protocol")) if err != nil { return bird.Parsed{"error": fmt.Sprintf("%s", err)}, false @@ -15,7 +16,7 @@ func ProtoRoutes(ps httprouter.Params) (bird.Parsed, bool) { return bird.RoutesProto(protocol) } -func RoutesFiltered(ps httprouter.Params) (bird.Parsed, bool) { +func RoutesFiltered(r *http.Request, ps httprouter.Params) (bird.Parsed, bool) { protocol, err := ValidateProtocolParam(ps.ByName("protocol")) if err != nil { return bird.Parsed{"error": fmt.Sprintf("%s", err)}, false @@ -23,19 +24,25 @@ func RoutesFiltered(ps httprouter.Params) (bird.Parsed, bool) { return bird.RoutesFiltered(protocol) } -func RoutesPrefixed(ps httprouter.Params) (bird.Parsed, bool) { - prefix, err := ValidatePrefixParam(ps.ByName("prefix")) +func RoutesPrefixed(r *http.Request, ps httprouter.Params) (bird.Parsed, bool) { + qs := r.URL.Query() + prefixl := qs["prefix"] + if len(prefixl) != 1 { + return bird.Parsed{"error": "need a prefix as single query parameter"}, false + } + + prefix, err := ValidatePrefixParam(prefixl[0]) if err != nil { return bird.Parsed{"error": fmt.Sprintf("%s", err)}, false } return bird.RoutesPrefixed(prefix) } -func TableRoutes(ps httprouter.Params) (bird.Parsed, bool) { +func TableRoutes(r *http.Request, ps httprouter.Params) (bird.Parsed, bool) { return bird.RoutesTable(ps.ByName("table")) } -func ProtoCount(ps httprouter.Params) (bird.Parsed, bool) { +func ProtoCount(r *http.Request, ps httprouter.Params) (bird.Parsed, bool) { protocol, err := ValidateProtocolParam(ps.ByName("protocol")) if err != nil { return bird.Parsed{"error": fmt.Sprintf("%s", err)}, false @@ -43,14 +50,14 @@ func ProtoCount(ps httprouter.Params) (bird.Parsed, bool) { return bird.RoutesProtoCount(protocol) } -func TableCount(ps httprouter.Params) (bird.Parsed, bool) { +func TableCount(r *http.Request, ps httprouter.Params) (bird.Parsed, bool) { return bird.RoutesTable(ps.ByName("table")) } -func RouteNet(ps httprouter.Params) (bird.Parsed, bool) { +func RouteNet(r *http.Request, ps httprouter.Params) (bird.Parsed, bool) { return bird.RoutesLookupTable(ps.ByName("net"), "master") } -func RouteNetTable(ps httprouter.Params) (bird.Parsed, bool) { +func RouteNetTable(r *http.Request, ps httprouter.Params) (bird.Parsed, bool) { return bird.RoutesLookupTable(ps.ByName("net"), ps.ByName("table")) } diff --git a/endpoints/status.go b/endpoints/status.go index 63d6f55..ae9e5a5 100644 --- a/endpoints/status.go +++ b/endpoints/status.go @@ -1,10 +1,12 @@ package endpoints import ( + "net/http" + "github.com/ecix/birdwatcher/bird" "github.com/julienschmidt/httprouter" ) -func Status(ps httprouter.Params) (bird.Parsed, bool) { +func Status(r *http.Request, ps httprouter.Params) (bird.Parsed, bool) { return bird.Status() } diff --git a/endpoints/symbols.go b/endpoints/symbols.go index 46e16ab..4bcc15a 100644 --- a/endpoints/symbols.go +++ b/endpoints/symbols.go @@ -1,20 +1,22 @@ package endpoints import ( + "net/http" + "github.com/ecix/birdwatcher/bird" "github.com/julienschmidt/httprouter" ) -func Symbols(ps httprouter.Params) (bird.Parsed, bool) { +func Symbols(r *http.Request, ps httprouter.Params) (bird.Parsed, bool) { return bird.Symbols() } -func SymbolTables(ps httprouter.Params) (bird.Parsed, bool) { +func SymbolTables(r *http.Request, ps httprouter.Params) (bird.Parsed, bool) { val, from_cache := bird.Symbols() return bird.Parsed{"symbols": val["routing table"]}, from_cache } -func SymbolProtocols(ps httprouter.Params) (bird.Parsed, bool) { +func SymbolProtocols(r *http.Request, ps httprouter.Params) (bird.Parsed, bool) { val, from_cache := bird.Symbols() return bird.Parsed{"symbols": val["protocols"]}, from_cache }