summaryrefslogtreecommitdiff
path: root/libgo/go/net/lookup_windows.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/net/lookup_windows.go')
-rw-r--r--libgo/go/net/lookup_windows.go79
1 files changed, 69 insertions, 10 deletions
diff --git a/libgo/go/net/lookup_windows.go b/libgo/go/net/lookup_windows.go
index 130364231d4..6a925b0a7ad 100644
--- a/libgo/go/net/lookup_windows.go
+++ b/libgo/go/net/lookup_windows.go
@@ -210,14 +210,21 @@ func lookupCNAME(name string) (cname string, err error) {
defer releaseThread()
var r *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
+ // windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
+ if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
+ // if there are no aliases, the canonical name is the input name
+ if name == "" || name[len(name)-1] != '.' {
+ return name + ".", nil
+ }
+ return name, nil
+ }
if e != nil {
return "", os.NewSyscallError("LookupCNAME", e)
}
defer syscall.DnsRecordListFree(r, 1)
- if r != nil && r.Type == syscall.DNS_TYPE_CNAME {
- v := (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0]))
- cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."
- }
+
+ resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
+ cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:]) + "."
return
}
@@ -236,8 +243,9 @@ func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err
return "", nil, os.NewSyscallError("LookupSRV", e)
}
defer syscall.DnsRecordListFree(r, 1)
+
addrs = make([]*SRV, 0, 10)
- for p := r; p != nil && p.Type == syscall.DNS_TYPE_SRV; p = p.Next {
+ for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
addrs = append(addrs, &SRV{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]), v.Port, v.Priority, v.Weight})
}
@@ -254,8 +262,9 @@ func lookupMX(name string) (mx []*MX, err error) {
return nil, os.NewSyscallError("LookupMX", e)
}
defer syscall.DnsRecordListFree(r, 1)
+
mx = make([]*MX, 0, 10)
- for p := r; p != nil && p.Type == syscall.DNS_TYPE_MX; p = p.Next {
+ for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
mx = append(mx, &MX{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]) + ".", v.Preference})
}
@@ -272,8 +281,9 @@ func lookupNS(name string) (ns []*NS, err error) {
return nil, os.NewSyscallError("LookupNS", e)
}
defer syscall.DnsRecordListFree(r, 1)
+
ns = make([]*NS, 0, 10)
- for p := r; p != nil && p.Type == syscall.DNS_TYPE_NS; p = p.Next {
+ for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
ns = append(ns, &NS{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."})
}
@@ -289,9 +299,10 @@ func lookupTXT(name string) (txt []string, err error) {
return nil, os.NewSyscallError("LookupTXT", e)
}
defer syscall.DnsRecordListFree(r, 1)
+
txt = make([]string, 0, 10)
- if r != nil && r.Type == syscall.DNS_TYPE_TEXT {
- d := (*syscall.DNSTXTData)(unsafe.Pointer(&r.Data[0]))
+ for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) {
+ d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] {
s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:])
txt = append(txt, s)
@@ -313,10 +324,58 @@ func lookupAddr(addr string) (name []string, err error) {
return nil, os.NewSyscallError("LookupAddr", e)
}
defer syscall.DnsRecordListFree(r, 1)
+
name = make([]string, 0, 10)
- for p := r; p != nil && p.Type == syscall.DNS_TYPE_PTR; p = p.Next {
+ for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
name = append(name, syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))
}
return name, nil
}
+
+const dnsSectionMask = 0x0003
+
+// returns only results applicable to name and resolves CNAME entries
+func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord {
+ cname := syscall.StringToUTF16Ptr(name)
+ if dnstype != syscall.DNS_TYPE_CNAME {
+ cname = resolveCNAME(cname, r)
+ }
+ rec := make([]*syscall.DNSRecord, 0, 10)
+ for p := r; p != nil; p = p.Next {
+ if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
+ continue
+ }
+ if p.Type != dnstype {
+ continue
+ }
+ if !syscall.DnsNameCompare(cname, p.Name) {
+ continue
+ }
+ rec = append(rec, p)
+ }
+ return rec
+}
+
+// returns the last CNAME in chain
+func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 {
+ // limit cname resolving to 10 in case of a infinite CNAME loop
+Cname:
+ for cnameloop := 0; cnameloop < 10; cnameloop++ {
+ for p := r; p != nil; p = p.Next {
+ if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
+ continue
+ }
+ if p.Type != syscall.DNS_TYPE_CNAME {
+ continue
+ }
+ if !syscall.DnsNameCompare(name, p.Name) {
+ continue
+ }
+ name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host
+ continue Cname
+ }
+ break
+ }
+ return name
+}