package rfc2136 import ( "context" "net" "testing" "github.com/coredns/coredns/plugin" "github.com/miekg/dns" ) // captureWriter implements dns.ResponseWriter and stashes the message // passed to WriteMsg so tests can inspect it after ServeDNS returns. type captureWriter struct { msg *dns.Msg } func (cw *captureWriter) WriteMsg(m *dns.Msg) error { cw.msg = m; return nil } func (cw *captureWriter) Write([]byte) (int, error) { return 0, nil } func (cw *captureWriter) Close() error { return nil } func (cw *captureWriter) TsigStatus() error { return nil } func (cw *captureWriter) TsigTimersOnly(bool) {} func (cw *captureWriter) Hijack() {} func (cw *captureWriter) LocalAddr() net.Addr { return &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} } func (cw *captureWriter) RemoteAddr() net.Addr { return &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} } func (cw *captureWriter) Network() string { return "udp" } // passthroughNext is a stand-in for the next plugin in the chain. // Returns a fixed rcode so we can detect "we passed through" in tests. type passthroughNext struct{ called bool } func (n *passthroughNext) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { n.called = true msg := new(dns.Msg) msg.SetReply(r) msg.Rcode = dns.RcodeRefused // arbitrary marker _ = w.WriteMsg(msg) return dns.RcodeRefused, nil } func (n *passthroughNext) Name() string { return "passthroughNext" } // newTestPlugin builds an RFC2136 with sensible defaults for tests. func newTestPlugin(zone, ns string, next plugin.Handler) *RFC2136 { return &RFC2136{ Next: next, Zones: []string{dns.Fqdn(zone)}, TTL: 60, Nameserver: dns.Fqdn(ns), store: newStore(), } } func TestServeDNS_OutsideZone_PassesThrough(t *testing.T) { next := &passthroughNext{} p := newTestPlugin("auth.example.com.", "ns.example.com.", next) req := new(dns.Msg) req.SetQuestion("unrelated.other.tld.", dns.TypeA) w := &captureWriter{} rcode, _ := p.ServeDNS(context.Background(), w, req) if !next.called { t.Errorf("expected pass-through to Next, but Next was not called") } if rcode != dns.RcodeRefused { t.Errorf("rcode = %d (want %d from passthroughNext marker)", rcode, dns.RcodeRefused) } } func TestServeDNS_ApexSOA_Synthetic(t *testing.T) { p := newTestPlugin("auth.example.com.", "ns.example.com.", nil) req := new(dns.Msg) req.SetQuestion("auth.example.com.", dns.TypeSOA) w := &captureWriter{} rcode, _ := p.ServeDNS(context.Background(), w, req) if rcode != dns.RcodeSuccess { t.Fatalf("rcode = %d, want NOERROR", rcode) } if w.msg == nil || !w.msg.Authoritative { t.Fatalf("response not authoritative: %+v", w.msg) } if len(w.msg.Answer) != 1 { t.Fatalf("Answer len = %d, want 1", len(w.msg.Answer)) } soa, ok := w.msg.Answer[0].(*dns.SOA) if !ok { t.Fatalf("answer is not SOA: %T", w.msg.Answer[0]) } if soa.Ns != "ns.example.com." { t.Errorf("SOA.Ns = %q, want ns.example.com.", soa.Ns) } if soa.Mbox != "admin.auth.example.com." { t.Errorf("SOA.Mbox = %q, want admin.auth.example.com.", soa.Mbox) } } func TestServeDNS_ApexNS_Synthetic(t *testing.T) { p := newTestPlugin("auth.example.com.", "ns.example.com.", nil) req := new(dns.Msg) req.SetQuestion("auth.example.com.", dns.TypeNS) w := &captureWriter{} p.ServeDNS(context.Background(), w, req) if len(w.msg.Answer) != 1 { t.Fatalf("Answer len = %d, want 1", len(w.msg.Answer)) } ns, ok := w.msg.Answer[0].(*dns.NS) if !ok { t.Fatalf("answer is not NS: %T", w.msg.Answer[0]) } if ns.Ns != "ns.example.com." { t.Errorf("NS.Ns = %q", ns.Ns) } } func TestServeDNS_ExistingTXT_Returned(t *testing.T) { p := newTestPlugin("auth.example.com.", "ns.example.com.", nil) p.store.Add(mustRR(t, `foo.auth.example.com. 60 IN TXT "token-1"`)) req := new(dns.Msg) req.SetQuestion("foo.auth.example.com.", dns.TypeTXT) w := &captureWriter{} rcode, _ := p.ServeDNS(context.Background(), w, req) if rcode != dns.RcodeSuccess { t.Fatalf("rcode = %d, want NOERROR", rcode) } if len(w.msg.Answer) != 1 { t.Fatalf("Answer len = %d, want 1", len(w.msg.Answer)) } txt := w.msg.Answer[0].(*dns.TXT) if txt.Txt[0] != "token-1" { t.Errorf("TXT = %q, want token-1", txt.Txt[0]) } } func TestServeDNS_NonExistentName_NXDOMAIN(t *testing.T) { p := newTestPlugin("auth.example.com.", "ns.example.com.", nil) req := new(dns.Msg) req.SetQuestion("missing.auth.example.com.", dns.TypeTXT) w := &captureWriter{} rcode, _ := p.ServeDNS(context.Background(), w, req) if rcode != dns.RcodeNameError { t.Errorf("rcode = %d, want NXDOMAIN (%d)", rcode, dns.RcodeNameError) } if len(w.msg.Answer) != 0 { t.Errorf("expected empty Answer for NXDOMAIN, got %v", w.msg.Answer) } if len(w.msg.Ns) != 1 { t.Errorf("expected SOA in authority section, got %v", w.msg.Ns) } if _, ok := w.msg.Ns[0].(*dns.SOA); !ok { t.Errorf("authority section is not SOA: %T", w.msg.Ns[0]) } } func TestServeDNS_WrongType_NODATA(t *testing.T) { p := newTestPlugin("auth.example.com.", "ns.example.com.", nil) p.store.Add(mustRR(t, `foo.auth.example.com. 60 IN A 192.0.2.1`)) req := new(dns.Msg) req.SetQuestion("foo.auth.example.com.", dns.TypeTXT) w := &captureWriter{} rcode, _ := p.ServeDNS(context.Background(), w, req) if rcode != dns.RcodeSuccess { t.Errorf("rcode = %d, want NOERROR (NODATA)", rcode) } if len(w.msg.Answer) != 0 { t.Errorf("NODATA must have empty Answer, got %v", w.msg.Answer) } if len(w.msg.Ns) != 1 { t.Errorf("expected SOA in authority for NODATA") } } func TestServeDNS_UpdateOpcode_Refused(t *testing.T) { p := newTestPlugin("auth.example.com.", "ns.example.com.", nil) req := new(dns.Msg) req.SetUpdate("auth.example.com.") w := &captureWriter{} rcode, _ := p.ServeDNS(context.Background(), w, req) if rcode != dns.RcodeRefused { t.Errorf("UPDATE rcode = %d, want REFUSED (%d)", rcode, dns.RcodeRefused) } } func TestFindZone_LongestSuffixWins(t *testing.T) { p := &RFC2136{ Zones: []string{"example.com.", "auth.example.com."}, } got := p.findZone("foo.auth.example.com.") if got != "auth.example.com." { t.Errorf("findZone returned %q, expected longest-match auth.example.com.", got) } } func TestFindZone_OutsideAllZones(t *testing.T) { p := &RFC2136{Zones: []string{"auth.example.com."}} if got := p.findZone("other.tld."); got != "" { t.Errorf("findZone for unrelated qname returned %q, want empty", got) } } func TestFindZone_CaseInsensitive(t *testing.T) { p := &RFC2136{Zones: []string{"auth.example.com."}} if got := p.findZone("Foo.AUTH.example.COM."); got != "auth.example.com." { t.Errorf("case-insensitive findZone returned %q", got) } }