diff --git a/bird/bird.go b/bird/bird.go index 809be30..363bbaf 100644 --- a/bird/bird.go +++ b/bird/bird.go @@ -49,13 +49,19 @@ func InitializeCache() { log.Println("Could not initialize redis cache, falling back to memory cache:", err) } } else { // initialize the MemoryCache - cache, err = NewMemoryCache() - if err != nil { - log.Fatal("Could not initialize MemoryCache:", err) + maxEntries := CacheConf.MaxEntries + maxEntriesDefault := 100 + if maxEntries == 0 { + log.Println("MaxEntries not set, using default value:", maxEntriesDefault) + maxEntries = maxEntriesDefault } + + cache = NewMemoryCache(maxEntries) + log.Println("Initialized MemoryCache with maxEntries:", maxEntries) } } +// ExpireCache is a convenience method to expire the cache. func ExpireCache() int { return cache.Expire() } @@ -73,12 +79,12 @@ func toCache(key string, val Parsed) bool { ttl = 5 // five minutes } - if err := cache.Set(key, val, ttl); err == nil { - return true - } else { + if err := cache.Set(key, val, ttl); err != nil { log.Println(err) return false } + + return true } /* Convenience method to retrieve entries from the cache. diff --git a/bird/config.go b/bird/config.go index a18f55a..4a8b821 100644 --- a/bird/config.go +++ b/bird/config.go @@ -31,4 +31,6 @@ type CacheConfig struct { RedisServer string `toml:"redis_server"` RedisPassword string `toml:"redis_password"` RedisDb int `toml:"redis_db"` + + MaxEntries int `toml:"max_entries"` } diff --git a/bird/memory_cache.go b/bird/memory_cache.go index 7475e06..1547efa 100644 --- a/bird/memory_cache.go +++ b/bird/memory_cache.go @@ -6,76 +6,124 @@ import ( "time" ) -// Implementation of the MemoryCache backend. - +// MemoryCache is a simple in-memory cache for parsed BIRD output. +// Limiting the number of cached results is using a simple LRU algorithm. type MemoryCache struct { - sync.RWMutex - m map[string]Parsed + sync.Mutex + m map[string]Parsed // Cached data + a map[string]time.Time // Access times + + maxKeys int // Maximum number of keys to cache } -func NewMemoryCache() (*MemoryCache, error) { +// NewMemoryCache creates a new MemoryCache with a maximum number of keys. +func NewMemoryCache(maxKeys int) *MemoryCache { var cache *MemoryCache - cache = &MemoryCache{m: make(map[string]Parsed)} - return cache, nil + cache = &MemoryCache{ + m: make(map[string]Parsed), + a: make(map[string]time.Time), + + maxKeys: maxKeys, + } + return cache } +// Get a key from the cache. func (c *MemoryCache) Get(key string) (Parsed, error) { - c.RLock() + c.Lock() val, ok := c.m[key] - c.RUnlock() + c.a[key] = time.Now().UTC() // Update access + c.Unlock() + if !ok { // cache miss return NilParse, errors.New("Failed to retrive key '" + key + "' from MemoryCache.") } - ttl, correct := val["ttl"].(time.Time) - if !correct { + // Check if the TTL is still valid + ttl, ok := val["ttl"].(time.Time) + if !ok { return NilParse, errors.New("Invalid TTL value for key '" + key + "'") } if ttl.Before(time.Now()) { return val, errors.New("TTL expired for key '" + key + "'") // TTL expired - } else { - return val, nil // cache hit } + + return val, nil // cache hit } +// Set a key in the cache. func (c *MemoryCache) Set(key string, val Parsed, ttl int) error { - switch { - case ttl == 0: + c.Lock() + defer c.Unlock() + + // Check if the key exists, if not clear the oldest key if + // the number of entries exceeds maxKeys. + if _, ok := c.a[key]; !ok { + if len(c.a) >= c.maxKeys { + c.expireLRU() + } + } + + if ttl == 0 { return nil // do not cache - case ttl > 0: - cachedAt := time.Now().UTC() - cacheTtl := cachedAt.Add(time.Duration(ttl) * time.Minute) - - c.Lock() - // This is not a really ... clean way of doing this. - val["ttl"] = cacheTtl - val["cached_at"] = cachedAt - - c.m[key] = val - c.Unlock() - return nil - default: // ttl negative - invalid + } + if ttl < 0 { return errors.New("Negative TTL value for key" + key) } + + cachedAt := time.Now().UTC() + cacheTTL := cachedAt.Add(time.Duration(ttl) * time.Minute) + + // This is not a really ... clean way of doing this. + val["ttl"] = cacheTTL + val["cached_at"] = cachedAt + + c.m[key] = val + c.a[key] = cachedAt + + return nil } +// Expire oldest key in cache. +// WARNING: this is not thread safe and a mutex +// should be acquired before calling this function. +func (c *MemoryCache) expireLRU() { + oldestKey := "" + oldestTime := time.Now().UTC() + for key := range c.m { + if c.a[key].Before(oldestTime) { + oldestKey = key + oldestTime = c.a[key] + } + } + if oldestKey == "" { + return // Nothing to do here. + } + delete(c.m, oldestKey) + delete(c.a, oldestKey) +} + +// Expire all keys in cache that are older than the +// TTL value. func (c *MemoryCache) Expire() int { c.Lock() + defer c.Unlock() + + now := time.Now().UTC() expiredKeys := []string{} - for key, _ := range c.m { - ttl, correct := c.m[key]["ttl"].(time.Time) - if !correct || ttl.Before(time.Now()) { + for key := range c.m { + ttl, ok := c.m[key]["ttl"].(time.Time) + if !ok || ttl.Before(now) { expiredKeys = append(expiredKeys, key) } } for _, key := range expiredKeys { delete(c.m, key) + delete(c.a, key) } - c.Unlock() - return len(expiredKeys) } diff --git a/bird/memory_cache_test.go b/bird/memory_cache_test.go index 4026fae..be2dff5 100644 --- a/bird/memory_cache_test.go +++ b/bird/memory_cache_test.go @@ -4,9 +4,9 @@ import ( "testing" ) -func Test_MemoryCacheAccess(t *testing.T) { +func TestMemoryCacheAccess(t *testing.T) { - cache, err := NewMemoryCache() + cache := NewMemoryCache(100) parsed := Parsed{ "foo": 23, @@ -15,13 +15,12 @@ func Test_MemoryCacheAccess(t *testing.T) { } t.Log("Setting memory cache...") - err = cache.Set("testkey", parsed, 5) - if err != nil { + if err := cache.Set("testkey", parsed, 5); err != nil { t.Error(err) } t.Log("Fetching from memory cache...") - parsed, err = cache.Get("testkey") + parsed, err := cache.Get("testkey") if err != nil { t.Error(err) } @@ -30,10 +29,8 @@ func Test_MemoryCacheAccess(t *testing.T) { t.Log(parsed) } -func Test_MemoryCacheAccessKeyMissing(t *testing.T) { - - cache, err := NewMemoryCache() - +func TestMemoryCacheAccessKeyMissing(t *testing.T) { + cache := NewMemoryCache(100) parsed, err := cache.Get("test_missing_key") if !IsSpecial(parsed) { t.Error(err) @@ -42,7 +39,7 @@ func Test_MemoryCacheAccessKeyMissing(t *testing.T) { t.Log(parsed) } -func Test_MemoryCacheRoutes(t *testing.T) { +func TestMemoryCacheRoutes(t *testing.T) { f, err := openFile("routes_bird1_ipv4.sample") if err != nil { t.Error(err) @@ -55,10 +52,9 @@ func Test_MemoryCacheRoutes(t *testing.T) { t.Fatal("Error getting routes") } - cache, err := NewMemoryCache() + cache := NewMemoryCache(100) - err = cache.Set("routes_protocol_test", parsed, 5) - if err != nil { + if err := cache.Set("routes_protocol_test", parsed, 5); err != nil { t.Error(err) } @@ -73,3 +69,37 @@ func Test_MemoryCacheRoutes(t *testing.T) { } t.Log("Retrieved routes:", len(routes)) } + +func TestMemoryCacheMaxEntries(t *testing.T) { + cache := NewMemoryCache(2) + + parsed := Parsed{ + "foo": 23, + "bar": 42, + } + + // Set 3 entries + if err := cache.Set("testkey1", parsed, 5); err != nil { + t.Error(err) + } + if err := cache.Set("testkey2", parsed, 5); err != nil { + t.Error(err) + } + if err := cache.Set("testkey3", parsed, 5); err != nil { + t.Error(err) + } + + // Check that the first entry is gone + _, err := cache.Get("testkey1") + if err == nil { + t.Error("Expected error, got nil") + } + // Check that the second entry is still there + value, err := cache.Get("testkey2") + if err != nil { + t.Error("Expected no error, got", err) + } + if value["foo"] != 23 { + t.Error("Expected 23, got", value["foo"]) + } +}