diff --git a/serve_mux.go b/serve_mux.go index e7f36e221..17837fd1f 100644 --- a/serve_mux.go +++ b/serve_mux.go @@ -28,6 +28,25 @@ func NewServeMux() *ServeMux { // DefaultServeMux is the default ServeMux used by Serve. var DefaultServeMux = NewServeMux() +func (mux *ServeMux) matchWildcard(q string) Handler { + + wildcards := "*." + // replace the labels of q with wildcard labels until we get a match + for off, end := 0, false; !end; off, end = NextLabel(q, off) { + // skip to removing the first label + if off == 0 { + continue + } + if h, ok := mux.z[wildcards+q[off:]]; ok { + return h + } + wildcards += "*." + } + + // we found nothing + return nil +} + func (mux *ServeMux) match(q string, t uint16) Handler { mux.m.RLock() defer mux.m.RUnlock() @@ -46,9 +65,20 @@ func (mux *ServeMux) match(q string, t uint16) Handler { // Continue for DS to see if we have a parent too, if so delegate to the parent handler = h } + + // we did not find a match - try wildcards if this is the first iteration only, + // as otherwise we will be attempting the same match on every iteration + if off == 0 { + if h := mux.matchWildcard(q); h != nil { + if t != TypeDS { + return h + } + handler = h + } + } } - // Wildcard match, if we have found nothing try the root zone as a last resort. + // If we have found nothing try the root zone as a last resort. if h, ok := mux.z["."]; ok { return h } diff --git a/serve_mux_test.go b/serve_mux_test.go index 3d990ce52..0eb660940 100644 --- a/serve_mux_test.go +++ b/serve_mux_test.go @@ -28,6 +28,87 @@ func TestDotAsCatchAllWildcard(t *testing.T) { } } +type mockHandler string + +func (mockHandler) ServeDNS(w ResponseWriter, r *Msg) { + panic("implement me") +} + +func TestWildcardMatch(t *testing.T) { + mux := NewServeMux() + mux.Handle("example.com.", mockHandler("example.com")) + mux.Handle("*.example.com.", mockHandler("*.example.com")) + mux.Handle("a.example.com.", mockHandler("a.example.com")) + + handler := mux.match("www.example.com.", TypeTXT) + if handler == nil { + t.Error("example.com match failed") + } + if string(handler.(mockHandler)) != "*.example.com" { + t.Error("www.example.com did not match *.example.com wildcard") + } + + handler = mux.match("a.example.com.", TypeTXT) + if handler == nil { + t.Error("a.example.com match failed") + } + if string(handler.(mockHandler)) != "a.example.com" { + t.Error("a.example.com did not match subdomain a") + } + + handler = mux.match("example.com", TypeTXT) + if handler == nil { + t.Error("example.com match failed") + } + if string(handler.(mockHandler)) != "example.com" { + t.Error("example.com did not match example.com, but with", handler) + } + + handler = mux.match("foo.bar.example.com", TypeTXT) + // see https://datatracker.ietf.org/doc/html/rfc4592#section-2.2.1 + // a wildcard does not match names below its zone + if handler != nil && string(handler.(mockHandler)) == "*.example.com" { + t.Error("foo.bar.example.com matched unexpectedly with non terminal") + } +} + +func TestTwoWildcardMatch(t *testing.T) { + mux := NewServeMux() + mux.Handle(".", mockHandler("root")) + mux.Handle("example.com.", mockHandler("example")) + mux.Handle("*.*.example.com.", mockHandler("2wildcard")) + + handler := mux.match("foo.bar.example.com.", TypeTXT) + if handler == nil { + t.Error("foo.bar.example.com match failed") + } + if string(handler.(mockHandler)) != "2wildcard" { + t.Error("foo.bar.example.com did not match *.*.example.com wildcard") + } + + // this tests wildcards as empty non-terminals + // (`*.example.com` is empty in this example) + handler = mux.match("www.example.com.", TypeTXT) + if handler != nil && string(handler.(mockHandler)) != "example" { + t.Error("www.example.com unexpectedly matched", string(handler.(mockHandler))) + } +} + +func TestWildcardMustNotMatchEntireZone(t *testing.T) { + mux := NewServeMux() + mux.Handle(".", mockHandler("root")) + mux.Handle("*.example.com.", mockHandler("2wildcard")) + + handler := mux.match("foo.bar.example.com.", TypeTXT) + if handler == nil { + t.Error("match failed") + } + + if handler != nil && string(handler.(mockHandler)) != "root" { + t.Error("foo.bar.example.com unexpectedly matched", string(handler.(mockHandler))) + } +} + func TestCaseFolding(t *testing.T) { mux := NewServeMux() mux.Handle("_udp.example.com.", HandlerFunc(HelloServer))