Skip to content

Commit

Permalink
update mdns to remove race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
asim committed Feb 1, 2019
1 parent 652b106 commit 88e1234
Showing 1 changed file with 31 additions and 19 deletions.
50 changes: 31 additions & 19 deletions registry/mdns_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package registry

import (
"context"
"net"
"strings"
"sync"
Expand Down Expand Up @@ -194,20 +195,20 @@ func (m *mdnsRegistry) Deregister(service *Service) error {
}

func (m *mdnsRegistry) GetService(service string) ([]*Service, error) {
p := mdns.DefaultParams(service)
p.Timeout = m.opts.Timeout
entryCh := make(chan *mdns.ServiceEntry, 10)
p.Entries = entryCh

exit := make(chan bool)
defer close(exit)

serviceMap := make(map[string]*Service)
entries := make(chan *mdns.ServiceEntry, 10)
done := make(chan bool)

p := mdns.DefaultParams(service)
// set context with timeout
p.Context, _ = context.WithTimeout(context.Background(), m.opts.Timeout)
// set entries channel
p.Entries = entries

go func() {
for {
select {
case e := <-entryCh:
case e := <-entries:
// list record so skip
if p.Service == "_services" {
continue
Expand Down Expand Up @@ -243,16 +244,21 @@ func (m *mdnsRegistry) GetService(service string) ([]*Service, error) {
})

serviceMap[txt.Version] = s
case <-exit:
case <-p.Context.Done():
close(done)
return
}
}
}()

// execute the query
if err := mdns.Query(p); err != nil {
return nil, err
}

// wait for completion
<-done

// create list and return
var services []*Service

Expand All @@ -264,21 +270,22 @@ func (m *mdnsRegistry) GetService(service string) ([]*Service, error) {
}

func (m *mdnsRegistry) ListServices() ([]*Service, error) {
p := mdns.DefaultParams("_services")
p.Timeout = m.opts.Timeout
entryCh := make(chan *mdns.ServiceEntry, 10)
p.Entries = entryCh
serviceMap := make(map[string]bool)
entries := make(chan *mdns.ServiceEntry, 10)
done := make(chan bool)

exit := make(chan bool)
defer close(exit)
p := mdns.DefaultParams("_services")
// set context with timeout
p.Context, _ = context.WithTimeout(context.Background(), m.opts.Timeout)
// set entries channel
p.Entries = entries

serviceMap := make(map[string]bool)
var services []*Service

go func() {
for {
select {
case e := <-entryCh:
case e := <-entries:
if e.TTL == 0 {
continue
}
Expand All @@ -288,16 +295,21 @@ func (m *mdnsRegistry) ListServices() ([]*Service, error) {
serviceMap[name] = true
services = append(services, &Service{Name: name})
}
case <-exit:
case <-p.Context.Done():
close(done)
return
}
}
}()

// execute query
if err := mdns.Query(p); err != nil {
return nil, err
}

// wait till done
<-done

return services, nil
}

Expand Down

0 comments on commit 88e1234

Please sign in to comment.