diff --git a/plugin.go b/plugin.go index 2c639ab..7fc6807 100644 --- a/plugin.go +++ b/plugin.go @@ -19,6 +19,8 @@ package rfc2136 import ( "context" + "strings" + "time" "github.com/coredns/coredns/plugin" "github.com/miekg/dns" @@ -62,6 +64,11 @@ type RFC2136 struct { // zones holds per-zone file handlers, keyed by canonical zone name. // Populated in setup; mutexes live inside each zoneFile. zones map[string]*zoneFile + + // rateLimit caps UPDATE traffic per TSIG key (Hamilton M8). nil + // disables rate limiting (test mode, or insecure deployments). + // Populated in setup() once TSIG keys are known. + rateLimit *rateLimiter } // Name implements plugin.Handler. @@ -88,6 +95,19 @@ func (p *RFC2136) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg _ = w.WriteMsg(resp) return dns.RcodeRefused, nil } + // Hamilton M8: per-key rate limit. TSIG just authenticates the + // sender — it doesn't prove the sender's behavior is sane. A + // compromised key or a runaway client must not be able to + // exhaust disk/git/serial-counter resources. + if p.rateLimit != nil { + if tsig := r.IsTsig(); tsig != nil && !p.rateLimit.allow(strings.ToLower(tsig.Hdr.Name), time.Now()) { + log.Warningf("UPDATE rate-limited for key %q", tsig.Hdr.Name) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeRefused) + _ = w.WriteMsg(resp) + return dns.RcodeRefused, nil + } + } return p.handleUpdate(w, r, true) } return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 0000000..d59db1b --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,82 @@ +package rfc2136 + +import ( + "sync" + "time" +) + +// Per-key token bucket. Hamilton M8: a compromised TSIG key — or a +// misconfigured client retrying forever — must not be able to drive +// unbounded UPDATE traffic. Each UPDATE costs disk IOPS, a git commit, +// and a slot in the SOA serial counter (9999/day per zone). 100 +// UPDATEs/minute per key is well above any legitimate ACME workflow +// (a full renewal storm across our ~84 zones might emit ~200 UPDATEs +// total over several minutes); anything beyond is suspicious. +const ( + defaultRateBurst = 100 // max tokens + defaultRatePeriod = time.Minute // refill window +) + +// rateLimiter is a goroutine-safe per-key token bucket. The zero value +// is unusable; construct via newRateLimiter. +type rateLimiter struct { + mu sync.Mutex + buckets map[string]*bucket + burst float64 // max tokens + period time.Duration // time to fully refill +} + +type bucket struct { + tokens float64 + lastRefill time.Time +} + +func newRateLimiter(burst int, period time.Duration) *rateLimiter { + if burst <= 0 { + burst = defaultRateBurst + } + if period <= 0 { + period = defaultRatePeriod + } + return &rateLimiter{ + buckets: make(map[string]*bucket), + burst: float64(burst), + period: period, + } +} + +// allow attempts to take one token for `key`. Returns true if a token +// was available, false otherwise. New keys start full (burst tokens). +// +// Refill is continuous: tokens accumulate at burst/period per second. +// The bucket caps at burst tokens. +func (rl *rateLimiter) allow(key string, now time.Time) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + b, ok := rl.buckets[key] + if !ok { + // First time we see this key — start the bucket full so + // legitimate clients don't see refill delays at boot. + rl.buckets[key] = &bucket{ + tokens: rl.burst - 1, + lastRefill: now, + } + return true + } + + // Refill: tokens earned since last access. + elapsed := now.Sub(b.lastRefill).Seconds() + earned := elapsed * (rl.burst / rl.period.Seconds()) + b.tokens += earned + if b.tokens > rl.burst { + b.tokens = rl.burst + } + b.lastRefill = now + + if b.tokens >= 1.0 { + b.tokens -= 1.0 + return true + } + return false +} diff --git a/ratelimit_test.go b/ratelimit_test.go new file mode 100644 index 0000000..317c78c --- /dev/null +++ b/ratelimit_test.go @@ -0,0 +1,83 @@ +package rfc2136 + +import ( + "testing" + "time" +) + +func TestRateLimiter_FirstCallAllowed(t *testing.T) { + rl := newRateLimiter(5, time.Minute) + now := time.Now() + if !rl.allow("key-a", now) { + t.Errorf("first call for new key must be allowed") + } +} + +func TestRateLimiter_BurstExhausts(t *testing.T) { + rl := newRateLimiter(3, time.Minute) + now := time.Now() + // First 3 calls succeed. + for i := 0; i < 3; i++ { + if !rl.allow("key-a", now) { + t.Fatalf("call %d should be allowed (burst=3)", i+1) + } + } + // 4th immediately after burst should be denied (no time elapsed + // for refill). + if rl.allow("key-a", now) { + t.Errorf("4th call exceeded burst; should be denied") + } +} + +func TestRateLimiter_RefillsOverTime(t *testing.T) { + // burst=2, period=1s → refill rate is 2 tokens/sec. + rl := newRateLimiter(2, time.Second) + t0 := time.Now() + if !rl.allow("k", t0) { + t.Fatal("call 1") + } + if !rl.allow("k", t0) { + t.Fatal("call 2") + } + if rl.allow("k", t0) { + t.Fatal("call 3 should be denied; bucket empty") + } + // Advance time by 500ms — should refill ~1 token. + if !rl.allow("k", t0.Add(500*time.Millisecond)) { + t.Errorf("expected refill after 500ms") + } +} + +func TestRateLimiter_PerKeyIsolation(t *testing.T) { + rl := newRateLimiter(2, time.Minute) + now := time.Now() + // Exhaust key-a. + rl.allow("key-a", now) + rl.allow("key-a", now) + if rl.allow("key-a", now) { + t.Fatal("key-a still has tokens; setup wrong") + } + // key-b is independent — must still be allowed. + if !rl.allow("key-b", now) { + t.Errorf("key-b was rate-limited despite no prior use") + } +} + +// TestRateLimiter_DoesNotOverflow guards against refill math +// accumulating beyond burst (which would let an attacker burst more +// after a long idle period than the configured cap). +func TestRateLimiter_DoesNotOverflow(t *testing.T) { + rl := newRateLimiter(5, time.Second) + t0 := time.Now() + rl.allow("k", t0) // create bucket + // Advance time 1 hour. Refill should cap at burst=5. + tFuture := t0.Add(time.Hour) + for i := 0; i < 5; i++ { + if !rl.allow("k", tFuture) { + t.Fatalf("post-idle call %d should be allowed (cap=5)", i+1) + } + } + if rl.allow("k", tFuture) { + t.Errorf("post-idle call 6 should be denied; cap exceeded") + } +} diff --git a/setup.go b/setup.go index e048d57..c92a94f 100644 --- a/setup.go +++ b/setup.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "strconv" + "time" "github.com/coredns/caddy" "github.com/coredns/coredns/core/dnsserver" @@ -164,6 +165,13 @@ func parse(c *caddy.Controller) (*RFC2136, error) { // Per-zone git author overrides. Defaults are applied later. var gitAuthorName, gitAuthorEmail string + // Rate-limit config (Hamilton M8). Defaults are + // defaultRateBurst/defaultRatePeriod from ratelimit.go; an explicit + // `rate-limit ` directive overrides. + rateBurst := defaultRateBurst + ratePeriod := defaultRatePeriod + rateLimitEnabled := true + for c.Next() { args := c.RemainingArgs() if len(args) < 1 { @@ -235,6 +243,30 @@ func parse(c *caddy.Controller) (*RFC2136, error) { gitAuthorName = gArgs[0] gitAuthorEmail = gArgs[1] + case "rate-limit": + rArgs := c.RemainingArgs() + switch len(rArgs) { + case 1: + if rArgs[0] == "off" || rArgs[0] == "false" || rArgs[0] == "no" { + rateLimitEnabled = false + break + } + return nil, c.Errf("rate-limit single-arg form must be 'off'; for limits use 'rate-limit '") + case 2: + b, err := strconv.ParseUint(rArgs[0], 10, 31) + if err != nil || b < 1 { + return nil, c.Errf("rate-limit burst must be positive integer, got %q", rArgs[0]) + } + pSec, err := strconv.ParseUint(rArgs[1], 10, 31) + if err != nil || pSec < 1 { + return nil, c.Errf("rate-limit period must be positive integer seconds, got %q", rArgs[1]) + } + rateBurst = int(b) + ratePeriod = time.Duration(pSec) * time.Second + default: + return nil, c.Errf("rate-limit takes 'off' OR ' ', got %d args", len(rArgs)) + } + default: return nil, c.Errf("unknown directive: %s", c.Val()) } @@ -248,6 +280,11 @@ func parse(c *caddy.Controller) (*RFC2136, error) { return nil, c.Err("zones-dir is required") } + // Construct rate limiter if enabled. + if rateLimitEnabled { + p.rateLimit = newRateLimiter(rateBurst, ratePeriod) + } + // Build zoneFile handles for each declared zone. p.zones = make(map[string]*zoneFile, len(p.Zones)) for _, z := range p.Zones {