diff options
Diffstat (limited to 'libgo/go/net/lookup_windows.go')
-rw-r--r-- | libgo/go/net/lookup_windows.go | 79 |
1 files changed, 69 insertions, 10 deletions
diff --git a/libgo/go/net/lookup_windows.go b/libgo/go/net/lookup_windows.go index 130364231d4..6a925b0a7ad 100644 --- a/libgo/go/net/lookup_windows.go +++ b/libgo/go/net/lookup_windows.go @@ -210,14 +210,21 @@ func lookupCNAME(name string) (cname string, err error) { defer releaseThread() var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil) + // windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s + if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS { + // if there are no aliases, the canonical name is the input name + if name == "" || name[len(name)-1] != '.' { + return name + ".", nil + } + return name, nil + } if e != nil { return "", os.NewSyscallError("LookupCNAME", e) } defer syscall.DnsRecordListFree(r, 1) - if r != nil && r.Type == syscall.DNS_TYPE_CNAME { - v := (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])) - cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "." - } + + resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r) + cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:]) + "." return } @@ -236,8 +243,9 @@ func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err return "", nil, os.NewSyscallError("LookupSRV", e) } defer syscall.DnsRecordListFree(r, 1) + addrs = make([]*SRV, 0, 10) - for p := r; p != nil && p.Type == syscall.DNS_TYPE_SRV; p = p.Next { + for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) { v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0])) addrs = append(addrs, &SRV{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]), v.Port, v.Priority, v.Weight}) } @@ -254,8 +262,9 @@ func lookupMX(name string) (mx []*MX, err error) { return nil, os.NewSyscallError("LookupMX", e) } defer syscall.DnsRecordListFree(r, 1) + mx = make([]*MX, 0, 10) - for p := r; p != nil && p.Type == syscall.DNS_TYPE_MX; p = p.Next { + for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) { v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0])) mx = append(mx, &MX{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]) + ".", v.Preference}) } @@ -272,8 +281,9 @@ func lookupNS(name string) (ns []*NS, err error) { return nil, os.NewSyscallError("LookupNS", e) } defer syscall.DnsRecordListFree(r, 1) + ns = make([]*NS, 0, 10) - for p := r; p != nil && p.Type == syscall.DNS_TYPE_NS; p = p.Next { + for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) { v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) ns = append(ns, &NS{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."}) } @@ -289,9 +299,10 @@ func lookupTXT(name string) (txt []string, err error) { return nil, os.NewSyscallError("LookupTXT", e) } defer syscall.DnsRecordListFree(r, 1) + txt = make([]string, 0, 10) - if r != nil && r.Type == syscall.DNS_TYPE_TEXT { - d := (*syscall.DNSTXTData)(unsafe.Pointer(&r.Data[0])) + for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) { + d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0])) for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] { s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:]) txt = append(txt, s) @@ -313,10 +324,58 @@ func lookupAddr(addr string) (name []string, err error) { return nil, os.NewSyscallError("LookupAddr", e) } defer syscall.DnsRecordListFree(r, 1) + name = make([]string, 0, 10) - for p := r; p != nil && p.Type == syscall.DNS_TYPE_PTR; p = p.Next { + for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) { v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) name = append(name, syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:])) } return name, nil } + +const dnsSectionMask = 0x0003 + +// returns only results applicable to name and resolves CNAME entries +func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord { + cname := syscall.StringToUTF16Ptr(name) + if dnstype != syscall.DNS_TYPE_CNAME { + cname = resolveCNAME(cname, r) + } + rec := make([]*syscall.DNSRecord, 0, 10) + for p := r; p != nil; p = p.Next { + if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer { + continue + } + if p.Type != dnstype { + continue + } + if !syscall.DnsNameCompare(cname, p.Name) { + continue + } + rec = append(rec, p) + } + return rec +} + +// returns the last CNAME in chain +func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 { + // limit cname resolving to 10 in case of a infinite CNAME loop +Cname: + for cnameloop := 0; cnameloop < 10; cnameloop++ { + for p := r; p != nil; p = p.Next { + if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer { + continue + } + if p.Type != syscall.DNS_TYPE_CNAME { + continue + } + if !syscall.DnsNameCompare(name, p.Name) { + continue + } + name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host + continue Cname + } + break + } + return name +} |