From ef6d90c5eddc8cced7c44d4b10fd7b33d9579d41 Mon Sep 17 00:00:00 2001
From: Daniel Czerwonk <daniel@dan-nrw.de>
Date: Thu, 18 Jan 2018 02:08:30 +0100
Subject: [PATCH] use net.type only on bird 2.0 and higher

---
 bird/bird.go | 42 ++++++++++++++++++++++++++++++------------
 1 file changed, 30 insertions(+), 12 deletions(-)

diff --git a/bird/bird.go b/bird/bird.go
index b114b73..1279000 100644
--- a/bird/bird.go
+++ b/bird/bird.go
@@ -2,7 +2,6 @@ package bird
 
 import (
 	"bytes"
-	"fmt"
 	"io"
 	"reflect"
 	"strings"
@@ -185,27 +184,27 @@ func Symbols() (Parsed, bool) {
 }
 
 func RoutesPrefixed(prefix string) (Parsed, bool) {
-	cmd := fmt.Sprintf("route all where net.type = NET_IP%s", IPVersion)
+	cmd := routeQueryForChannel("route all")
 	return RunAndParse(cmd, parseRoutes)
 }
 
 func RoutesProto(protocol string) (Parsed, bool) {
-	cmd := fmt.Sprintf("route all protocol %s where net.type = NET_IP%s", protocol, IPVersion)
+	cmd := routeQueryForChannel("route all protocol " + protocol)
 	return RunAndParse(cmd, parseRoutes)
 }
 
 func RoutesProtoCount(protocol string) (Parsed, bool) {
-	cmd := fmt.Sprintf("route all protocol %s where net.type = NET_IP%s count", protocol, IPVersion)
+	cmd := routeQueryForChannel("route all protocol "+protocol) + " count"
 	return RunAndParse(cmd, parseRoutes)
 }
 
 func RoutesFiltered(protocol string) (Parsed, bool) {
-	cmd := fmt.Sprintf("route all filtered %s where net.type = NET_IP%s", protocol, IPVersion)
+	cmd := routeQueryForChannel("route all filtered " + protocol)
 	return RunAndParse(cmd, parseRoutes)
 }
 
 func RoutesExport(protocol string) (Parsed, bool) {
-	cmd := fmt.Sprintf("route all export %s where net.type = NET_IP%s", protocol, IPVersion)
+	cmd := routeQueryForChannel("route all export " + protocol)
 	return RunAndParse(cmd, parseRoutes)
 }
 
@@ -221,12 +220,12 @@ func RoutesNoExport(protocol string) (Parsed, bool) {
 			protocol[len(ParserConf.PeerProtocolPrefix):]
 	}
 
-	cmd := fmt.Sprintf("route all noexport %s where net.type = NET_IP%s", protocol, IPVersion)
+	cmd := routeQueryForChannel("route all noexport " + protocol)
 	return RunAndParse(cmd, parseRoutes)
 }
 
 func RoutesExportCount(protocol string) (Parsed, bool) {
-	cmd := fmt.Sprintf("route all export %s where net.type = NET_IP%s count", protocol, IPVersion)
+	cmd := routeQueryForChannel("route all export "+protocol) + " count"
 	return RunAndParse(cmd, parseRoutesCount)
 }
 
@@ -247,7 +246,7 @@ func RoutesLookupProtocol(net string, protocol string) (Parsed, bool) {
 }
 
 func RoutesPeer(peer string) (Parsed, bool) {
-	cmd := fmt.Sprintf("route export %s where net.type = NET_IP%s", peer, IPVersion)
+	cmd := routeQueryForChannel("route export " + peer)
 	return RunAndParse(cmd, parseRoutes)
 }
 
@@ -260,8 +259,8 @@ func RoutesDump() (Parsed, bool) {
 }
 
 func RoutesDumpSingleTable() (Parsed, bool) {
-	importedRes, cached := RunAndParse(fmt.Sprintf("route all where net.type = NET_IP%s", IPVersion), parseRoutes)
-	filteredRes, _ := RunAndParse(fmt.Sprintf("route all filtered where net.type = NET_IP%s", IPVersion), parseRoutes)
+	importedRes, cached := RunAndParse(routeQueryForChannel("route all"), parseRoutes)
+	filteredRes, _ := RunAndParse(routeQueryForChannel("route all filtered"), parseRoutes)
 
 	imported := importedRes["routes"]
 	filtered := filteredRes["routes"]
@@ -275,7 +274,7 @@ func RoutesDumpSingleTable() (Parsed, bool) {
 }
 
 func RoutesDumpPerPeerTable() (Parsed, bool) {
-	importedRes, cached := RunAndParse("route all where net.type = NET_IP"+IPVersion, parseRoutes)
+	importedRes, cached := RunAndParse(routeQueryForChannel("route all"), parseRoutes)
 	imported := importedRes["routes"]
 	filtered := []Parsed{}
 
@@ -314,3 +313,22 @@ func RoutesDumpPerPeerTable() (Parsed, bool) {
 
 	return result, cached
 }
+
+func routeQueryForChannel(cmd string) string {
+	status, _ := Status()
+	birdStatus, ok := status["status"].(Parsed)
+	if !ok {
+		return cmd
+	}
+
+	version, ok := birdStatus["version"].(string)
+	if !ok {
+		return cmd
+	}
+
+	if len(version) == 0 || int(version[0]) < 2 {
+		return cmd
+	}
+
+	return cmd + " where net.type = NET_IP" + IPVersion
+}