From de3dbfb08ac0e4e15619f317a33ddf6a478e9c5e Mon Sep 17 00:00:00 2001 From: Miguel Angel Nubla Date: Mon, 30 Dec 2024 18:42:39 +0100 Subject: [PATCH] Major refactor --- .gitignore | 3 +- .goreleaser.yaml | 13 +- README.md | 4 +- addr.go | 106 ++++++++++++ addr_collection.go | 182 +++++++++++++++++++++ cmd/ipv6disc/ipv6disc.go | 19 ++- main.go | 7 - pkg/ndp/ndp.go | 14 +- pkg/worker/table.go | 260 ------------------------------ state.go | 93 +++++++++++ pkg/worker/worker.go => worker.go | 100 ++++++------ 11 files changed, 459 insertions(+), 342 deletions(-) create mode 100644 addr.go create mode 100644 addr_collection.go delete mode 100644 main.go delete mode 100644 pkg/worker/table.go create mode 100644 state.go rename pkg/worker/worker.go => worker.go (74%) diff --git a/.gitignore b/.gitignore index f868bb2..0b0222e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ -/ipv6disc +cmd/ipv6disc/ipv6disc +cmd/ipv6disc/ipv6disc.exe dist/ diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 813641e..4d79b32 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -1,9 +1,11 @@ +version: 2 before: hooks: - - go mod tidy - go generate ./... + - go mod tidy builds: - - env: + - main: cmd/ipv6disc/ipv6disc.go + env: - CGO_ENABLED=0 goos: # - aix @@ -85,6 +87,8 @@ dockers: - "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.url=https://github.com/miguelangel-nubla/{{ .ProjectName }}" - "--label=org.opencontainers.image.source=https://github.com/miguelangel-nubla/{{ .ProjectName }}" + extra_files: + - LICENSE.txt - image_templates: - "ghcr.io/miguelangel-nubla/{{.ProjectName}}:{{ .Tag }}-armv7" @@ -99,6 +103,8 @@ dockers: - "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.url=https://github.com/miguelangel-nubla/{{ .ProjectName }}" - "--label=org.opencontainers.image.source=https://github.com/miguelangel-nubla/{{ .ProjectName }}" + extra_files: + - LICENSE.txt - image_templates: - "ghcr.io/miguelangel-nubla/{{.ProjectName}}:{{ .Tag }}-arm64v8" @@ -151,7 +157,6 @@ docker_manifests: archives: - - rlcp: true files: - LICENSE* - README* @@ -171,7 +176,7 @@ archives: checksum: name_template: 'checksums.txt' snapshot: - name_template: "{{ incpatch .Version }}-next" + version_template: "{{ incpatch .Version }}-next" changelog: sort: asc filters: diff --git a/README.md b/README.md index b023ca4..445ae00 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Download the [latest release](https://github.com/miguelangel-nubla/ipv6disc/rele Ensure you have Go installed on your system. If not, follow the instructions on the official [Go website](https://golang.org/doc/install) to install it. Then: ``` -go install github.com/miguelangel-nubla/ipv6disc +go install github.com/miguelangel-nubla/ipv6disc/cmd/ipv6disc ``` ### Or use the docker image @@ -39,7 +39,7 @@ If you need to pause and select/copy data use `screen`, launch `ipv6disc -live [ ## Flags - `-log_level`: Set the logging level (default: "info"). Available options: "debug", "info", "warn", "error", "fatal", "panic". -- `-ttl`: Set the time-to-live (TTL) for a discovered host entry in the table after it has been last seen (default: 4 hours). +- `-lifetime`: Set the lifetime for a discovered host entry after it has been last seen (default: 4 hours). - `-live`: Show the current state live on the terminal (default: false). ## License diff --git a/addr.go b/addr.go new file mode 100644 index 0000000..4566304 --- /dev/null +++ b/addr.go @@ -0,0 +1,106 @@ +package ipv6disc + +import ( + "net" + "net/netip" + "sync" + "time" +) + +type Addr struct { + netip.Addr + Hw net.HardwareAddr + + lifetime time.Duration + onExpiration func(*Addr, AddrExpirationRemainingEvents) + + expiration3 *time.Timer + expiration2 *time.Timer + expiration1 *time.Timer + expiration *time.Timer + expirationTime time.Time + + unwatch chan bool + + mutex sync.RWMutex +} + +type AddrExpirationRemainingEvents int + +func (a *Addr) resetTimers(resetExpirationTime bool) { + a.mutex.Lock() + defer a.mutex.Unlock() + + a.expiration3.Reset(a.lifetime / 3 * 2) + a.expiration2.Reset(a.lifetime / 4 * 3) + a.expiration1.Reset(a.lifetime / 5 * 4) + a.expiration.Reset(a.lifetime) + + if resetExpirationTime { + a.expirationTime = time.Now().Add(a.lifetime) + } +} + +func (a *Addr) IsStillValid() bool { + return a.expirationTime.After(time.Now()) +} + +func (a *Addr) GetExpiration() time.Time { + a.mutex.RLock() + defer a.mutex.RUnlock() + + return a.expirationTime +} + +func (a *Addr) Seen() { + a.resetTimers(true) +} + +func (a *Addr) Watch() { + go func() { + for { + select { + case <-a.expiration3.C: + a.onExpiration(a, 3) + case <-a.expiration2.C: + a.onExpiration(a, 2) + case <-a.expiration1.C: + a.onExpiration(a, 1) + case <-a.expiration.C: + a.onExpiration(a, 0) + return + case <-a.unwatch: + return + } + } + }() + + a.resetTimers(false) +} + +func (a *Addr) Unwatch() { + a.mutex.Lock() + defer a.mutex.Unlock() + + a.expiration3.Stop() + a.expiration2.Stop() + a.expiration1.Stop() + a.expiration.Stop() + + close(a.unwatch) +} + +func NewAddr(hw net.HardwareAddr, addr netip.Addr, lifetime time.Duration, onExpiration func(*Addr, AddrExpirationRemainingEvents)) *Addr { + return &Addr{ + Addr: addr, + Hw: hw, + lifetime: lifetime, + onExpiration: onExpiration, + expiration3: time.NewTimer(lifetime / 3 * 2), + expiration2: time.NewTimer(lifetime / 4 * 3), + expiration1: time.NewTimer(lifetime / 5 * 4), + expiration: time.NewTimer(lifetime), + expirationTime: time.Now().Add(lifetime), + unwatch: make(chan bool), + } +} diff --git a/addr_collection.go b/addr_collection.go new file mode 100644 index 0000000..414938c --- /dev/null +++ b/addr_collection.go @@ -0,0 +1,182 @@ +package ipv6disc + +import ( + "fmt" + "maps" + "net/netip" + "slices" + "sort" + "strings" + "sync" + "time" +) + +type AddrCollection struct { + // string in key avoids looping over Addr.String() in the map + addresses map[string]*Addr + addressesMutex sync.RWMutex +} + +func (c *AddrCollection) Enlist(addr *Addr) (*Addr, bool) { + c.addressesMutex.Lock() + defer c.addressesMutex.Unlock() + + addString := addr.String() + + existing := false + if c.Contains(addr) { + existing = true + } else { + c.addresses[addString] = addr + } + + c.addresses[addString].Seen() + + return c.addresses[addString], existing +} + +func (c *AddrCollection) Remove(addr *Addr) { + c.addressesMutex.Lock() + defer c.addressesMutex.Unlock() + + if c.Contains(addr) { + c.addresses[addr.String()].Unwatch() + delete(c.addresses, addr.String()) + } +} + +func (c *AddrCollection) Join(addr *AddrCollection) { + c.addressesMutex.Lock() + defer c.addressesMutex.Unlock() + + for addr, info := range addr.addresses { + c.addresses[addr] = info + } +} + +func (c *AddrCollection) Contains(addr *Addr) bool { + _, ok := c.addresses[addr.String()] + return ok +} + +func (c *AddrCollection) Equal(other *AddrCollection) bool { + c.addressesMutex.RLock() + defer c.addressesMutex.RUnlock() + other.addressesMutex.RLock() + defer other.addressesMutex.RUnlock() + + if len(c.addresses) != len(other.addresses) { + return false + } + + for addrKey, _ := range c.addresses { + if _, ok := other.addresses[addrKey]; !ok { + return false + } + } + + return true +} + +func (c *AddrCollection) FilterPrefix(prefix netip.Prefix) *AddrCollection { + c.addressesMutex.RLock() + defer c.addressesMutex.RUnlock() + + result := NewAddrCollection() + for _, addr := range c.addresses { + if prefix.Contains(addr.Addr.WithZone("")) { + result.Enlist(addr) + } + } + + return result +} + +func (c *AddrCollection) Filter6() *AddrCollection { + c.addressesMutex.RLock() + defer c.addressesMutex.RUnlock() + + result := NewAddrCollection() + for _, addr := range c.addresses { + if addr.Addr.Is6() { + result.Enlist(addr) + } + } + return result +} + +func (c *AddrCollection) Filter4() *AddrCollection { + c.addressesMutex.RLock() + defer c.addressesMutex.RUnlock() + + result := NewAddrCollection() + for _, addr := range c.addresses { + if addr.Addr.Is4() { + result.Enlist(addr) + } + } + return result +} + +func (c *AddrCollection) Get() []*Addr { + c.addressesMutex.RLock() + defer c.addressesMutex.RUnlock() + + keys := make([]string, 0, len(c.addresses)) + for key := range c.addresses { + keys = append(keys, key) + } + sort.Strings(keys) + + addresses := make([]*Addr, 0, len(c.addresses)) + for _, key := range keys { + addresses = append(addresses, c.addresses[key]) + } + + return addresses +} + +func (c *AddrCollection) Strings() []string { + c.addressesMutex.RLock() + defer c.addressesMutex.RUnlock() + + addressesMap := make(map[string]bool) + for _, addr := range c.addresses { + ip := addr.WithZone("").String() + addressesMap[ip] = true + } + + return slices.Collect(maps.Keys(addressesMap)) +} + +func (c *AddrCollection) PrettyPrint(prefix string) string { + c.addressesMutex.RLock() + defer c.addressesMutex.RUnlock() + + var result strings.Builder + c.addressesMutex.RLock() + + // Get the keys from the map + keys := make([]netip.Addr, 0, len(c.addresses)) + for _, addr := range c.addresses { + keys = append(keys, addr.Addr) + } + + // Sort the keys + sort.Slice(keys, func(i, j int) bool { + return keys[i].Less(keys[j]) + }) + + // Iterate ordered + for _, key := range keys { + ipAddressInfo := c.addresses[key.String()] + fmt.Fprintf(&result, prefix+"%s %s\n", ipAddressInfo.Addr.String(), time.Until(ipAddressInfo.GetExpiration()).Round(time.Second)) + } + c.addressesMutex.RUnlock() + + return result.String() +} + +func NewAddrCollection() *AddrCollection { + return &AddrCollection{addresses: make(map[string]*Addr)} +} diff --git a/cmd/ipv6disc/ipv6disc.go b/cmd/ipv6disc/ipv6disc.go index 623c49b..8718e7a 100644 --- a/cmd/ipv6disc/ipv6disc.go +++ b/cmd/ipv6disc/ipv6disc.go @@ -1,4 +1,4 @@ -package ipv6disc +package main import ( "flag" @@ -9,21 +9,21 @@ import ( "go.uber.org/zap" "go.uber.org/zap/zapcore" + "github.com/miguelangel-nubla/ipv6disc" "github.com/miguelangel-nubla/ipv6disc/pkg/terminal" - "github.com/miguelangel-nubla/ipv6disc/pkg/worker" ) var logLevel string -var ttl time.Duration +var lifetime time.Duration var live bool func init() { flag.StringVar(&logLevel, "log_level", "info", "Logging level (debug, info, warn, error, fatal, panic) default: info") - flag.DurationVar(&ttl, "ttl", 4*time.Hour, "Time to keep a discovered host entry in the table after it has been last seen. This is not the TTL of the DDNS record. Default: 4h") + flag.DurationVar(&lifetime, "lifetime", 4*time.Hour, "Time to keep a discovered host entry after it has been last seen. Default: 4h") flag.BoolVar(&live, "live", false, "Show the currrent state live on the terminal, default: false") } -func Start() { +func main() { flag.Parse() startUpdater() @@ -34,8 +34,11 @@ func startUpdater() { sugar := initializeLogger() - table := worker.NewTable() - err := worker.NewWorker(table, ttl, sugar).Start() + rediscover := lifetime / 3 + + worker := ipv6disc.NewWorker(sugar, rediscover, lifetime) + + err := worker.Start() if err != nil { sugar.Fatalf("can't start worker: %s", err) } @@ -44,7 +47,7 @@ func startUpdater() { for { if live { var result strings.Builder - result.WriteString(table.PrettyPrint(4)) + result.WriteString(worker.State.PrettyPrint(" ")) liveOutput <- result.String() } diff --git a/main.go b/main.go deleted file mode 100644 index 0dbd21a..0000000 --- a/main.go +++ /dev/null @@ -1,7 +0,0 @@ -package main - -import "github.com/miguelangel-nubla/ipv6disc/cmd/ipv6disc" - -func main() { - ipv6disc.Start() -} diff --git a/pkg/ndp/ndp.go b/pkg/ndp/ndp.go index 0860317..59f17fc 100644 --- a/pkg/ndp/ndp.go +++ b/pkg/ndp/ndp.go @@ -47,7 +47,7 @@ func (conn *Conn) SendNeighborSolicitation(target *netip.Addr) error { return nil } -func ListenForNDP(iface *net.Interface, addr netip.Addr, onFoundLinkLayerAddr func(netip.Addr, net.HardwareAddr)) (*Conn, error) { +func ListenForNDP(iface *net.Interface, addr netip.Addr, processNDP func(netip.Addr, net.HardwareAddr)) (*Conn, error) { conn, _, err := ndp.Listen(iface, ndp.Addr(addr.String())) if err != nil { return nil, fmt.Errorf("failed to listen for NDP packets: %v", err) @@ -61,28 +61,28 @@ func ListenForNDP(iface *net.Interface, addr netip.Addr, onFoundLinkLayerAddr fu continue } - processNDPPacket(msg, from, onFoundLinkLayerAddr) + processNDPPacket(msg, from, processNDP) } }() return &Conn{conn, iface, &addr}, nil } -func processNDPPacket(message ndp.Message, from netip.Addr, onFoundLinkLayerAddr func(netip.Addr, net.HardwareAddr)) { +func processNDPPacket(message ndp.Message, from netip.Addr, processNDP func(netip.Addr, net.HardwareAddr)) { switch msg := message.(type) { case *ndp.NeighborAdvertisement: - processNDPOptions(&msg.Options, from, onFoundLinkLayerAddr) + processNDPOptions(&msg.Options, from, processNDP) case *ndp.NeighborSolicitation: - processNDPOptions(&msg.Options, from, onFoundLinkLayerAddr) + processNDPOptions(&msg.Options, from, processNDP) default: } } -func processNDPOptions(options *[]ndp.Option, from netip.Addr, onFoundLinkLayerAddr func(netip.Addr, net.HardwareAddr)) { +func processNDPOptions(options *[]ndp.Option, from netip.Addr, processNDP func(netip.Addr, net.HardwareAddr)) { for _, o := range *options { if linkLayerAddr, ok := o.(*ndp.LinkLayerAddress); ok { - onFoundLinkLayerAddr(from, linkLayerAddr.Addr) + processNDP(from, linkLayerAddr.Addr) } } } diff --git a/pkg/worker/table.go b/pkg/worker/table.go deleted file mode 100644 index 68db3f3..0000000 --- a/pkg/worker/table.go +++ /dev/null @@ -1,260 +0,0 @@ -package worker - -import ( - "fmt" - "net" - "net/netip" - "sort" - "strings" - "sync" - "time" -) - -type IPAddressInfo struct { - Address netip.Addr - Hw net.HardwareAddr - mutex sync.RWMutex - timer *time.Timer - timerExpiration time.Time - timerPreventive1 *time.Timer - timerPreventive2 *time.Timer - timerPreventive3 *time.Timer - stopChannel chan bool - onExpiration func(*IPAddressInfo, int) -} - -type IPAddressSet struct { - addresses map[netip.Addr]*IPAddressInfo - addressesMapMutex sync.RWMutex -} - -type Table struct { - macs map[string]*IPAddressSet - macsMapMutex sync.RWMutex -} - -func (i *IPAddressInfo) IsStillValid() bool { - return i.timerExpiration.After(time.Now()) -} - -func (i *IPAddressInfo) Extend(ttl time.Duration) { - i.mutex.Lock() - i.timerExpiration = time.Now().Add(ttl) - i.timer.Reset(ttl) - i.timerPreventive1.Reset(ttl / 3 * 2) - i.timerPreventive2.Reset(ttl / 4 * 3) - i.timerPreventive3.Reset(ttl / 5 * 4) - i.mutex.Unlock() -} - -func (i *IPAddressInfo) GetExpiration() time.Time { - i.mutex.RLock() - expiration := i.timerExpiration - i.mutex.RUnlock() - return expiration -} - -func (i *IPAddressInfo) GetAddress() netip.Addr { - return i.Address -} - -func (i *IPAddressInfo) Clear() { - i.mutex.Lock() - i.timerExpiration = time.Now() - i.timer.Stop() - i.timerPreventive1.Stop() - i.timerPreventive2.Stop() - i.timerPreventive3.Stop() - close(i.stopChannel) - i.mutex.Unlock() -} - -func NewIPAddressInfo(hw net.HardwareAddr, addr netip.Addr, ttl time.Duration, onExpiration func(*IPAddressInfo, int)) *IPAddressInfo { - info := &IPAddressInfo{ - onExpiration: onExpiration, - Address: addr, - Hw: hw, - timer: time.NewTimer(ttl), - timerExpiration: time.Now().Add(ttl), - timerPreventive1: time.NewTimer(ttl / 3 * 2), - timerPreventive2: time.NewTimer(ttl / 4 * 3), - timerPreventive3: time.NewTimer(ttl / 5 * 4), - stopChannel: make(chan bool), - } - - go func() { - for { - select { - case <-info.timer.C: - info.onExpiration(info, 0) - return - case <-info.timerPreventive1.C: - info.onExpiration(info, 1) - case <-info.timerPreventive2.C: - info.onExpiration(info, 2) - case <-info.timerPreventive3.C: - info.onExpiration(info, 3) - case <-info.stopChannel: - return - } - } - }() - - return info -} - -// Add an IPv6 address to the set with a TTL (in seconds). -func (s *IPAddressSet) Add(hw net.HardwareAddr, addr netip.Addr, ttl time.Duration, onExpiration func(*IPAddressInfo, int)) bool { - existing := false - if !s.Contains(addr) { - s.addressesMapMutex.Lock() - s.addresses[addr] = NewIPAddressInfo(hw, addr, ttl, onExpiration) - s.addressesMapMutex.Unlock() - } else { - if (s.addresses[addr]).IsStillValid() { - existing = true - } - (s.addresses[addr]).Extend(ttl) - (s.addresses[addr]).onExpiration = onExpiration - } - return existing -} - -// Remove an IPv6 address from the set. -func (s *IPAddressSet) Remove(addr netip.Addr) { - if s.Contains(addr) { - (s.addresses[addr]).Clear() - - s.addressesMapMutex.Lock() - delete(s.addresses, addr) - s.addressesMapMutex.Unlock() - } -} - -// Check if an IPv6 address exists in the set and is not expired. -func (s *IPAddressSet) Contains(addr netip.Addr) bool { - s.addressesMapMutex.RLock() - info, ok := s.addresses[addr] - s.addressesMapMutex.RUnlock() - if !ok { - return false - } - return time.Now().Before(info.GetExpiration()) -} - -func (s *IPAddressSet) PrettyPrint(tabSize int) string { - indent := func(level int) string { - return strings.Repeat(" ", level*tabSize) - } - - var result strings.Builder - s.addressesMapMutex.RLock() - - // Get the keys from the map - keys := make([]netip.Addr, 0, len(s.addresses)) - for k := range s.addresses { - keys = append(keys, k) - } - - // Sort the keys - sort.Slice(keys, func(i, j int) bool { - return keys[i].Less(keys[j]) - }) - - // Iterate ordered - for _, key := range keys { - ipAddressInfo := s.addresses[key] - fmt.Fprintf(&result, indent(2)+"%s %s\n", ipAddressInfo.Address.String(), time.Until(ipAddressInfo.GetExpiration()).Round(time.Second)) - } - s.addressesMapMutex.RUnlock() - - return result.String() -} - -func NewIPAddressSet() *IPAddressSet { - return &IPAddressSet{addresses: make(map[netip.Addr]*IPAddressInfo)} -} - -// Add macs address to the set with a TTL (in seconds). -func (t *Table) Add(hw net.HardwareAddr, addr netip.Addr, ttl time.Duration, onExpiration func(*IPAddressInfo, int)) bool { - mac := hw.String() - if !t.Contains(hw) { - t.macsMapMutex.Lock() - t.macs[mac] = NewIPAddressSet() - t.macsMapMutex.Unlock() - } - return (t.macs[mac]).Add(hw, addr, ttl, onExpiration) -} - -// Remove an MACs address from the set. -func (t *Table) Remove(hw net.HardwareAddr) { - mac := hw.String() - t.macsMapMutex.Lock() - delete(t.macs, mac) - t.macsMapMutex.Unlock() -} - -// Check if an MAC address exists in the set and is not expired. -func (t *Table) Contains(hw net.HardwareAddr) bool { - mac := hw.String() - t.macsMapMutex.RLock() - _, ok := t.macs[mac] - t.macsMapMutex.RUnlock() - return ok -} - -func (t *Table) Filter(hws []net.HardwareAddr, prefixes []netip.Prefix) []*IPAddressInfo { - found := []*IPAddressInfo{} - for _, prefix := range prefixes { - for _, hw := range hws { - mac := hw.String() - t.macsMapMutex.RLock() - if t.Contains(hw) { - for _, ipAddressInfo := range t.macs[mac].addresses { - // Remove zone identifier from netip.Addr, zones strip prefixes - test := netip.AddrFrom16(ipAddressInfo.Address.As16()) - if prefix.Contains(test) { - found = append(found, ipAddressInfo) - } - } - } - t.macsMapMutex.RUnlock() - } - } - - return found -} - -func (t *Table) PrettyPrint(tabSize int) string { - indent := func(level int) string { - return strings.Repeat(" ", level*tabSize) - } - var result strings.Builder - - result.WriteString("Table:\n") - - t.macsMapMutex.RLock() - - // Get the keys from the map - keys := make([]string, 0, len(t.macs)) - for k := range t.macs { - keys = append(keys, k) - } - - // Sort the keys - sort.Strings(keys) - - // Iterate ordered - for _, key := range keys { - fmt.Fprintf(&result, indent(1)+"%s:\n", key) - result.WriteString(t.macs[key].PrettyPrint(tabSize)) - } - - t.macsMapMutex.RUnlock() - - return result.String() -} - -func NewTable() *Table { - return &Table{macs: make(map[string]*IPAddressSet)} -} diff --git a/state.go b/state.go new file mode 100644 index 0000000..a6e6572 --- /dev/null +++ b/state.go @@ -0,0 +1,93 @@ +package ipv6disc + +import ( + "fmt" + "net" + "net/netip" + "sort" + "strings" + "sync" + "time" +) + +type State struct { + macs map[string]*AddrCollection + macsMutex sync.RWMutex + addrDefaultLifetime time.Duration + addrDefaultOnExpiration func(*Addr, AddrExpirationRemainingEvents) +} + +// accepts default TTL and onExpiration function +func (s *State) Enlist(hw net.HardwareAddr, netipAddr netip.Addr, ttl time.Duration, onExpiration func(*Addr, AddrExpirationRemainingEvents)) (*Addr, bool) { + s.macsMutex.Lock() + defer s.macsMutex.Unlock() + + mac := hw.String() + _, exists := s.macs[mac] + if !exists { + s.macs[mac] = NewAddrCollection() + } + + if ttl == 0 { + ttl = s.addrDefaultLifetime + } + + if onExpiration == nil { + onExpiration = s.addrDefaultOnExpiration + } + + newAddr := NewAddr(hw, netipAddr, ttl, onExpiration) + + return s.macs[mac].Enlist(newAddr) +} + +func (s *State) Filter(hws []net.HardwareAddr, prefixes []netip.Prefix) *AddrCollection { + results := NewAddrCollection() + for _, prefix := range prefixes { + for _, hw := range hws { + s.macsMutex.RLock() + collection, exists := s.macs[hw.String()] + if exists { + results.Join(collection.FilterPrefix(prefix)) + } + s.macsMutex.RUnlock() + } + } + + return results +} + +func (s *State) PrettyPrint(prefix string) string { + var result strings.Builder + + fmt.Fprintf(&result, "%sDiscovery:\n", prefix) + + s.macsMutex.RLock() + + // Get the keys from the map + keys := make([]string, 0, len(s.macs)) + for k := range s.macs { + keys = append(keys, k) + } + + // Sort the keys + sort.Strings(keys) + + // Iterate ordered + for _, key := range keys { + fmt.Fprintf(&result, "%s %s\n", prefix, key) + fmt.Fprint(&result, s.macs[key].PrettyPrint(prefix+" ")) + } + + s.macsMutex.RUnlock() + + return result.String() +} + +func NewState(lifetime time.Duration) *State { + return &State{ + macs: make(map[string]*AddrCollection), + addrDefaultLifetime: lifetime, + addrDefaultOnExpiration: func(addr *Addr, remainingEvents AddrExpirationRemainingEvents) {}, + } +} diff --git a/pkg/worker/worker.go b/worker.go similarity index 74% rename from pkg/worker/worker.go rename to worker.go index 3401c5b..71f638e 100644 --- a/pkg/worker/worker.go +++ b/worker.go @@ -1,4 +1,4 @@ -package worker +package ipv6disc import ( "errors" @@ -23,16 +23,18 @@ func (e *InvalidInterfaceError) Error() string { } type Worker struct { - logger *zap.SugaredLogger - Table *Table - ttl time.Duration + *State + logger *zap.SugaredLogger + rediscover time.Duration + lifetime time.Duration } -func NewWorker(table *Table, ttl time.Duration, logger *zap.SugaredLogger) *Worker { +func NewWorker(logger *zap.SugaredLogger, rediscover time.Duration, lifetime time.Duration) *Worker { return &Worker{ - logger: logger, - Table: table, - ttl: ttl, + State: NewState(lifetime), + logger: logger, + rediscover: rediscover, + lifetime: lifetime, } } @@ -92,60 +94,52 @@ func (w *Worker) StartInterfaceAddr(iface net.Interface, addr netip.Addr) { var ssdpConn *ssdp.Conn var wsdConn *wsd.Conn - // manage NDP - onFoundLinkLayerAddr := func(hostAddr netip.Addr, linkLayerAddr net.HardwareAddr) { - onExpiration := func(info *IPAddressInfo, attempt int) { - address := info.GetAddress() - if attempt == 0 { - w.logger.Infow("host expired", - zap.String("ipv6", netip.AddrFrom16(address.As16()).String()), - zap.String("mac", linkLayerAddr.String()), - zap.String("iface", address.Zone()), - ) - } else { - w.logger.Debugw("host not seen for a while", - zap.String("ipv6", netip.AddrFrom16(address.As16()).String()), - zap.String("mac", linkLayerAddr.String()), - zap.String("iface", address.Zone()), - zap.Int("attempt", attempt), - ) - } - - // do ping - if pingConn != nil { - w.logger.Debugw("ping", - zap.String("ipv6", address.String()), - ) - target := netip.MustParseAddr(address.String()) - err := pingConn.SendPing(&target) - if err != nil { - w.logger.Errorw("ping failed", - zap.String("ipv6", address.String()), - zap.Error(err), - ) - } - } else { - w.logger.Errorw("unable to ping, connection not available", - zap.String("ipv6", address.String()), - ) - } + w.State.addrDefaultOnExpiration = func(addr *Addr, remainingEvents AddrExpirationRemainingEvents) { + if remainingEvents == 0 { + w.logger.Infow("host expired", + zap.String("ipv6", netip.AddrFrom16(addr.As16()).String()), + zap.String("mac", addr.Hw.String()), + zap.String("iface", addr.Zone()), + ) + return + } + + w.logger.Debugw("host not seen for a while, pinging", + zap.String("ipv6", netip.AddrFrom16(addr.As16()).String()), + zap.String("mac", addr.Hw.String()), + zap.String("iface", addr.Zone()), + zap.Int("remainingEvents", int(remainingEvents)), + ) + + err := pingConn.SendPing(&addr.Addr) + if err != nil { + w.logger.Errorw("ping failed", + zap.String("ipv6", netip.AddrFrom16(addr.As16()).String()), + zap.String("mac", addr.Hw.String()), + zap.String("iface", addr.Zone()), + zap.Error(err), + ) } - existing := w.Table.Add(linkLayerAddr, hostAddr, w.ttl, onExpiration) + } + + processNDP := func(netipAddr netip.Addr, netHardwareAddr net.HardwareAddr) { + addr, existing := w.State.Enlist(netHardwareAddr, netipAddr, 0, nil) if existing { w.logger.Debugw("ttl refreshed", - zap.String("ipv6", hostAddr.String()), - zap.String("mac", linkLayerAddr.String()), + zap.String("ipv6", netipAddr.String()), + zap.String("mac", netHardwareAddr.String()), ) } else { + addr.Watch() w.logger.Infow("host identified", - zap.String("ipv6", netip.AddrFrom16(hostAddr.As16()).String()), - zap.String("mac", linkLayerAddr.String()), - zap.String("iface", hostAddr.Zone()), + zap.String("ipv6", netip.AddrFrom16(netipAddr.As16()).String()), + zap.String("mac", netHardwareAddr.String()), + zap.String("iface", netipAddr.Zone()), ) } } - ndpConn, err = ndp.ListenForNDP(&iface, addr, onFoundLinkLayerAddr) + ndpConn, err = ndp.ListenForNDP(&iface, addr, processNDP) if err != nil { w.logger.Fatalf("error listening for NDP on interface %s: %s", iface.Name, err) } @@ -183,7 +177,7 @@ func (w *Worker) StartInterfaceAddr(iface net.Interface, addr netip.Addr) { discover() // Periodic re-discovery - ticker := time.NewTicker(w.ttl / 3) + ticker := time.NewTicker(w.rediscover) defer ticker.Stop() for range ticker.C { discover()