Skip to content

Commit

Permalink
contrib/miekg/dns: use random port for test (#3164)
Browse files Browse the repository at this point in the history
  • Loading branch information
rarguelloF authored Feb 7, 2025
1 parent 9462231 commit e08fe35
Showing 1 changed file with 55 additions and 61 deletions.
116 changes: 55 additions & 61 deletions contrib/miekg/dns/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package dns_test
import (
"context"
"net"
"sync"
"testing"
"time"

Expand All @@ -28,136 +29,134 @@ func (th *testHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m)
}

func startServer(t *testing.T, traced bool) (*dns.Server, func()) {
func startServer(t *testing.T, traced bool) (*dns.Server, string) {
var h dns.Handler = &testHandler{}
if traced {
h = dnstrace.WrapHandler(h)
}
addr := getAddr(t).String()
server := &dns.Server{
Addr: addr,
Net: "udp",
Handler: h,
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)

srv := &dns.Server{
PacketConn: pc,
ReadTimeout: time.Hour,
WriteTimeout: time.Hour,
Handler: h,
}

// start the server
waitLock := sync.Mutex{}
waitLock.Lock()
srv.NotifyStartedFunc = waitLock.Unlock

go func() {
err := server.ListenAndServe()
if err != nil {
t.Error(err)
}
require.NoError(t, srv.ActivateAndServe())
}()
waitTillUDPReady(addr)
stopServer := func() {
err := server.Shutdown()
assert.NoError(t, err)
}
return server, stopServer
t.Cleanup(func() {
require.NoError(t, srv.Shutdown())
})

waitLock.Lock()
return srv, pc.LocalAddr().String()
}

func TestExchange(t *testing.T) {
server, stopServer := startServer(t, false)
defer stopServer()
_, addr := startServer(t, false)

mt := mocktracer.Start()
defer mt.Stop()

m := newMessage()

_, err := dnstrace.Exchange(m, server.Addr)
assert.NoError(t, err)
_, err := dnstrace.Exchange(m, addr)
require.NoError(t, err)

spans := mt.FinishedSpans()
require.Len(t, spans, 1)
assertClientSpan(t, spans[0])
}

func TestExchangeContext(t *testing.T) {
server, stopServer := startServer(t, false)
defer stopServer()
_, addr := startServer(t, false)

mt := mocktracer.Start()
defer mt.Stop()

m := newMessage()

_, err := dnstrace.ExchangeContext(context.Background(), m, server.Addr)
assert.NoError(t, err)
_, err := dnstrace.ExchangeContext(context.Background(), m, addr)
require.NoError(t, err)

spans := mt.FinishedSpans()
require.Len(t, spans, 1)
assertClientSpan(t, spans[0])
}

func TestExchangeConn(t *testing.T) {
server, stopServer := startServer(t, false)
defer stopServer()
_, addr := startServer(t, false)

mt := mocktracer.Start()
defer mt.Stop()

m := newMessage()

conn, err := net.Dial("udp", server.Addr)
conn, err := net.Dial("udp", addr)
require.NoError(t, err)

_, err = dnstrace.ExchangeConn(conn, m)
assert.NoError(t, err)
require.NoError(t, err)

spans := mt.FinishedSpans()
require.Len(t, spans, 1)
assertClientSpan(t, spans[0])
}

func TestClient_Exchange(t *testing.T) {
server, stopServer := startServer(t, false)
defer stopServer()
_, addr := startServer(t, false)

mt := mocktracer.Start()
defer mt.Stop()

m := newMessage()

client := newTracedClient()

_, _, err := client.Exchange(m, server.Addr)
assert.NoError(t, err)
_, _, err := client.Exchange(m, addr)
require.NoError(t, err)

spans := mt.FinishedSpans()
require.Len(t, spans, 1)
assertClientSpan(t, spans[0])
}

func TestClient_ExchangeContext(t *testing.T) {
server, stopServer := startServer(t, false)
defer stopServer()
_, addr := startServer(t, false)

mt := mocktracer.Start()
defer mt.Stop()

m := newMessage()

client := newTracedClient()

_, _, err := client.ExchangeContext(context.Background(), m, server.Addr)
assert.NoError(t, err)
_, _, err := client.ExchangeContext(context.Background(), m, addr)
require.NoError(t, err)

spans := mt.FinishedSpans()
require.Len(t, spans, 1)
assertClientSpan(t, spans[0])
}

func TestWrapHandler(t *testing.T) {
server, stopServer := startServer(t, true)
_, addr := startServer(t, true)

mt := mocktracer.Start()
defer mt.Stop()

m := newMessage()
_, err := dns.Exchange(m, server.Addr)
assert.NoError(t, err)
client := newClient()

_, _, err := client.Exchange(m, addr)
require.NoError(t, err)

stopServer() // Shutdown server so span is closed after DNS request
waitForSpans(mt, 1)

spans := mt.FinishedSpans()
require.Len(t, spans, 1)
Expand All @@ -177,8 +176,12 @@ func newMessage() *dns.Msg {
return m
}

func newClient() *dns.Client {
return &dns.Client{Net: "udp"}
}

func newTracedClient() *dnstrace.Client {
return &dnstrace.Client{Client: &dns.Client{Net: "udp"}}
return &dnstrace.Client{Client: newClient()}
}

func assertClientSpan(t *testing.T, s mocktracer.Span) {
Expand All @@ -190,24 +193,15 @@ func assertClientSpan(t *testing.T, s mocktracer.Span) {
assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind))
}

func getAddr(t *testing.T) net.Addr {
li, err := net.Listen("tcp4", "127.0.0.1:2020")
if err != nil {
t.Fatal(err)
}
addr := li.Addr()
li.Close()
return addr
}
func waitForSpans(mt mocktracer.Tracer, sz int) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

func waitTillUDPReady(addr string) {
deadline := time.Now().Add(time.Second * 10)
for time.Now().Before(deadline) {
m := new(dns.Msg)
m.SetQuestion("miek.nl.", dns.TypeMX)
_, err := dns.Exchange(m, addr)
if err == nil {
break
for len(mt.FinishedSpans()) < sz {
select {
case <-ctx.Done():
return
default:
}
time.Sleep(time.Millisecond * 100)
}
Expand Down

0 comments on commit e08fe35

Please sign in to comment.