summaryrefslogtreecommitdiff
path: root/chromium/net/dns
diff options
context:
space:
mode:
Diffstat (limited to 'chromium/net/dns')
-rw-r--r--chromium/net/dns/address_sorter.h46
-rw-r--r--chromium/net/dns/address_sorter_posix.cc426
-rw-r--r--chromium/net/dns/address_sorter_posix.h94
-rw-r--r--chromium/net/dns/address_sorter_posix_unittest.cc327
-rw-r--r--chromium/net/dns/address_sorter_unittest.cc66
-rw-r--r--chromium/net/dns/address_sorter_win.cc198
-rw-r--r--chromium/net/dns/dns_client.cc71
-rw-r--r--chromium/net/dns/dns_client.h44
-rw-r--r--chromium/net/dns/dns_config_service.cc226
-rw-r--r--chromium/net/dns/dns_config_service.h174
-rw-r--r--chromium/net/dns/dns_config_service_posix.cc404
-rw-r--r--chromium/net/dns/dns_config_service_posix.h67
-rw-r--r--chromium/net/dns/dns_config_service_posix_unittest.cc156
-rw-r--r--chromium/net/dns/dns_config_service_unittest.cc258
-rw-r--r--chromium/net/dns/dns_config_service_win.cc737
-rw-r--r--chromium/net/dns/dns_config_service_win.h154
-rw-r--r--chromium/net/dns/dns_config_service_win_unittest.cc430
-rw-r--r--chromium/net/dns/dns_hosts.cc169
-rw-r--r--chromium/net/dns/dns_hosts.h79
-rw-r--r--chromium/net/dns/dns_hosts_unittest.cc124
-rw-r--r--chromium/net/dns/dns_protocol.h143
-rw-r--r--chromium/net/dns/dns_query.cc89
-rw-r--r--chromium/net/dns/dns_query.h58
-rw-r--r--chromium/net/dns/dns_query_unittest.cc69
-rw-r--r--chromium/net/dns/dns_response.cc337
-rw-r--r--chromium/net/dns/dns_response.h169
-rw-r--r--chromium/net/dns/dns_response_unittest.cc581
-rw-r--r--chromium/net/dns/dns_session.cc298
-rw-r--r--chromium/net/dns/dns_session.h147
-rw-r--r--chromium/net/dns/dns_session_unittest.cc252
-rw-r--r--chromium/net/dns/dns_socket_pool.cc234
-rw-r--r--chromium/net/dns/dns_socket_pool.h91
-rw-r--r--chromium/net/dns/dns_test_util.cc210
-rw-r--r--chromium/net/dns/dns_test_util.h205
-rw-r--r--chromium/net/dns/dns_transaction.cc963
-rw-r--r--chromium/net/dns/dns_transaction.h78
-rw-r--r--chromium/net/dns/dns_transaction_unittest.cc940
-rw-r--r--chromium/net/dns/host_cache.cc122
-rw-r--r--chromium/net/dns/host_cache.h124
-rw-r--r--chromium/net/dns/host_cache_unittest.cc388
-rw-r--r--chromium/net/dns/host_resolver.cc145
-rw-r--r--chromium/net/dns/host_resolver.h204
-rw-r--r--chromium/net/dns/host_resolver_impl.cc2206
-rw-r--r--chromium/net/dns/host_resolver_impl.h285
-rw-r--r--chromium/net/dns/host_resolver_impl_unittest.cc1641
-rw-r--r--chromium/net/dns/host_resolver_proc.cc267
-rw-r--r--chromium/net/dns/host_resolver_proc.h111
-rw-r--r--chromium/net/dns/mapped_host_resolver.cc63
-rw-r--r--chromium/net/dns/mapped_host_resolver.h71
-rw-r--r--chromium/net/dns/mapped_host_resolver_unittest.cc219
-rw-r--r--chromium/net/dns/mdns_cache.cc212
-rw-r--r--chromium/net/dns/mdns_cache.h119
-rw-r--r--chromium/net/dns/mdns_cache_unittest.cc375
-rw-r--r--chromium/net/dns/mdns_client.cc17
-rw-r--r--chromium/net/dns/mdns_client.h158
-rw-r--r--chromium/net/dns/mdns_client_impl.cc671
-rw-r--r--chromium/net/dns/mdns_client_impl.h298
-rw-r--r--chromium/net/dns/mdns_client_unittest.cc1176
-rw-r--r--chromium/net/dns/mock_host_resolver.cc408
-rw-r--r--chromium/net/dns/mock_host_resolver.h284
-rw-r--r--chromium/net/dns/mock_mdns_socket_factory.cc115
-rw-r--r--chromium/net/dns/mock_mdns_socket_factory.h101
-rw-r--r--chromium/net/dns/notify_watcher_mac.cc64
-rw-r--r--chromium/net/dns/notify_watcher_mac.h47
-rw-r--r--chromium/net/dns/record_parsed.cc86
-rw-r--r--chromium/net/dns/record_parsed.h64
-rw-r--r--chromium/net/dns/record_parsed_unittest.cc75
-rw-r--r--chromium/net/dns/record_rdata.cc287
-rw-r--r--chromium/net/dns/record_rdata.h217
-rw-r--r--chromium/net/dns/record_rdata_unittest.cc222
-rw-r--r--chromium/net/dns/serial_worker.cc104
-rw-r--r--chromium/net/dns/serial_worker.h96
-rw-r--r--chromium/net/dns/serial_worker_unittest.cc163
-rw-r--r--chromium/net/dns/single_request_host_resolver.cc77
-rw-r--r--chromium/net/dns/single_request_host_resolver.h56
-rw-r--r--chromium/net/dns/single_request_host_resolver_unittest.cc124
76 files changed, 20576 insertions, 0 deletions
diff --git a/chromium/net/dns/address_sorter.h b/chromium/net/dns/address_sorter.h
new file mode 100644
index 00000000000..6ac943086ef
--- /dev/null
+++ b/chromium/net/dns/address_sorter.h
@@ -0,0 +1,46 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_ADDRESS_SORTER_H_
+#define NET_DNS_ADDRESS_SORTER_H_
+
+#include "base/basictypes.h"
+#include "base/callback.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/net_export.h"
+
+namespace net {
+
+class AddressList;
+
+// Sorts AddressList according to RFC3484, by likelihood of successful
+// connection. Depending on the platform, the sort could be performed
+// asynchronously by the OS, or synchronously by local implementation.
+// AddressSorter does not necessarily preserve port numbers on the sorted list.
+class NET_EXPORT AddressSorter {
+ public:
+ typedef base::Callback<void(bool success,
+ const AddressList& list)> CallbackType;
+
+ virtual ~AddressSorter() {}
+
+ // Sorts |list|, which must include at least one IPv6 address.
+ // Calls |callback| upon completion. Could complete synchronously. Could
+ // complete after this AddressSorter is destroyed.
+ virtual void Sort(const AddressList& list,
+ const CallbackType& callback) const = 0;
+
+ // Creates platform-dependent AddressSorter.
+ static scoped_ptr<AddressSorter> CreateAddressSorter();
+
+ protected:
+ AddressSorter() {}
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(AddressSorter);
+};
+
+} // namespace net
+
+#endif // NET_DNS_ADDRESS_SORTER_H_
diff --git a/chromium/net/dns/address_sorter_posix.cc b/chromium/net/dns/address_sorter_posix.cc
new file mode 100644
index 00000000000..8d87774587d
--- /dev/null
+++ b/chromium/net/dns/address_sorter_posix.cc
@@ -0,0 +1,426 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/address_sorter_posix.h"
+
+#include <netinet/in.h>
+
+#if defined(OS_MACOSX) || defined(OS_BSD)
+#include <sys/socket.h> // Must be included before ifaddrs.h.
+#include <ifaddrs.h>
+#include <net/if.h>
+#include <netinet/in_var.h>
+#include <string.h>
+#include <sys/ioctl.h>
+#endif
+
+#include <algorithm>
+
+#include "base/logging.h"
+#include "base/memory/scoped_vector.h"
+#include "base/posix/eintr_wrapper.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/udp/datagram_client_socket.h"
+
+#if defined(OS_LINUX)
+#include "net/base/address_tracker_linux.h"
+#endif
+
+namespace net {
+
+namespace {
+
+// Address sorting is performed according to RFC3484 with revisions.
+// http://tools.ietf.org/html/draft-ietf-6man-rfc3484bis-06
+// Precedence and label are separate to support override through /etc/gai.conf.
+
+// Returns true if |p1| should precede |p2| in the table.
+// Sorts table by decreasing prefix size to allow longest prefix matching.
+bool ComparePolicy(const AddressSorterPosix::PolicyEntry& p1,
+ const AddressSorterPosix::PolicyEntry& p2) {
+ return p1.prefix_length > p2.prefix_length;
+}
+
+// Creates sorted PolicyTable from |table| with |size| entries.
+AddressSorterPosix::PolicyTable LoadPolicy(
+ AddressSorterPosix::PolicyEntry* table,
+ size_t size) {
+ AddressSorterPosix::PolicyTable result(table, table + size);
+ std::sort(result.begin(), result.end(), ComparePolicy);
+ return result;
+}
+
+// Search |table| for matching prefix of |address|. |table| must be sorted by
+// descending prefix (prefix of another prefix must be later in table).
+unsigned GetPolicyValue(const AddressSorterPosix::PolicyTable& table,
+ const IPAddressNumber& address) {
+ if (address.size() == kIPv4AddressSize)
+ return GetPolicyValue(table, ConvertIPv4NumberToIPv6Number(address));
+ for (unsigned i = 0; i < table.size(); ++i) {
+ const AddressSorterPosix::PolicyEntry& entry = table[i];
+ IPAddressNumber prefix(entry.prefix, entry.prefix + kIPv6AddressSize);
+ if (IPNumberMatchesPrefix(address, prefix, entry.prefix_length))
+ return entry.value;
+ }
+ NOTREACHED();
+ // The last entry is the least restrictive, so assume it's default.
+ return table.back().value;
+}
+
+bool IsIPv6Multicast(const IPAddressNumber& address) {
+ DCHECK_EQ(kIPv6AddressSize, address.size());
+ return address[0] == 0xFF;
+}
+
+AddressSorterPosix::AddressScope GetIPv6MulticastScope(
+ const IPAddressNumber& address) {
+ DCHECK_EQ(kIPv6AddressSize, address.size());
+ return static_cast<AddressSorterPosix::AddressScope>(address[1] & 0x0F);
+}
+
+bool IsIPv6Loopback(const IPAddressNumber& address) {
+ DCHECK_EQ(kIPv6AddressSize, address.size());
+ // IN6_IS_ADDR_LOOPBACK
+ unsigned char kLoopback[kIPv6AddressSize] = {
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 1,
+ };
+ return address == IPAddressNumber(kLoopback, kLoopback + kIPv6AddressSize);
+}
+
+bool IsIPv6LinkLocal(const IPAddressNumber& address) {
+ DCHECK_EQ(kIPv6AddressSize, address.size());
+ // IN6_IS_ADDR_LINKLOCAL
+ return (address[0] == 0xFE) && ((address[1] & 0xC0) == 0x80);
+}
+
+bool IsIPv6SiteLocal(const IPAddressNumber& address) {
+ DCHECK_EQ(kIPv6AddressSize, address.size());
+ // IN6_IS_ADDR_SITELOCAL
+ return (address[0] == 0xFE) && ((address[1] & 0xC0) == 0xC0);
+}
+
+AddressSorterPosix::AddressScope GetScope(
+ const AddressSorterPosix::PolicyTable& ipv4_scope_table,
+ const IPAddressNumber& address) {
+ if (address.size() == kIPv6AddressSize) {
+ if (IsIPv6Multicast(address)) {
+ return GetIPv6MulticastScope(address);
+ } else if (IsIPv6Loopback(address) || IsIPv6LinkLocal(address)) {
+ return AddressSorterPosix::SCOPE_LINKLOCAL;
+ } else if (IsIPv6SiteLocal(address)) {
+ return AddressSorterPosix::SCOPE_SITELOCAL;
+ } else {
+ return AddressSorterPosix::SCOPE_GLOBAL;
+ }
+ } else if (address.size() == kIPv4AddressSize) {
+ return static_cast<AddressSorterPosix::AddressScope>(
+ GetPolicyValue(ipv4_scope_table, address));
+ } else {
+ NOTREACHED();
+ return AddressSorterPosix::SCOPE_NODELOCAL;
+ }
+}
+
+// Default policy table. RFC 3484, Section 2.1.
+AddressSorterPosix::PolicyEntry kDefaultPrecedenceTable[] = {
+ // ::1/128 -- loopback
+ { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, 128, 50 },
+ // ::/0 -- any
+ { { }, 0, 40 },
+ // ::ffff:0:0/96 -- IPv4 mapped
+ { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF }, 96, 35 },
+ // 2002::/16 -- 6to4
+ { { 0x20, 0x02, }, 16, 30 },
+ // 2001::/32 -- Teredo
+ { { 0x20, 0x01, 0, 0 }, 32, 5 },
+ // fc00::/7 -- unique local address
+ { { 0xFC }, 7, 3 },
+ // ::/96 -- IPv4 compatible
+ { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 96, 1 },
+ // fec0::/10 -- site-local expanded scope
+ { { 0xFE, 0xC0 }, 10, 1 },
+ // 3ffe::/16 -- 6bone
+ { { 0x3F, 0xFE }, 16, 1 },
+};
+
+AddressSorterPosix::PolicyEntry kDefaultLabelTable[] = {
+ // ::1/128 -- loopback
+ { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, 128, 0 },
+ // ::/0 -- any
+ { { }, 0, 1 },
+ // ::ffff:0:0/96 -- IPv4 mapped
+ { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF }, 96, 4 },
+ // 2002::/16 -- 6to4
+ { { 0x20, 0x02, }, 16, 2 },
+ // 2001::/32 -- Teredo
+ { { 0x20, 0x01, 0, 0 }, 32, 5 },
+ // fc00::/7 -- unique local address
+ { { 0xFC }, 7, 13 },
+ // ::/96 -- IPv4 compatible
+ { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 96, 3 },
+ // fec0::/10 -- site-local expanded scope
+ { { 0xFE, 0xC0 }, 10, 11 },
+ // 3ffe::/16 -- 6bone
+ { { 0x3F, 0xFE }, 16, 12 },
+};
+
+// Default mapping of IPv4 addresses to scope.
+AddressSorterPosix::PolicyEntry kDefaultIPv4ScopeTable[] = {
+ { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0x7F }, 104,
+ AddressSorterPosix::SCOPE_LINKLOCAL },
+ { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xA9, 0xFE }, 112,
+ AddressSorterPosix::SCOPE_LINKLOCAL },
+ { { }, 0, AddressSorterPosix::SCOPE_GLOBAL },
+};
+
+// Returns number of matching initial bits between the addresses |a1| and |a2|.
+unsigned CommonPrefixLength(const IPAddressNumber& a1,
+ const IPAddressNumber& a2) {
+ DCHECK_EQ(a1.size(), a2.size());
+ for (size_t i = 0; i < a1.size(); ++i) {
+ unsigned diff = a1[i] ^ a2[i];
+ if (!diff)
+ continue;
+ for (unsigned j = 0; j < CHAR_BIT; ++j) {
+ if (diff & (1 << (CHAR_BIT - 1)))
+ return i * CHAR_BIT + j;
+ diff <<= 1;
+ }
+ NOTREACHED();
+ }
+ return a1.size() * CHAR_BIT;
+}
+
+// Computes the number of leading 1-bits in |mask|.
+unsigned MaskPrefixLength(const IPAddressNumber& mask) {
+ IPAddressNumber all_ones(mask.size(), 0xFF);
+ return CommonPrefixLength(mask, all_ones);
+}
+
+struct DestinationInfo {
+ IPAddressNumber address;
+ AddressSorterPosix::AddressScope scope;
+ unsigned precedence;
+ unsigned label;
+ const AddressSorterPosix::SourceAddressInfo* src;
+ unsigned common_prefix_length;
+};
+
+// Returns true iff |dst_a| should precede |dst_b| in the address list.
+// RFC 3484, section 6.
+bool CompareDestinations(const DestinationInfo* dst_a,
+ const DestinationInfo* dst_b) {
+ // Rule 1: Avoid unusable destinations.
+ // Unusable destinations are already filtered out.
+ DCHECK(dst_a->src);
+ DCHECK(dst_b->src);
+
+ // Rule 2: Prefer matching scope.
+ bool scope_match1 = (dst_a->src->scope == dst_a->scope);
+ bool scope_match2 = (dst_b->src->scope == dst_b->scope);
+ if (scope_match1 != scope_match2)
+ return scope_match1;
+
+ // Rule 3: Avoid deprecated addresses.
+ if (dst_a->src->deprecated != dst_b->src->deprecated)
+ return !dst_a->src->deprecated;
+
+ // Rule 4: Prefer home addresses.
+ if (dst_a->src->home != dst_b->src->home)
+ return dst_a->src->home;
+
+ // Rule 5: Prefer matching label.
+ bool label_match1 = (dst_a->src->label == dst_a->label);
+ bool label_match2 = (dst_b->src->label == dst_b->label);
+ if (label_match1 != label_match2)
+ return label_match1;
+
+ // Rule 6: Prefer higher precedence.
+ if (dst_a->precedence != dst_b->precedence)
+ return dst_a->precedence > dst_b->precedence;
+
+ // Rule 7: Prefer native transport.
+ if (dst_a->src->native != dst_b->src->native)
+ return dst_a->src->native;
+
+ // Rule 8: Prefer smaller scope.
+ if (dst_a->scope != dst_b->scope)
+ return dst_a->scope < dst_b->scope;
+
+ // Rule 9: Use longest matching prefix. Only for matching address families.
+ if (dst_a->address.size() == dst_b->address.size()) {
+ if (dst_a->common_prefix_length != dst_b->common_prefix_length)
+ return dst_a->common_prefix_length > dst_b->common_prefix_length;
+ }
+
+ // Rule 10: Leave the order unchanged.
+ // stable_sort takes care of that.
+ return false;
+}
+
+} // namespace
+
+AddressSorterPosix::AddressSorterPosix(ClientSocketFactory* socket_factory)
+ : socket_factory_(socket_factory),
+ precedence_table_(LoadPolicy(kDefaultPrecedenceTable,
+ arraysize(kDefaultPrecedenceTable))),
+ label_table_(LoadPolicy(kDefaultLabelTable,
+ arraysize(kDefaultLabelTable))),
+ ipv4_scope_table_(LoadPolicy(kDefaultIPv4ScopeTable,
+ arraysize(kDefaultIPv4ScopeTable))) {
+ NetworkChangeNotifier::AddIPAddressObserver(this);
+ OnIPAddressChanged();
+}
+
+AddressSorterPosix::~AddressSorterPosix() {
+ NetworkChangeNotifier::RemoveIPAddressObserver(this);
+}
+
+void AddressSorterPosix::Sort(const AddressList& list,
+ const CallbackType& callback) const {
+ DCHECK(CalledOnValidThread());
+ ScopedVector<DestinationInfo> sort_list;
+
+ for (size_t i = 0; i < list.size(); ++i) {
+ scoped_ptr<DestinationInfo> info(new DestinationInfo());
+ info->address = list[i].address();
+ info->scope = GetScope(ipv4_scope_table_, info->address);
+ info->precedence = GetPolicyValue(precedence_table_, info->address);
+ info->label = GetPolicyValue(label_table_, info->address);
+
+ // Each socket can only be bound once.
+ scoped_ptr<DatagramClientSocket> socket(
+ socket_factory_->CreateDatagramClientSocket(
+ DatagramSocket::DEFAULT_BIND,
+ RandIntCallback(),
+ NULL /* NetLog */,
+ NetLog::Source()));
+
+ // Even though no packets are sent, cannot use port 0 in Connect.
+ IPEndPoint dest(info->address, 80 /* port */);
+ int rv = socket->Connect(dest);
+ if (rv != OK) {
+ LOG(WARNING) << "Could not connect to " << dest.ToStringWithoutPort()
+ << " reason " << rv;
+ continue;
+ }
+ // Filter out unusable destinations.
+ IPEndPoint src;
+ rv = socket->GetLocalAddress(&src);
+ if (rv != OK) {
+ LOG(WARNING) << "Could not get local address for "
+ << src.ToStringWithoutPort() << " reason " << rv;
+ continue;
+ }
+
+ SourceAddressInfo& src_info = source_map_[src.address()];
+ if (src_info.scope == SCOPE_UNDEFINED) {
+ // If |source_info_| is out of date, |src| might be missing, but we still
+ // want to sort, even though the HostCache will be cleared soon.
+ FillPolicy(src.address(), &src_info);
+ }
+ info->src = &src_info;
+
+ if (info->address.size() == src.address().size()) {
+ info->common_prefix_length = std::min(
+ CommonPrefixLength(info->address, src.address()),
+ info->src->prefix_length);
+ }
+ sort_list.push_back(info.release());
+ }
+
+ std::stable_sort(sort_list.begin(), sort_list.end(), CompareDestinations);
+
+ AddressList result;
+ for (size_t i = 0; i < sort_list.size(); ++i)
+ result.push_back(IPEndPoint(sort_list[i]->address, 0 /* port */));
+
+ callback.Run(true, result);
+}
+
+void AddressSorterPosix::OnIPAddressChanged() {
+ DCHECK(CalledOnValidThread());
+ source_map_.clear();
+#if defined(OS_LINUX)
+ const internal::AddressTrackerLinux* tracker =
+ NetworkChangeNotifier::GetAddressTracker();
+ if (!tracker)
+ return;
+ typedef internal::AddressTrackerLinux::AddressMap AddressMap;
+ AddressMap map = tracker->GetAddressMap();
+ for (AddressMap::const_iterator it = map.begin(); it != map.end(); ++it) {
+ const IPAddressNumber& address = it->first;
+ const struct ifaddrmsg& msg = it->second;
+ SourceAddressInfo& info = source_map_[address];
+ info.native = false; // TODO(szym): obtain this via netlink.
+ info.deprecated = msg.ifa_flags & IFA_F_DEPRECATED;
+ info.home = msg.ifa_flags & IFA_F_HOMEADDRESS;
+ info.prefix_length = msg.ifa_prefixlen;
+ FillPolicy(address, &info);
+ }
+#elif defined(OS_MACOSX) || defined(OS_BSD)
+ // It's not clear we will receive notification when deprecated flag changes.
+ // Socket for ioctl.
+ int ioctl_socket = socket(AF_INET6, SOCK_DGRAM, 0);
+ if (ioctl_socket < 0)
+ return;
+
+ struct ifaddrs* addrs;
+ int rv = getifaddrs(&addrs);
+ if (rv < 0) {
+ LOG(WARNING) << "getifaddrs failed " << rv;
+ close(ioctl_socket);
+ return;
+ }
+
+ for (struct ifaddrs* ifa = addrs; ifa != NULL; ifa = ifa->ifa_next) {
+ IPEndPoint src;
+ if (!src.FromSockAddr(ifa->ifa_addr, ifa->ifa_addr->sa_len))
+ continue;
+ SourceAddressInfo& info = source_map_[src.address()];
+ // Note: no known way to fill in |native| and |home|.
+ info.native = info.home = info.deprecated = false;
+ if (ifa->ifa_addr->sa_family == AF_INET6) {
+ struct in6_ifreq ifr = {};
+ strncpy(ifr.ifr_name, ifa->ifa_name, sizeof(ifr.ifr_name) - 1);
+ DCHECK_LE(ifa->ifa_addr->sa_len, sizeof(ifr.ifr_ifru.ifru_addr));
+ memcpy(&ifr.ifr_ifru.ifru_addr, ifa->ifa_addr, ifa->ifa_addr->sa_len);
+ int rv = ioctl(ioctl_socket, SIOCGIFAFLAG_IN6, &ifr);
+ if (rv >= 0) {
+ info.deprecated = ifr.ifr_ifru.ifru_flags & IN6_IFF_DEPRECATED;
+ } else {
+ LOG(WARNING) << "SIOCGIFAFLAG_IN6 failed " << rv;
+ }
+ }
+ if (ifa->ifa_netmask) {
+ IPEndPoint netmask;
+ if (netmask.FromSockAddr(ifa->ifa_netmask, ifa->ifa_addr->sa_len)) {
+ info.prefix_length = MaskPrefixLength(netmask.address());
+ } else {
+ LOG(WARNING) << "FromSockAddr failed on netmask";
+ }
+ }
+ FillPolicy(src.address(), &info);
+ }
+ freeifaddrs(addrs);
+ close(ioctl_socket);
+#endif
+}
+
+void AddressSorterPosix::FillPolicy(const IPAddressNumber& address,
+ SourceAddressInfo* info) const {
+ DCHECK(CalledOnValidThread());
+ info->scope = GetScope(ipv4_scope_table_, address);
+ info->label = GetPolicyValue(label_table_, address);
+}
+
+// static
+scoped_ptr<AddressSorter> AddressSorter::CreateAddressSorter() {
+ return scoped_ptr<AddressSorter>(
+ new AddressSorterPosix(ClientSocketFactory::GetDefaultFactory()));
+}
+
+} // namespace net
+
diff --git a/chromium/net/dns/address_sorter_posix.h b/chromium/net/dns/address_sorter_posix.h
new file mode 100644
index 00000000000..1c88ad2f978
--- /dev/null
+++ b/chromium/net/dns/address_sorter_posix.h
@@ -0,0 +1,94 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_ADDRESS_SORTER_POSIX_H_
+#define NET_DNS_ADDRESS_SORTER_POSIX_H_
+
+#include <map>
+#include <vector>
+
+#include "base/threading/non_thread_safe.h"
+#include "net/base/address_list.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_export.h"
+#include "net/base/net_util.h"
+#include "net/base/network_change_notifier.h"
+#include "net/dns/address_sorter.h"
+
+namespace net {
+
+class ClientSocketFactory;
+
+// This implementation uses explicit policy to perform the sorting. It is not
+// thread-safe and always completes synchronously.
+class NET_EXPORT_PRIVATE AddressSorterPosix
+ : public AddressSorter,
+ public base::NonThreadSafe,
+ public NetworkChangeNotifier::IPAddressObserver {
+ public:
+ // Generic policy entry.
+ struct PolicyEntry {
+ // IPv4 addresses must be mapped to IPv6.
+ unsigned char prefix[kIPv6AddressSize];
+ unsigned prefix_length;
+ unsigned value;
+ };
+
+ typedef std::vector<PolicyEntry> PolicyTable;
+
+ enum AddressScope {
+ SCOPE_UNDEFINED = 0,
+ SCOPE_NODELOCAL = 1,
+ SCOPE_LINKLOCAL = 2,
+ SCOPE_SITELOCAL = 5,
+ SCOPE_ORGLOCAL = 8,
+ SCOPE_GLOBAL = 14,
+ };
+
+ struct SourceAddressInfo {
+ // Values read from policy tables.
+ AddressScope scope;
+ unsigned label;
+
+ // Values from the OS, matter only if more than one source address is used.
+ unsigned prefix_length;
+ bool deprecated; // vs. preferred RFC4862
+ bool home; // vs. care-of RFC6275
+ bool native;
+ };
+
+ typedef std::map<IPAddressNumber, SourceAddressInfo> SourceAddressMap;
+
+ explicit AddressSorterPosix(ClientSocketFactory* socket_factory);
+ virtual ~AddressSorterPosix();
+
+ // AddressSorter:
+ virtual void Sort(const AddressList& list,
+ const CallbackType& callback) const OVERRIDE;
+
+ private:
+ friend class AddressSorterPosixTest;
+
+ // NetworkChangeNotifier::IPAddressObserver:
+ virtual void OnIPAddressChanged() OVERRIDE;
+
+ // Fills |info| with values for |address| from policy tables.
+ void FillPolicy(const IPAddressNumber& address,
+ SourceAddressInfo* info) const;
+
+ // Mutable to allow using default values for source addresses which were not
+ // found in most recent OnIPAddressChanged.
+ mutable SourceAddressMap source_map_;
+
+ ClientSocketFactory* socket_factory_;
+ PolicyTable precedence_table_;
+ PolicyTable label_table_;
+ PolicyTable ipv4_scope_table_;
+
+ DISALLOW_COPY_AND_ASSIGN(AddressSorterPosix);
+};
+
+} // namespace net
+
+#endif // NET_DNS_ADDRESS_SORTER_POSIX_H_
diff --git a/chromium/net/dns/address_sorter_posix_unittest.cc b/chromium/net/dns/address_sorter_posix_unittest.cc
new file mode 100644
index 00000000000..c4517379957
--- /dev/null
+++ b/chromium/net/dns/address_sorter_posix_unittest.cc
@@ -0,0 +1,327 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/address_sorter_posix.h"
+
+#include "base/bind.h"
+#include "base/logging.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+#include "net/base/test_completion_callback.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/ssl_client_socket.h"
+#include "net/socket/stream_socket.h"
+#include "net/udp/datagram_client_socket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+namespace {
+
+// Used to map destination address to source address.
+typedef std::map<IPAddressNumber, IPAddressNumber> AddressMapping;
+
+IPAddressNumber ParseIP(const std::string& str) {
+ IPAddressNumber addr;
+ CHECK(ParseIPLiteralToNumber(str, &addr));
+ return addr;
+}
+
+// A mock socket which binds to source address according to AddressMapping.
+class TestUDPClientSocket : public DatagramClientSocket {
+ public:
+ explicit TestUDPClientSocket(const AddressMapping* mapping)
+ : mapping_(mapping), connected_(false) {}
+
+ virtual ~TestUDPClientSocket() {}
+
+ virtual int Read(IOBuffer*, int, const CompletionCallback&) OVERRIDE {
+ NOTIMPLEMENTED();
+ return OK;
+ }
+ virtual int Write(IOBuffer*, int, const CompletionCallback&) OVERRIDE {
+ NOTIMPLEMENTED();
+ return OK;
+ }
+ virtual bool SetReceiveBufferSize(int32) OVERRIDE {
+ return true;
+ }
+ virtual bool SetSendBufferSize(int32) OVERRIDE {
+ return true;
+ }
+
+ virtual void Close() OVERRIDE {}
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
+ NOTIMPLEMENTED();
+ return OK;
+ }
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
+ if (!connected_)
+ return ERR_UNEXPECTED;
+ *address = local_endpoint_;
+ return OK;
+ }
+
+ virtual int Connect(const IPEndPoint& remote) OVERRIDE {
+ if (connected_)
+ return ERR_UNEXPECTED;
+ AddressMapping::const_iterator it = mapping_->find(remote.address());
+ if (it == mapping_->end())
+ return ERR_FAILED;
+ connected_ = true;
+ local_endpoint_ = IPEndPoint(it->second, 39874 /* arbitrary port */);
+ return OK;
+ }
+
+ virtual const BoundNetLog& NetLog() const OVERRIDE {
+ return net_log_;
+ }
+
+ private:
+ BoundNetLog net_log_;
+ const AddressMapping* mapping_;
+ bool connected_;
+ IPEndPoint local_endpoint_;
+
+ DISALLOW_COPY_AND_ASSIGN(TestUDPClientSocket);
+};
+
+// Creates TestUDPClientSockets and maintains an AddressMapping.
+class TestSocketFactory : public ClientSocketFactory {
+ public:
+ TestSocketFactory() {}
+ virtual ~TestSocketFactory() {}
+
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
+ DatagramSocket::BindType,
+ const RandIntCallback&,
+ NetLog*,
+ const NetLog::Source&) OVERRIDE {
+ return scoped_ptr<DatagramClientSocket>(new TestUDPClientSocket(&mapping_));
+ }
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
+ const AddressList&,
+ NetLog*,
+ const NetLog::Source&) OVERRIDE {
+ NOTIMPLEMENTED();
+ return scoped_ptr<StreamSocket>();
+ }
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle>,
+ const HostPortPair&,
+ const SSLConfig&,
+ const SSLClientSocketContext&) OVERRIDE {
+ NOTIMPLEMENTED();
+ return scoped_ptr<SSLClientSocket>();
+ }
+ virtual void ClearSSLSessionCache() OVERRIDE {
+ NOTIMPLEMENTED();
+ }
+
+ void AddMapping(const IPAddressNumber& dst, const IPAddressNumber& src) {
+ mapping_[dst] = src;
+ }
+
+ private:
+ AddressMapping mapping_;
+
+ DISALLOW_COPY_AND_ASSIGN(TestSocketFactory);
+};
+
+void OnSortComplete(AddressList* result_buf,
+ const CompletionCallback& callback,
+ bool success,
+ const AddressList& result) {
+ EXPECT_TRUE(success);
+ if (success)
+ *result_buf = result;
+ callback.Run(OK);
+}
+
+} // namespace
+
+class AddressSorterPosixTest : public testing::Test {
+ protected:
+ AddressSorterPosixTest() : sorter_(&socket_factory_) {}
+
+ void AddMapping(const std::string& dst, const std::string& src) {
+ socket_factory_.AddMapping(ParseIP(dst), ParseIP(src));
+ }
+
+ AddressSorterPosix::SourceAddressInfo* GetSourceInfo(
+ const std::string& addr) {
+ IPAddressNumber address = ParseIP(addr);
+ AddressSorterPosix::SourceAddressInfo* info = &sorter_.source_map_[address];
+ if (info->scope == AddressSorterPosix::SCOPE_UNDEFINED)
+ sorter_.FillPolicy(address, info);
+ return info;
+ }
+
+ // Verify that NULL-terminated |addresses| matches (-1)-terminated |order|
+ // after sorting.
+ void Verify(const char* addresses[], const int order[]) {
+ AddressList list;
+ for (const char** addr = addresses; *addr != NULL; ++addr)
+ list.push_back(IPEndPoint(ParseIP(*addr), 80));
+ for (size_t i = 0; order[i] >= 0; ++i)
+ CHECK_LT(order[i], static_cast<int>(list.size()));
+
+ AddressList result;
+ TestCompletionCallback callback;
+ sorter_.Sort(list, base::Bind(&OnSortComplete, &result,
+ callback.callback()));
+ callback.WaitForResult();
+
+ for (size_t i = 0; (i < result.size()) || (order[i] >= 0); ++i) {
+ IPEndPoint expected = order[i] >= 0 ? list[order[i]] : IPEndPoint();
+ IPEndPoint actual = i < result.size() ? result[i] : IPEndPoint();
+ EXPECT_TRUE(expected.address() == actual.address()) <<
+ "Address out of order at position " << i << "\n" <<
+ " Actual: " << actual.ToStringWithoutPort() << "\n" <<
+ "Expected: " << expected.ToStringWithoutPort();
+ }
+ }
+
+ TestSocketFactory socket_factory_;
+ AddressSorterPosix sorter_;
+};
+
+// Rule 1: Avoid unusable destinations.
+TEST_F(AddressSorterPosixTest, Rule1) {
+ AddMapping("10.0.0.231", "10.0.0.1");
+ const char* addresses[] = { "::1", "10.0.0.231", "127.0.0.1", NULL };
+ const int order[] = { 1, -1 };
+ Verify(addresses, order);
+}
+
+// Rule 2: Prefer matching scope.
+TEST_F(AddressSorterPosixTest, Rule2) {
+ AddMapping("3002::1", "4000::10"); // matching global
+ AddMapping("ff32::1", "fe81::10"); // matching link-local
+ AddMapping("fec1::1", "fec1::10"); // matching node-local
+ AddMapping("3002::2", "::1"); // global vs. link-local
+ AddMapping("fec1::2", "fe81::10"); // site-local vs. link-local
+ AddMapping("8.0.0.1", "169.254.0.10"); // global vs. link-local
+ // In all three cases, matching scope is preferred.
+ const int order[] = { 1, 0, -1 };
+ const char* addresses1[] = { "3002::2", "3002::1", NULL };
+ Verify(addresses1, order);
+ const char* addresses2[] = { "fec1::2", "ff32::1", NULL };
+ Verify(addresses2, order);
+ const char* addresses3[] = { "8.0.0.1", "fec1::1", NULL };
+ Verify(addresses3, order);
+}
+
+// Rule 3: Avoid deprecated addresses.
+TEST_F(AddressSorterPosixTest, Rule3) {
+ // Matching scope.
+ AddMapping("3002::1", "4000::10");
+ GetSourceInfo("4000::10")->deprecated = true;
+ AddMapping("3002::2", "4000::20");
+ const char* addresses[] = { "3002::1", "3002::2", NULL };
+ const int order[] = { 1, 0, -1 };
+ Verify(addresses, order);
+}
+
+// Rule 4: Prefer home addresses.
+TEST_F(AddressSorterPosixTest, Rule4) {
+ AddMapping("3002::1", "4000::10");
+ AddMapping("3002::2", "4000::20");
+ GetSourceInfo("4000::20")->home = true;
+ const char* addresses[] = { "3002::1", "3002::2", NULL };
+ const int order[] = { 1, 0, -1 };
+ Verify(addresses, order);
+}
+
+// Rule 5: Prefer matching label.
+TEST_F(AddressSorterPosixTest, Rule5) {
+ AddMapping("::1", "::1"); // matching loopback
+ AddMapping("::ffff:1234:1", "::ffff:1234:10"); // matching IPv4-mapped
+ AddMapping("2001::1", "::ffff:1234:10"); // Teredo vs. IPv4-mapped
+ AddMapping("2002::1", "2001::10"); // 6to4 vs. Teredo
+ const int order[] = { 1, 0, -1 };
+ {
+ const char* addresses[] = { "2001::1", "::1", NULL };
+ Verify(addresses, order);
+ }
+ {
+ const char* addresses[] = { "2002::1", "::ffff:1234:1", NULL };
+ Verify(addresses, order);
+ }
+}
+
+// Rule 6: Prefer higher precedence.
+TEST_F(AddressSorterPosixTest, Rule6) {
+ AddMapping("::1", "::1"); // loopback
+ AddMapping("ff32::1", "fe81::10"); // multicast
+ AddMapping("::ffff:1234:1", "::ffff:1234:10"); // IPv4-mapped
+ AddMapping("2001::1", "2001::10"); // Teredo
+ const char* addresses[] = { "2001::1", "::ffff:1234:1", "ff32::1", "::1",
+ NULL };
+ const int order[] = { 3, 2, 1, 0, -1 };
+ Verify(addresses, order);
+}
+
+// Rule 7: Prefer native transport.
+TEST_F(AddressSorterPosixTest, Rule7) {
+ AddMapping("3002::1", "4000::10");
+ AddMapping("3002::2", "4000::20");
+ GetSourceInfo("4000::20")->native = true;
+ const char* addresses[] = { "3002::1", "3002::2", NULL };
+ const int order[] = { 1, 0, -1 };
+ Verify(addresses, order);
+}
+
+// Rule 8: Prefer smaller scope.
+TEST_F(AddressSorterPosixTest, Rule8) {
+ // Matching scope. Should precede the others by Rule 2.
+ AddMapping("fe81::1", "fe81::10"); // link-local
+ AddMapping("3000::1", "4000::10"); // global
+ // Mismatched scope.
+ AddMapping("ff32::1", "4000::10"); // link-local
+ AddMapping("ff35::1", "4000::10"); // site-local
+ AddMapping("ff38::1", "4000::10"); // org-local
+ const char* addresses[] = { "ff38::1", "3000::1", "ff35::1", "ff32::1",
+ "fe81::1", NULL };
+ const int order[] = { 4, 1, 3, 2, 0, -1 };
+ Verify(addresses, order);
+}
+
+// Rule 9: Use longest matching prefix.
+TEST_F(AddressSorterPosixTest, Rule9) {
+ AddMapping("3000::1", "3000:ffff::10"); // 16 bit match
+ GetSourceInfo("3000:ffff::10")->prefix_length = 16;
+ AddMapping("4000::1", "4000::10"); // 123 bit match, limited to 15
+ GetSourceInfo("4000::10")->prefix_length = 15;
+ AddMapping("4002::1", "4000::10"); // 14 bit match
+ AddMapping("4080::1", "4000::10"); // 8 bit match
+ const char* addresses[] = { "4080::1", "4002::1", "4000::1", "3000::1",
+ NULL };
+ const int order[] = { 3, 2, 1, 0, -1 };
+ Verify(addresses, order);
+}
+
+// Rule 10: Leave the order unchanged.
+TEST_F(AddressSorterPosixTest, Rule10) {
+ AddMapping("4000::1", "4000::10");
+ AddMapping("4000::2", "4000::10");
+ AddMapping("4000::3", "4000::10");
+ const char* addresses[] = { "4000::1", "4000::2", "4000::3", NULL };
+ const int order[] = { 0, 1, 2, -1 };
+ Verify(addresses, order);
+}
+
+TEST_F(AddressSorterPosixTest, MultipleRules) {
+ AddMapping("::1", "::1"); // loopback
+ AddMapping("ff32::1", "fe81::10"); // link-local multicast
+ AddMapping("ff3e::1", "4000::10"); // global multicast
+ AddMapping("4000::1", "4000::10"); // global unicast
+ AddMapping("ff32::2", "fe81::20"); // deprecated link-local multicast
+ GetSourceInfo("fe81::20")->deprecated = true;
+ const char* addresses[] = { "ff3e::1", "ff32::2", "4000::1", "ff32::1", "::1",
+ "8.0.0.1", NULL };
+ const int order[] = { 4, 3, 0, 2, 1, -1 };
+ Verify(addresses, order);
+}
+
+} // namespace net
diff --git a/chromium/net/dns/address_sorter_unittest.cc b/chromium/net/dns/address_sorter_unittest.cc
new file mode 100644
index 00000000000..0c2be884d29
--- /dev/null
+++ b/chromium/net/dns/address_sorter_unittest.cc
@@ -0,0 +1,66 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/address_sorter.h"
+
+#if defined(OS_WIN)
+#include <winsock2.h>
+#endif
+
+#include "base/bind.h"
+#include "base/logging.h"
+#include "net/base/address_list.h"
+#include "net/base/net_util.h"
+#include "net/base/test_completion_callback.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+#if defined(OS_WIN)
+#include "net/base/winsock_init.h"
+#endif
+
+namespace net {
+namespace {
+
+IPEndPoint MakeEndPoint(const std::string& str) {
+ IPAddressNumber addr;
+ CHECK(ParseIPLiteralToNumber(str, &addr));
+ return IPEndPoint(addr, 0);
+}
+
+void OnSortComplete(AddressList* result_buf,
+ const CompletionCallback& callback,
+ bool success,
+ const AddressList& result) {
+ if (success)
+ *result_buf = result;
+ callback.Run(success ? OK : ERR_FAILED);
+}
+
+TEST(AddressSorterTest, Sort) {
+ int expected_result = OK;
+#if defined(OS_WIN)
+ EnsureWinsockInit();
+ SOCKET sock = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP);
+ if (sock == INVALID_SOCKET) {
+ expected_result = ERR_FAILED;
+ } else {
+ closesocket(sock);
+ }
+#endif
+ scoped_ptr<AddressSorter> sorter(AddressSorter::CreateAddressSorter());
+ AddressList list;
+ list.push_back(MakeEndPoint("10.0.0.1"));
+ list.push_back(MakeEndPoint("8.8.8.8"));
+ list.push_back(MakeEndPoint("::1"));
+ list.push_back(MakeEndPoint("2001:4860:4860::8888"));
+
+ AddressList result;
+ TestCompletionCallback callback;
+ sorter->Sort(list, base::Bind(&OnSortComplete, &result,
+ callback.callback()));
+ EXPECT_EQ(expected_result, callback.WaitForResult());
+}
+
+} // namespace
+} // namespace net
diff --git a/chromium/net/dns/address_sorter_win.cc b/chromium/net/dns/address_sorter_win.cc
new file mode 100644
index 00000000000..3e1afa734cc
--- /dev/null
+++ b/chromium/net/dns/address_sorter_win.cc
@@ -0,0 +1,198 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/address_sorter.h"
+
+#include <winsock2.h>
+
+#include <algorithm>
+
+#include "base/bind.h"
+#include "base/location.h"
+#include "base/logging.h"
+#include "base/threading/worker_pool.h"
+#include "base/win/windows_version.h"
+#include "net/base/address_list.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/winsock_init.h"
+
+namespace net {
+
+namespace {
+
+class AddressSorterWin : public AddressSorter {
+ public:
+ AddressSorterWin() {
+ EnsureWinsockInit();
+ }
+
+ virtual ~AddressSorterWin() {}
+
+ // AddressSorter:
+ virtual void Sort(const AddressList& list,
+ const CallbackType& callback) const OVERRIDE {
+ DCHECK(!list.empty());
+ scoped_refptr<Job> job = new Job(list, callback);
+ }
+
+ private:
+ // Executes the SIO_ADDRESS_LIST_SORT ioctl on the WorkerPool, and
+ // performs the necessary conversions to/from AddressList.
+ class Job : public base::RefCountedThreadSafe<Job> {
+ public:
+ Job(const AddressList& list, const CallbackType& callback)
+ : callback_(callback),
+ buffer_size_(sizeof(SOCKET_ADDRESS_LIST) +
+ list.size() * (sizeof(SOCKET_ADDRESS) +
+ sizeof(SOCKADDR_STORAGE))),
+ input_buffer_(reinterpret_cast<SOCKET_ADDRESS_LIST*>(
+ malloc(buffer_size_))),
+ output_buffer_(reinterpret_cast<SOCKET_ADDRESS_LIST*>(
+ malloc(buffer_size_))),
+ success_(false) {
+ input_buffer_->iAddressCount = list.size();
+ SOCKADDR_STORAGE* storage = reinterpret_cast<SOCKADDR_STORAGE*>(
+ input_buffer_->Address + input_buffer_->iAddressCount);
+
+ for (size_t i = 0; i < list.size(); ++i) {
+ IPEndPoint ipe = list[i];
+ // Addresses must be sockaddr_in6.
+ if (ipe.GetFamily() == ADDRESS_FAMILY_IPV4) {
+ ipe = IPEndPoint(ConvertIPv4NumberToIPv6Number(ipe.address()),
+ ipe.port());
+ }
+
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(storage + i);
+ socklen_t addr_len = sizeof(SOCKADDR_STORAGE);
+ bool result = ipe.ToSockAddr(addr, &addr_len);
+ DCHECK(result);
+ input_buffer_->Address[i].lpSockaddr = addr;
+ input_buffer_->Address[i].iSockaddrLength = addr_len;
+ }
+
+ if (!base::WorkerPool::PostTaskAndReply(
+ FROM_HERE,
+ base::Bind(&Job::Run, this),
+ base::Bind(&Job::OnComplete, this),
+ false /* task is slow */)) {
+ LOG(ERROR) << "WorkerPool::PostTaskAndReply failed";
+ OnComplete();
+ }
+ }
+
+ private:
+ friend class base::RefCountedThreadSafe<Job>;
+ ~Job() {}
+
+ // Executed on the WorkerPool.
+ void Run() {
+ SOCKET sock = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP);
+ if (sock == INVALID_SOCKET)
+ return;
+ DWORD result_size = 0;
+ int result = WSAIoctl(sock, SIO_ADDRESS_LIST_SORT, input_buffer_.get(),
+ buffer_size_, output_buffer_.get(), buffer_size_,
+ &result_size, NULL, NULL);
+ if (result == SOCKET_ERROR) {
+ LOG(ERROR) << "SIO_ADDRESS_LIST_SORT failed " << WSAGetLastError();
+ } else {
+ success_ = true;
+ }
+ closesocket(sock);
+ }
+
+ // Executed on the calling thread.
+ void OnComplete() {
+ AddressList list;
+ if (success_) {
+ list.reserve(output_buffer_->iAddressCount);
+ for (int i = 0; i < output_buffer_->iAddressCount; ++i) {
+ IPEndPoint ipe;
+ ipe.FromSockAddr(output_buffer_->Address[i].lpSockaddr,
+ output_buffer_->Address[i].iSockaddrLength);
+ // Unmap V4MAPPED IPv6 addresses so that Happy Eyeballs works.
+ if (IsIPv4Mapped(ipe.address())) {
+ ipe = IPEndPoint(ConvertIPv4MappedToIPv4(ipe.address()),
+ ipe.port());
+ }
+ list.push_back(ipe);
+ }
+ }
+ callback_.Run(success_, list);
+ }
+
+ const CallbackType callback_;
+ const size_t buffer_size_;
+ scoped_ptr_malloc<SOCKET_ADDRESS_LIST> input_buffer_;
+ scoped_ptr_malloc<SOCKET_ADDRESS_LIST> output_buffer_;
+ bool success_;
+
+ DISALLOW_COPY_AND_ASSIGN(Job);
+ };
+
+ DISALLOW_COPY_AND_ASSIGN(AddressSorterWin);
+};
+
+// Merges |list_ipv4| and |list_ipv6| before passing it to |callback|, but
+// only if |success| is true.
+void MergeResults(const AddressSorter::CallbackType& callback,
+ const AddressList& list_ipv4,
+ bool success,
+ const AddressList& list_ipv6) {
+ if (!success) {
+ callback.Run(false, AddressList());
+ return;
+ }
+ AddressList list;
+ list.insert(list.end(), list_ipv6.begin(), list_ipv6.end());
+ list.insert(list.end(), list_ipv4.begin(), list_ipv4.end());
+ callback.Run(true, list);
+}
+
+// Wrapper for AddressSorterWin which does not sort IPv4 or IPv4-mapped
+// addresses but always puts them at the end of the list. Needed because the
+// SIO_ADDRESS_LIST_SORT does not support IPv4 addresses on Windows XP.
+class AddressSorterWinXP : public AddressSorter {
+ public:
+ AddressSorterWinXP() {}
+ virtual ~AddressSorterWinXP() {}
+
+ // AddressSorter:
+ virtual void Sort(const AddressList& list,
+ const CallbackType& callback) const OVERRIDE {
+ AddressList list_ipv4;
+ AddressList list_ipv6;
+ for (size_t i = 0; i < list.size(); ++i) {
+ const IPEndPoint& ipe = list[i];
+ if (ipe.GetFamily() == ADDRESS_FAMILY_IPV4) {
+ list_ipv4.push_back(ipe);
+ } else {
+ list_ipv6.push_back(ipe);
+ }
+ }
+ if (!list_ipv6.empty()) {
+ sorter_.Sort(list_ipv6, base::Bind(&MergeResults, callback, list_ipv4));
+ } else {
+ NOTREACHED() << "Should not be called with IPv4-only addresses.";
+ callback.Run(true, list);
+ }
+ }
+
+ private:
+ AddressSorterWin sorter_;
+
+ DISALLOW_COPY_AND_ASSIGN(AddressSorterWinXP);
+};
+
+} // namespace
+
+// static
+scoped_ptr<AddressSorter> AddressSorter::CreateAddressSorter() {
+ if (base::win::GetVersion() < base::win::VERSION_VISTA)
+ return scoped_ptr<AddressSorter>(new AddressSorterWinXP());
+ return scoped_ptr<AddressSorter>(new AddressSorterWin());
+}
+
+} // namespace net
+
diff --git a/chromium/net/dns/dns_client.cc b/chromium/net/dns/dns_client.cc
new file mode 100644
index 00000000000..976f1533905
--- /dev/null
+++ b/chromium/net/dns/dns_client.cc
@@ -0,0 +1,71 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_client.h"
+
+#include "base/bind.h"
+#include "base/rand_util.h"
+#include "net/base/net_log.h"
+#include "net/dns/address_sorter.h"
+#include "net/dns/dns_config_service.h"
+#include "net/dns/dns_session.h"
+#include "net/dns/dns_socket_pool.h"
+#include "net/dns/dns_transaction.h"
+#include "net/socket/client_socket_factory.h"
+
+namespace net {
+
+namespace {
+
+class DnsClientImpl : public DnsClient {
+ public:
+ explicit DnsClientImpl(NetLog* net_log)
+ : address_sorter_(AddressSorter::CreateAddressSorter()),
+ net_log_(net_log) {}
+
+ virtual void SetConfig(const DnsConfig& config) OVERRIDE {
+ factory_.reset();
+ session_ = NULL;
+ if (config.IsValid()) {
+ ClientSocketFactory* factory = ClientSocketFactory::GetDefaultFactory();
+ scoped_ptr<DnsSocketPool> socket_pool(
+ config.randomize_ports ? DnsSocketPool::CreateDefault(factory)
+ : DnsSocketPool::CreateNull(factory));
+ session_ = new DnsSession(config,
+ socket_pool.Pass(),
+ base::Bind(&base::RandInt),
+ net_log_);
+ factory_ = DnsTransactionFactory::CreateFactory(session_.get());
+ }
+ }
+
+ virtual const DnsConfig* GetConfig() const OVERRIDE {
+ return session_.get() ? &session_->config() : NULL;
+ }
+
+ virtual DnsTransactionFactory* GetTransactionFactory() OVERRIDE {
+ return session_.get() ? factory_.get() : NULL;
+ }
+
+ virtual AddressSorter* GetAddressSorter() OVERRIDE {
+ return address_sorter_.get();
+ }
+
+ private:
+ scoped_refptr<DnsSession> session_;
+ scoped_ptr<DnsTransactionFactory> factory_;
+ scoped_ptr<AddressSorter> address_sorter_;
+
+ NetLog* net_log_;
+};
+
+} // namespace
+
+// static
+scoped_ptr<DnsClient> DnsClient::CreateClient(NetLog* net_log) {
+ return scoped_ptr<DnsClient>(new DnsClientImpl(net_log));
+}
+
+} // namespace net
+
diff --git a/chromium/net/dns/dns_client.h b/chromium/net/dns/dns_client.h
new file mode 100644
index 00000000000..650c7d0d416
--- /dev/null
+++ b/chromium/net/dns/dns_client.h
@@ -0,0 +1,44 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_CLIENT_H_
+#define NET_DNS_DNS_CLIENT_H_
+
+#include "base/memory/scoped_ptr.h"
+#include "net/base/net_export.h"
+
+namespace net {
+
+class AddressSorter;
+struct DnsConfig;
+class DnsTransactionFactory;
+class NetLog;
+
+// Convenience wrapper which allows easy injection of DnsTransaction into
+// HostResolverImpl. Pointers returned by the Get* methods are only guaranteed
+// to remain valid until next time SetConfig is called.
+class NET_EXPORT DnsClient {
+ public:
+ virtual ~DnsClient() {}
+
+ // Creates a new DnsTransactionFactory according to the new |config|.
+ virtual void SetConfig(const DnsConfig& config) = 0;
+
+ // Returns NULL if the current config is not valid.
+ virtual const DnsConfig* GetConfig() const = 0;
+
+ // Returns NULL if the current config is not valid.
+ virtual DnsTransactionFactory* GetTransactionFactory() = 0;
+
+ // Returns NULL if the current config is not valid.
+ virtual AddressSorter* GetAddressSorter() = 0;
+
+ // Creates default client.
+ static scoped_ptr<DnsClient> CreateClient(NetLog* net_log);
+};
+
+} // namespace net
+
+#endif // NET_DNS_DNS_CLIENT_H_
+
diff --git a/chromium/net/dns/dns_config_service.cc b/chromium/net/dns/dns_config_service.cc
new file mode 100644
index 00000000000..ea8a3421cd2
--- /dev/null
+++ b/chromium/net/dns/dns_config_service.cc
@@ -0,0 +1,226 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_config_service.h"
+
+#include "base/logging.h"
+#include "base/metrics/histogram.h"
+#include "base/values.h"
+#include "net/base/ip_endpoint.h"
+
+namespace net {
+
+// Default values are taken from glibc resolv.h except timeout which is set to
+// |kDnsTimeoutSeconds|.
+DnsConfig::DnsConfig()
+ : append_to_multi_label_name(true),
+ randomize_ports(false),
+ ndots(1),
+ timeout(base::TimeDelta::FromSeconds(kDnsTimeoutSeconds)),
+ attempts(2),
+ rotate(false),
+ edns0(false) {}
+
+DnsConfig::~DnsConfig() {}
+
+bool DnsConfig::Equals(const DnsConfig& d) const {
+ return EqualsIgnoreHosts(d) && (hosts == d.hosts);
+}
+
+bool DnsConfig::EqualsIgnoreHosts(const DnsConfig& d) const {
+ return (nameservers == d.nameservers) &&
+ (search == d.search) &&
+ (append_to_multi_label_name == d.append_to_multi_label_name) &&
+ (ndots == d.ndots) &&
+ (timeout == d.timeout) &&
+ (attempts == d.attempts) &&
+ (rotate == d.rotate) &&
+ (edns0 == d.edns0);
+}
+
+void DnsConfig::CopyIgnoreHosts(const DnsConfig& d) {
+ nameservers = d.nameservers;
+ search = d.search;
+ append_to_multi_label_name = d.append_to_multi_label_name;
+ ndots = d.ndots;
+ timeout = d.timeout;
+ attempts = d.attempts;
+ rotate = d.rotate;
+ edns0 = d.edns0;
+}
+
+base::Value* DnsConfig::ToValue() const {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+
+ base::ListValue* list = new base::ListValue();
+ for (size_t i = 0; i < nameservers.size(); ++i)
+ list->Append(new base::StringValue(nameservers[i].ToString()));
+ dict->Set("nameservers", list);
+
+ list = new base::ListValue();
+ for (size_t i = 0; i < search.size(); ++i)
+ list->Append(new base::StringValue(search[i]));
+ dict->Set("search", list);
+
+ dict->SetBoolean("append_to_multi_label_name", append_to_multi_label_name);
+ dict->SetInteger("ndots", ndots);
+ dict->SetDouble("timeout", timeout.InSecondsF());
+ dict->SetInteger("attempts", attempts);
+ dict->SetBoolean("rotate", rotate);
+ dict->SetBoolean("edns0", edns0);
+ dict->SetInteger("num_hosts", hosts.size());
+
+ return dict;
+}
+
+
+DnsConfigService::DnsConfigService()
+ : watch_failed_(false),
+ have_config_(false),
+ have_hosts_(false),
+ need_update_(false),
+ last_sent_empty_(true) {}
+
+DnsConfigService::~DnsConfigService() {
+}
+
+void DnsConfigService::ReadConfig(const CallbackType& callback) {
+ DCHECK(CalledOnValidThread());
+ DCHECK(!callback.is_null());
+ DCHECK(callback_.is_null());
+ callback_ = callback;
+ ReadNow();
+}
+
+void DnsConfigService::WatchConfig(const CallbackType& callback) {
+ DCHECK(CalledOnValidThread());
+ DCHECK(!callback.is_null());
+ DCHECK(callback_.is_null());
+ callback_ = callback;
+ watch_failed_ = !StartWatching();
+ ReadNow();
+}
+
+void DnsConfigService::InvalidateConfig() {
+ DCHECK(CalledOnValidThread());
+ base::TimeTicks now = base::TimeTicks::Now();
+ if (!last_invalidate_config_time_.is_null()) {
+ UMA_HISTOGRAM_LONG_TIMES("AsyncDNS.ConfigNotifyInterval",
+ now - last_invalidate_config_time_);
+ }
+ last_invalidate_config_time_ = now;
+ if (!have_config_)
+ return;
+ have_config_ = false;
+ StartTimer();
+}
+
+void DnsConfigService::InvalidateHosts() {
+ DCHECK(CalledOnValidThread());
+ base::TimeTicks now = base::TimeTicks::Now();
+ if (!last_invalidate_hosts_time_.is_null()) {
+ UMA_HISTOGRAM_LONG_TIMES("AsyncDNS.HostsNotifyInterval",
+ now - last_invalidate_hosts_time_);
+ }
+ last_invalidate_hosts_time_ = now;
+ if (!have_hosts_)
+ return;
+ have_hosts_ = false;
+ StartTimer();
+}
+
+void DnsConfigService::OnConfigRead(const DnsConfig& config) {
+ DCHECK(CalledOnValidThread());
+ DCHECK(config.IsValid());
+
+ bool changed = false;
+ if (!config.EqualsIgnoreHosts(dns_config_)) {
+ dns_config_.CopyIgnoreHosts(config);
+ need_update_ = true;
+ changed = true;
+ }
+ if (!changed && !last_sent_empty_time_.is_null()) {
+ UMA_HISTOGRAM_LONG_TIMES("AsyncDNS.UnchangedConfigInterval",
+ base::TimeTicks::Now() - last_sent_empty_time_);
+ }
+ UMA_HISTOGRAM_BOOLEAN("AsyncDNS.ConfigChange", changed);
+
+ have_config_ = true;
+ if (have_hosts_ || watch_failed_)
+ OnCompleteConfig();
+}
+
+void DnsConfigService::OnHostsRead(const DnsHosts& hosts) {
+ DCHECK(CalledOnValidThread());
+
+ bool changed = false;
+ if (hosts != dns_config_.hosts) {
+ dns_config_.hosts = hosts;
+ need_update_ = true;
+ changed = true;
+ }
+ if (!changed && !last_sent_empty_time_.is_null()) {
+ UMA_HISTOGRAM_LONG_TIMES("AsyncDNS.UnchangedHostsInterval",
+ base::TimeTicks::Now() - last_sent_empty_time_);
+ }
+ UMA_HISTOGRAM_BOOLEAN("AsyncDNS.HostsChange", changed);
+
+ have_hosts_ = true;
+ if (have_config_ || watch_failed_)
+ OnCompleteConfig();
+}
+
+void DnsConfigService::StartTimer() {
+ DCHECK(CalledOnValidThread());
+ if (last_sent_empty_) {
+ DCHECK(!timer_.IsRunning());
+ return; // No need to withdraw again.
+ }
+ timer_.Stop();
+
+ // Give it a short timeout to come up with a valid config. Otherwise withdraw
+ // the config from the receiver. The goal is to avoid perceivable network
+ // outage (when using the wrong config) but at the same time avoid
+ // unnecessary Job aborts in HostResolverImpl. The signals come from multiple
+ // sources so it might receive multiple events during a config change.
+
+ // DHCP and user-induced changes are on the order of seconds, so 150ms should
+ // not add perceivable delay. On the other hand, config readers should finish
+ // within 150ms with the rare exception of I/O block or extra large HOSTS.
+ const base::TimeDelta kTimeout = base::TimeDelta::FromMilliseconds(150);
+
+ timer_.Start(FROM_HERE,
+ kTimeout,
+ this,
+ &DnsConfigService::OnTimeout);
+}
+
+void DnsConfigService::OnTimeout() {
+ DCHECK(CalledOnValidThread());
+ DCHECK(!last_sent_empty_);
+ // Indicate that even if there is no change in On*Read, we will need to
+ // update the receiver when the config becomes complete.
+ need_update_ = true;
+ // Empty config is considered invalid.
+ last_sent_empty_ = true;
+ last_sent_empty_time_ = base::TimeTicks::Now();
+ callback_.Run(DnsConfig());
+}
+
+void DnsConfigService::OnCompleteConfig() {
+ timer_.Stop();
+ if (!need_update_)
+ return;
+ need_update_ = false;
+ last_sent_empty_ = false;
+ if (watch_failed_) {
+ // If a watch failed, the config may not be accurate, so report empty.
+ callback_.Run(DnsConfig());
+ } else {
+ callback_.Run(dns_config_);
+ }
+}
+
+} // namespace net
+
diff --git a/chromium/net/dns/dns_config_service.h b/chromium/net/dns/dns_config_service.h
new file mode 100644
index 00000000000..4babb9e7d37
--- /dev/null
+++ b/chromium/net/dns/dns_config_service.h
@@ -0,0 +1,174 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_CONFIG_SERVICE_H_
+#define NET_DNS_DNS_CONFIG_SERVICE_H_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "base/gtest_prod_util.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/threading/non_thread_safe.h"
+#include "base/time/time.h"
+#include "base/timer/timer.h"
+// Needed on shared build with MSVS2010 to avoid multiple definitions of
+// std::vector<IPEndPoint>.
+#include "net/base/address_list.h"
+#include "net/base/ip_endpoint.h" // win requires size of IPEndPoint
+#include "net/base/net_export.h"
+#include "net/dns/dns_hosts.h"
+
+namespace base {
+class Value;
+}
+
+namespace net {
+
+// Always use 1 second timeout (followed by binary exponential backoff).
+// TODO(szym): Remove code which reads timeout from system.
+const unsigned kDnsTimeoutSeconds = 1;
+
+// DnsConfig stores configuration of the system resolver.
+struct NET_EXPORT_PRIVATE DnsConfig {
+ DnsConfig();
+ virtual ~DnsConfig();
+
+ bool Equals(const DnsConfig& d) const;
+
+ bool EqualsIgnoreHosts(const DnsConfig& d) const;
+
+ void CopyIgnoreHosts(const DnsConfig& src);
+
+ // Returns a Value representation of |this|. Caller takes ownership of the
+ // returned Value. For performance reasons, the Value only contains the
+ // number of hosts rather than the full list.
+ base::Value* ToValue() const;
+
+ bool IsValid() const {
+ return !nameservers.empty();
+ }
+
+ // List of name server addresses.
+ std::vector<IPEndPoint> nameservers;
+ // Suffix search list; used on first lookup when number of dots in given name
+ // is less than |ndots|.
+ std::vector<std::string> search;
+
+ DnsHosts hosts;
+
+ // AppendToMultiLabelName: is suffix search performed for multi-label names?
+ // True, except on Windows where it can be configured.
+ bool append_to_multi_label_name;
+
+ // Indicates that source port randomization is required. This uses additional
+ // resources on some platforms.
+ bool randomize_ports;
+
+ // Resolver options; see man resolv.conf.
+
+ // Minimum number of dots before global resolution precedes |search|.
+ int ndots;
+ // Time between retransmissions, see res_state.retrans.
+ base::TimeDelta timeout;
+ // Maximum number of attempts, see res_state.retry.
+ int attempts;
+ // Round robin entries in |nameservers| for subsequent requests.
+ bool rotate;
+ // Enable EDNS0 extensions.
+ bool edns0;
+};
+
+
+// Service for reading system DNS settings, on demand or when signalled by
+// internal watchers and NetworkChangeNotifier.
+class NET_EXPORT_PRIVATE DnsConfigService
+ : NON_EXPORTED_BASE(public base::NonThreadSafe) {
+ public:
+ // Callback interface for the client, called on the same thread as
+ // ReadConfig() and WatchConfig().
+ typedef base::Callback<void(const DnsConfig& config)> CallbackType;
+
+ // Creates the platform-specific DnsConfigService.
+ static scoped_ptr<DnsConfigService> CreateSystemService();
+
+ DnsConfigService();
+ virtual ~DnsConfigService();
+
+ // Attempts to read the configuration. Will run |callback| when succeeded.
+ // Can be called at most once.
+ void ReadConfig(const CallbackType& callback);
+
+ // Registers systems watchers. Will attempt to read config after watch starts,
+ // but only if watchers started successfully. Will run |callback| iff config
+ // changes from last call or has to be withdrawn. Can be called at most once.
+ // Might require MessageLoopForIO.
+ void WatchConfig(const CallbackType& callback);
+
+ protected:
+ enum WatchStatus {
+ DNS_CONFIG_WATCH_STARTED = 0,
+ DNS_CONFIG_WATCH_FAILED_TO_START_CONFIG,
+ DNS_CONFIG_WATCH_FAILED_TO_START_HOSTS,
+ DNS_CONFIG_WATCH_FAILED_CONFIG,
+ DNS_CONFIG_WATCH_FAILED_HOSTS,
+ DNS_CONFIG_WATCH_MAX,
+ };
+
+ // Immediately attempts to read the current configuration.
+ virtual void ReadNow() = 0;
+ // Registers system watchers. Returns true iff succeeds.
+ virtual bool StartWatching() = 0;
+
+ // Called when the current config (except hosts) has changed.
+ void InvalidateConfig();
+ // Called when the current hosts have changed.
+ void InvalidateHosts();
+
+ // Called with new config. |config|.hosts is ignored.
+ void OnConfigRead(const DnsConfig& config);
+ // Called with new hosts. Rest of the config is assumed unchanged.
+ void OnHostsRead(const DnsHosts& hosts);
+
+ void set_watch_failed(bool value) { watch_failed_ = value; }
+
+ private:
+ // The timer counts from the last Invalidate* until complete config is read.
+ void StartTimer();
+ void OnTimeout();
+ // Called when the config becomes complete. Stops the timer.
+ void OnCompleteConfig();
+
+ CallbackType callback_;
+
+ DnsConfig dns_config_;
+
+ // True if any of the necessary watchers failed. In that case, the service
+ // will communicate changes via OnTimeout, but will only send empty DnsConfig.
+ bool watch_failed_;
+ // True after On*Read, before Invalidate*. Tells if the config is complete.
+ bool have_config_;
+ bool have_hosts_;
+ // True if receiver needs to be updated when the config becomes complete.
+ bool need_update_;
+ // True if the last config sent was empty (instead of |dns_config_|).
+ // Set when |timer_| expires.
+ bool last_sent_empty_;
+
+ // Initialized and updated on Invalidate* call.
+ base::TimeTicks last_invalidate_config_time_;
+ base::TimeTicks last_invalidate_hosts_time_;
+ // Initialized and updated when |timer_| expires.
+ base::TimeTicks last_sent_empty_time_;
+
+ // Started in Invalidate*, cleared in On*Read.
+ base::OneShotTimer<DnsConfigService> timer_;
+
+ DISALLOW_COPY_AND_ASSIGN(DnsConfigService);
+};
+
+} // namespace net
+
+#endif // NET_DNS_DNS_CONFIG_SERVICE_H_
diff --git a/chromium/net/dns/dns_config_service_posix.cc b/chromium/net/dns/dns_config_service_posix.cc
new file mode 100644
index 00000000000..ff2295e704a
--- /dev/null
+++ b/chromium/net/dns/dns_config_service_posix.cc
@@ -0,0 +1,404 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_config_service_posix.h"
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/bind.h"
+#include "base/files/file_path.h"
+#include "base/files/file_path_watcher.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/metrics/histogram.h"
+#include "base/time/time.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_util.h"
+#include "net/dns/dns_hosts.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/notify_watcher_mac.h"
+#include "net/dns/serial_worker.h"
+
+namespace net {
+
+#if !defined(OS_ANDROID)
+namespace internal {
+
+namespace {
+
+const base::FilePath::CharType* kFilePathHosts =
+ FILE_PATH_LITERAL("/etc/hosts");
+
+#if defined(OS_MACOSX)
+// From 10.7.3 configd-395.10/dnsinfo/dnsinfo.h
+static const char* kDnsNotifyKey =
+ "com.apple.system.SystemConfiguration.dns_configuration";
+
+class ConfigWatcher {
+ public:
+ bool Watch(const base::Callback<void(bool succeeded)>& callback) {
+ return watcher_.Watch(kDnsNotifyKey, callback);
+ }
+
+ private:
+ NotifyWatcherMac watcher_;
+};
+#else
+
+#ifndef _PATH_RESCONF // Normally defined in <resolv.h>
+#define _PATH_RESCONF "/etc/resolv.conf"
+#endif
+
+static const base::FilePath::CharType* kFilePathConfig =
+ FILE_PATH_LITERAL(_PATH_RESCONF);
+
+class ConfigWatcher {
+ public:
+ typedef base::Callback<void(bool succeeded)> CallbackType;
+
+ bool Watch(const CallbackType& callback) {
+ callback_ = callback;
+ return watcher_.Watch(base::FilePath(kFilePathConfig), false,
+ base::Bind(&ConfigWatcher::OnCallback,
+ base::Unretained(this)));
+ }
+
+ private:
+ void OnCallback(const base::FilePath& path, bool error) {
+ callback_.Run(!error);
+ }
+
+ base::FilePathWatcher watcher_;
+ CallbackType callback_;
+};
+#endif
+
+ConfigParsePosixResult ReadDnsConfig(DnsConfig* config) {
+ ConfigParsePosixResult result;
+#if defined(OS_OPENBSD)
+ // Note: res_ninit in glibc always returns 0 and sets RES_INIT.
+ // res_init behaves the same way.
+ memset(&_res, 0, sizeof(_res));
+ if (res_init() == 0) {
+ result = ConvertResStateToDnsConfig(_res, config);
+ } else {
+ result = CONFIG_PARSE_POSIX_RES_INIT_FAILED;
+ }
+#else // all other OS_POSIX
+ struct __res_state res;
+ memset(&res, 0, sizeof(res));
+ if (res_ninit(&res) == 0) {
+ result = ConvertResStateToDnsConfig(res, config);
+ } else {
+ result = CONFIG_PARSE_POSIX_RES_INIT_FAILED;
+ }
+ // Prefer res_ndestroy where available.
+#if defined(OS_MACOSX) || defined(OS_FREEBSD)
+ res_ndestroy(&res);
+#else
+ res_nclose(&res);
+#endif
+#endif
+ // Override timeout value to match default setting on Windows.
+ config->timeout = base::TimeDelta::FromSeconds(kDnsTimeoutSeconds);
+ return result;
+}
+
+} // namespace
+
+class DnsConfigServicePosix::Watcher {
+ public:
+ explicit Watcher(DnsConfigServicePosix* service)
+ : weak_factory_(this),
+ service_(service) {}
+ ~Watcher() {}
+
+ bool Watch() {
+ bool success = true;
+ if (!config_watcher_.Watch(base::Bind(&Watcher::OnConfigChanged,
+ base::Unretained(this)))) {
+ LOG(ERROR) << "DNS config watch failed to start.";
+ success = false;
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus",
+ DNS_CONFIG_WATCH_FAILED_TO_START_CONFIG,
+ DNS_CONFIG_WATCH_MAX);
+ }
+ if (!hosts_watcher_.Watch(base::FilePath(kFilePathHosts), false,
+ base::Bind(&Watcher::OnHostsChanged,
+ base::Unretained(this)))) {
+ LOG(ERROR) << "DNS hosts watch failed to start.";
+ success = false;
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus",
+ DNS_CONFIG_WATCH_FAILED_TO_START_HOSTS,
+ DNS_CONFIG_WATCH_MAX);
+ }
+ return success;
+ }
+
+ private:
+ void OnConfigChanged(bool succeeded) {
+ // Ignore transient flutter of resolv.conf by delaying the signal a bit.
+ const base::TimeDelta kDelay = base::TimeDelta::FromMilliseconds(50);
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(&Watcher::OnConfigChangedDelayed,
+ weak_factory_.GetWeakPtr(),
+ succeeded),
+ kDelay);
+ }
+ void OnConfigChangedDelayed(bool succeeded) {
+ service_->OnConfigChanged(succeeded);
+ }
+ void OnHostsChanged(const base::FilePath& path, bool error) {
+ service_->OnHostsChanged(!error);
+ }
+
+ base::WeakPtrFactory<Watcher> weak_factory_;
+ DnsConfigServicePosix* service_;
+ ConfigWatcher config_watcher_;
+ base::FilePathWatcher hosts_watcher_;
+
+ DISALLOW_COPY_AND_ASSIGN(Watcher);
+};
+
+// A SerialWorker that uses libresolv to initialize res_state and converts
+// it to DnsConfig.
+class DnsConfigServicePosix::ConfigReader : public SerialWorker {
+ public:
+ explicit ConfigReader(DnsConfigServicePosix* service)
+ : service_(service), success_(false) {}
+
+ virtual void DoWork() OVERRIDE {
+ base::TimeTicks start_time = base::TimeTicks::Now();
+ ConfigParsePosixResult result = ReadDnsConfig(&dns_config_);
+ success_ = (result == CONFIG_PARSE_POSIX_OK);
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.ConfigParsePosix",
+ result, CONFIG_PARSE_POSIX_MAX);
+ UMA_HISTOGRAM_BOOLEAN("AsyncDNS.ConfigParseResult", success_);
+ UMA_HISTOGRAM_TIMES("AsyncDNS.ConfigParseDuration",
+ base::TimeTicks::Now() - start_time);
+ }
+
+ virtual void OnWorkFinished() OVERRIDE {
+ DCHECK(!IsCancelled());
+ if (success_) {
+ service_->OnConfigRead(dns_config_);
+ } else {
+ LOG(WARNING) << "Failed to read DnsConfig.";
+ }
+ }
+
+ private:
+ virtual ~ConfigReader() {}
+
+ DnsConfigServicePosix* service_;
+ // Written in DoWork, read in OnWorkFinished, no locking necessary.
+ DnsConfig dns_config_;
+ bool success_;
+
+ DISALLOW_COPY_AND_ASSIGN(ConfigReader);
+};
+
+// A SerialWorker that reads the HOSTS file and runs Callback.
+class DnsConfigServicePosix::HostsReader : public SerialWorker {
+ public:
+ explicit HostsReader(DnsConfigServicePosix* service)
+ : service_(service), path_(kFilePathHosts), success_(false) {}
+
+ private:
+ virtual ~HostsReader() {}
+
+ virtual void DoWork() OVERRIDE {
+ base::TimeTicks start_time = base::TimeTicks::Now();
+ success_ = ParseHostsFile(path_, &hosts_);
+ UMA_HISTOGRAM_BOOLEAN("AsyncDNS.HostParseResult", success_);
+ UMA_HISTOGRAM_TIMES("AsyncDNS.HostsParseDuration",
+ base::TimeTicks::Now() - start_time);
+ }
+
+ virtual void OnWorkFinished() OVERRIDE {
+ if (success_) {
+ service_->OnHostsRead(hosts_);
+ } else {
+ LOG(WARNING) << "Failed to read DnsHosts.";
+ }
+ }
+
+ DnsConfigServicePosix* service_;
+ const base::FilePath path_;
+ // Written in DoWork, read in OnWorkFinished, no locking necessary.
+ DnsHosts hosts_;
+ bool success_;
+
+ DISALLOW_COPY_AND_ASSIGN(HostsReader);
+};
+
+DnsConfigServicePosix::DnsConfigServicePosix()
+ : config_reader_(new ConfigReader(this)),
+ hosts_reader_(new HostsReader(this)) {}
+
+DnsConfigServicePosix::~DnsConfigServicePosix() {
+ config_reader_->Cancel();
+ hosts_reader_->Cancel();
+}
+
+void DnsConfigServicePosix::ReadNow() {
+ config_reader_->WorkNow();
+ hosts_reader_->WorkNow();
+}
+
+bool DnsConfigServicePosix::StartWatching() {
+ // TODO(szym): re-start watcher if that makes sense. http://crbug.com/116139
+ watcher_.reset(new Watcher(this));
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus", DNS_CONFIG_WATCH_STARTED,
+ DNS_CONFIG_WATCH_MAX);
+ return watcher_->Watch();
+}
+
+void DnsConfigServicePosix::OnConfigChanged(bool succeeded) {
+ InvalidateConfig();
+ if (succeeded) {
+ config_reader_->WorkNow();
+ } else {
+ LOG(ERROR) << "DNS config watch failed.";
+ set_watch_failed(true);
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus",
+ DNS_CONFIG_WATCH_FAILED_CONFIG,
+ DNS_CONFIG_WATCH_MAX);
+ }
+}
+
+void DnsConfigServicePosix::OnHostsChanged(bool succeeded) {
+ InvalidateHosts();
+ if (succeeded) {
+ hosts_reader_->WorkNow();
+ } else {
+ LOG(ERROR) << "DNS hosts watch failed.";
+ set_watch_failed(true);
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus",
+ DNS_CONFIG_WATCH_FAILED_HOSTS,
+ DNS_CONFIG_WATCH_MAX);
+ }
+}
+
+ConfigParsePosixResult ConvertResStateToDnsConfig(const struct __res_state& res,
+ DnsConfig* dns_config) {
+ CHECK(dns_config != NULL);
+ if (!(res.options & RES_INIT))
+ return CONFIG_PARSE_POSIX_RES_INIT_UNSET;
+
+ dns_config->nameservers.clear();
+
+#if defined(OS_MACOSX) || defined(OS_FREEBSD)
+ union res_sockaddr_union addresses[MAXNS];
+ int nscount = res_getservers(const_cast<res_state>(&res), addresses, MAXNS);
+ DCHECK_GE(nscount, 0);
+ DCHECK_LE(nscount, MAXNS);
+ for (int i = 0; i < nscount; ++i) {
+ IPEndPoint ipe;
+ if (!ipe.FromSockAddr(
+ reinterpret_cast<const struct sockaddr*>(&addresses[i]),
+ sizeof addresses[i])) {
+ return CONFIG_PARSE_POSIX_BAD_ADDRESS;
+ }
+ dns_config->nameservers.push_back(ipe);
+ }
+#elif defined(OS_LINUX)
+ COMPILE_ASSERT(arraysize(res.nsaddr_list) >= MAXNS &&
+ arraysize(res._u._ext.nsaddrs) >= MAXNS,
+ incompatible_libresolv_res_state);
+ DCHECK_LE(res.nscount, MAXNS);
+ // Initially, glibc stores IPv6 in |_ext.nsaddrs| and IPv4 in |nsaddr_list|.
+ // In res_send.c:res_nsend, it merges |nsaddr_list| into |nsaddrs|,
+ // but we have to combine the two arrays ourselves.
+ for (int i = 0; i < res.nscount; ++i) {
+ IPEndPoint ipe;
+ const struct sockaddr* addr = NULL;
+ size_t addr_len = 0;
+ if (res.nsaddr_list[i].sin_family) { // The indicator used by res_nsend.
+ addr = reinterpret_cast<const struct sockaddr*>(&res.nsaddr_list[i]);
+ addr_len = sizeof res.nsaddr_list[i];
+ } else if (res._u._ext.nsaddrs[i] != NULL) {
+ addr = reinterpret_cast<const struct sockaddr*>(res._u._ext.nsaddrs[i]);
+ addr_len = sizeof *res._u._ext.nsaddrs[i];
+ } else {
+ return CONFIG_PARSE_POSIX_BAD_EXT_STRUCT;
+ }
+ if (!ipe.FromSockAddr(addr, addr_len))
+ return CONFIG_PARSE_POSIX_BAD_ADDRESS;
+ dns_config->nameservers.push_back(ipe);
+ }
+#else // !(defined(OS_LINUX) || defined(OS_MACOSX) || defined(OS_FREEBSD))
+ DCHECK_LE(res.nscount, MAXNS);
+ for (int i = 0; i < res.nscount; ++i) {
+ IPEndPoint ipe;
+ if (!ipe.FromSockAddr(
+ reinterpret_cast<const struct sockaddr*>(&res.nsaddr_list[i]),
+ sizeof res.nsaddr_list[i])) {
+ return CONFIG_PARSE_POSIX_BAD_ADDRESS;
+ }
+ dns_config->nameservers.push_back(ipe);
+ }
+#endif
+
+ dns_config->search.clear();
+ for (int i = 0; (i < MAXDNSRCH) && res.dnsrch[i]; ++i) {
+ dns_config->search.push_back(std::string(res.dnsrch[i]));
+ }
+
+ dns_config->ndots = res.ndots;
+ dns_config->timeout = base::TimeDelta::FromSeconds(res.retrans);
+ dns_config->attempts = res.retry;
+#if defined(RES_ROTATE)
+ dns_config->rotate = res.options & RES_ROTATE;
+#endif
+ dns_config->edns0 = res.options & RES_USE_EDNS0;
+
+ // The current implementation assumes these options are set. They normally
+ // cannot be overwritten by /etc/resolv.conf
+ unsigned kRequiredOptions = RES_RECURSE | RES_DEFNAMES | RES_DNSRCH;
+ if ((res.options & kRequiredOptions) != kRequiredOptions)
+ return CONFIG_PARSE_POSIX_MISSING_OPTIONS;
+
+ unsigned kUnhandledOptions = RES_USEVC | RES_IGNTC | RES_USE_DNSSEC;
+ if (res.options & kUnhandledOptions)
+ return CONFIG_PARSE_POSIX_UNHANDLED_OPTIONS;
+
+ if (dns_config->nameservers.empty())
+ return CONFIG_PARSE_POSIX_NO_NAMESERVERS;
+
+ // If any name server is 0.0.0.0, assume the configuration is invalid.
+ // TODO(szym): Measure how often this happens. http://crbug.com/125599
+ const IPAddressNumber kEmptyAddress(kIPv4AddressSize);
+ for (unsigned i = 0; i < dns_config->nameservers.size(); ++i) {
+ if (dns_config->nameservers[i].address() == kEmptyAddress)
+ return CONFIG_PARSE_POSIX_NULL_ADDRESS;
+ }
+ return CONFIG_PARSE_POSIX_OK;
+}
+
+} // namespace internal
+
+// static
+scoped_ptr<DnsConfigService> DnsConfigService::CreateSystemService() {
+ return scoped_ptr<DnsConfigService>(new internal::DnsConfigServicePosix());
+}
+
+#else // defined(OS_ANDROID)
+// Android NDK provides only a stub <resolv.h> header.
+class StubDnsConfigService : public DnsConfigService {
+ public:
+ StubDnsConfigService() {}
+ virtual ~StubDnsConfigService() {}
+ private:
+ virtual void ReadNow() OVERRIDE {}
+ virtual bool StartWatching() OVERRIDE { return false; }
+};
+// static
+scoped_ptr<DnsConfigService> DnsConfigService::CreateSystemService() {
+ return scoped_ptr<DnsConfigService>(new StubDnsConfigService());
+}
+#endif
+
+} // namespace net
diff --git a/chromium/net/dns/dns_config_service_posix.h b/chromium/net/dns/dns_config_service_posix.h
new file mode 100644
index 00000000000..95a4377d932
--- /dev/null
+++ b/chromium/net/dns/dns_config_service_posix.h
@@ -0,0 +1,67 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_CONFIG_SERVICE_POSIX_H_
+#define NET_DNS_DNS_CONFIG_SERVICE_POSIX_H_
+
+#include <sys/types.h>
+#include <netinet/in.h>
+#include <resolv.h>
+
+#include "base/compiler_specific.h"
+#include "net/base/net_export.h"
+#include "net/dns/dns_config_service.h"
+
+namespace net {
+
+// Use DnsConfigService::CreateSystemService to use it outside of tests.
+namespace internal {
+
+class NET_EXPORT_PRIVATE DnsConfigServicePosix : public DnsConfigService {
+ public:
+ DnsConfigServicePosix();
+ virtual ~DnsConfigServicePosix();
+
+ protected:
+ // DnsConfigService:
+ virtual void ReadNow() OVERRIDE;
+ virtual bool StartWatching() OVERRIDE;
+
+ private:
+ class Watcher;
+ class ConfigReader;
+ class HostsReader;
+
+ void OnConfigChanged(bool succeeded);
+ void OnHostsChanged(bool succeeded);
+
+ scoped_ptr<Watcher> watcher_;
+ scoped_refptr<ConfigReader> config_reader_;
+ scoped_refptr<HostsReader> hosts_reader_;
+
+ DISALLOW_COPY_AND_ASSIGN(DnsConfigServicePosix);
+};
+
+enum ConfigParsePosixResult {
+ CONFIG_PARSE_POSIX_OK = 0,
+ CONFIG_PARSE_POSIX_RES_INIT_FAILED,
+ CONFIG_PARSE_POSIX_RES_INIT_UNSET,
+ CONFIG_PARSE_POSIX_BAD_ADDRESS,
+ CONFIG_PARSE_POSIX_BAD_EXT_STRUCT,
+ CONFIG_PARSE_POSIX_NULL_ADDRESS,
+ CONFIG_PARSE_POSIX_NO_NAMESERVERS,
+ CONFIG_PARSE_POSIX_MISSING_OPTIONS,
+ CONFIG_PARSE_POSIX_UNHANDLED_OPTIONS,
+ CONFIG_PARSE_POSIX_MAX // Bounding values for enumeration.
+};
+
+// Fills in |dns_config| from |res|.
+ConfigParsePosixResult NET_EXPORT_PRIVATE ConvertResStateToDnsConfig(
+ const struct __res_state& res, DnsConfig* dns_config);
+
+} // namespace internal
+
+} // namespace net
+
+#endif // NET_DNS_DNS_CONFIG_SERVICE_POSIX_H_
diff --git a/chromium/net/dns/dns_config_service_posix_unittest.cc b/chromium/net/dns/dns_config_service_posix_unittest.cc
new file mode 100644
index 00000000000..92e46ed1977
--- /dev/null
+++ b/chromium/net/dns/dns_config_service_posix_unittest.cc
@@ -0,0 +1,156 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <resolv.h>
+
+#include "base/sys_byteorder.h"
+#include "net/dns/dns_config_service_posix.h"
+
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+namespace {
+
+// MAXNS is normally 3, but let's test 4 if possible.
+const char* kNameserversIPv4[] = {
+ "8.8.8.8",
+ "192.168.1.1",
+ "63.1.2.4",
+ "1.0.0.1",
+};
+
+#if defined(OS_LINUX)
+const char* kNameserversIPv6[] = {
+ NULL,
+ "2001:DB8:0::42",
+ NULL,
+ "::FFFF:129.144.52.38",
+};
+#endif
+
+// Fills in |res| with sane configuration.
+void InitializeResState(res_state res) {
+ memset(res, 0, sizeof(*res));
+ res->options = RES_INIT | RES_RECURSE | RES_DEFNAMES | RES_DNSRCH |
+ RES_ROTATE;
+ res->ndots = 2;
+ res->retrans = 4;
+ res->retry = 7;
+
+ const char kDnsrch[] = "chromium.org" "\0" "example.com";
+ memcpy(res->defdname, kDnsrch, sizeof(kDnsrch));
+ res->dnsrch[0] = res->defdname;
+ res->dnsrch[1] = res->defdname + sizeof("chromium.org");
+
+ for (unsigned i = 0; i < arraysize(kNameserversIPv4) && i < MAXNS; ++i) {
+ struct sockaddr_in sa;
+ sa.sin_family = AF_INET;
+ sa.sin_port = base::HostToNet16(NS_DEFAULTPORT + i);
+ inet_pton(AF_INET, kNameserversIPv4[i], &sa.sin_addr);
+ res->nsaddr_list[i] = sa;
+ ++res->nscount;
+ }
+
+#if defined(OS_LINUX)
+ // Install IPv6 addresses, replacing the corresponding IPv4 addresses.
+ unsigned nscount6 = 0;
+ for (unsigned i = 0; i < arraysize(kNameserversIPv6) && i < MAXNS; ++i) {
+ if (!kNameserversIPv6[i])
+ continue;
+ // Must use malloc to mimick res_ninit.
+ struct sockaddr_in6 *sa6;
+ sa6 = (struct sockaddr_in6 *)malloc(sizeof(*sa6));
+ sa6->sin6_family = AF_INET6;
+ sa6->sin6_port = base::HostToNet16(NS_DEFAULTPORT - i);
+ inet_pton(AF_INET6, kNameserversIPv6[i], &sa6->sin6_addr);
+ res->_u._ext.nsaddrs[i] = sa6;
+ memset(&res->nsaddr_list[i], 0, sizeof res->nsaddr_list[i]);
+ ++nscount6;
+ }
+ res->_u._ext.nscount6 = nscount6;
+#endif
+}
+
+void CloseResState(res_state res) {
+#if defined(OS_LINUX)
+ for (int i = 0; i < res->nscount; ++i) {
+ if (res->_u._ext.nsaddrs[i] != NULL)
+ free(res->_u._ext.nsaddrs[i]);
+ }
+#endif
+}
+
+void InitializeExpectedConfig(DnsConfig* config) {
+ config->ndots = 2;
+ config->timeout = base::TimeDelta::FromSeconds(4);
+ config->attempts = 7;
+ config->rotate = true;
+ config->edns0 = false;
+ config->append_to_multi_label_name = true;
+ config->search.clear();
+ config->search.push_back("chromium.org");
+ config->search.push_back("example.com");
+
+ config->nameservers.clear();
+ for (unsigned i = 0; i < arraysize(kNameserversIPv4) && i < MAXNS; ++i) {
+ IPAddressNumber ip;
+ ParseIPLiteralToNumber(kNameserversIPv4[i], &ip);
+ config->nameservers.push_back(IPEndPoint(ip, NS_DEFAULTPORT + i));
+ }
+
+#if defined(OS_LINUX)
+ for (unsigned i = 0; i < arraysize(kNameserversIPv6) && i < MAXNS; ++i) {
+ if (!kNameserversIPv6[i])
+ continue;
+ IPAddressNumber ip;
+ ParseIPLiteralToNumber(kNameserversIPv6[i], &ip);
+ config->nameservers[i] = IPEndPoint(ip, NS_DEFAULTPORT - i);
+ }
+#endif
+}
+
+TEST(DnsConfigServicePosixTest, ConvertResStateToDnsConfig) {
+ struct __res_state res;
+ DnsConfig config;
+ EXPECT_FALSE(config.IsValid());
+ InitializeResState(&res);
+ ASSERT_EQ(internal::CONFIG_PARSE_POSIX_OK,
+ internal::ConvertResStateToDnsConfig(res, &config));
+ CloseResState(&res);
+ EXPECT_TRUE(config.IsValid());
+
+ DnsConfig expected_config;
+ EXPECT_FALSE(expected_config.EqualsIgnoreHosts(config));
+ InitializeExpectedConfig(&expected_config);
+ EXPECT_TRUE(expected_config.EqualsIgnoreHosts(config));
+}
+
+TEST(DnsConfigServicePosixTest, RejectEmptyNameserver) {
+ struct __res_state res = {};
+ res.options = RES_INIT | RES_RECURSE | RES_DEFNAMES | RES_DNSRCH;
+ const char kDnsrch[] = "chromium.org";
+ memcpy(res.defdname, kDnsrch, sizeof(kDnsrch));
+ res.dnsrch[0] = res.defdname;
+
+ struct sockaddr_in sa = {};
+ sa.sin_family = AF_INET;
+ sa.sin_port = base::HostToNet16(NS_DEFAULTPORT);
+ sa.sin_addr.s_addr = INADDR_ANY;
+ res.nsaddr_list[0] = sa;
+ sa.sin_addr.s_addr = 0xCAFE1337;
+ res.nsaddr_list[1] = sa;
+ res.nscount = 2;
+
+ DnsConfig config;
+ EXPECT_EQ(internal::CONFIG_PARSE_POSIX_NULL_ADDRESS,
+ internal::ConvertResStateToDnsConfig(res, &config));
+
+ sa.sin_addr.s_addr = 0xDEADBEEF;
+ res.nsaddr_list[0] = sa;
+ EXPECT_EQ(internal::CONFIG_PARSE_POSIX_OK,
+ internal::ConvertResStateToDnsConfig(res, &config));
+}
+
+} // namespace
+} // namespace net
diff --git a/chromium/net/dns/dns_config_service_unittest.cc b/chromium/net/dns/dns_config_service_unittest.cc
new file mode 100644
index 00000000000..f42068032de
--- /dev/null
+++ b/chromium/net/dns/dns_config_service_unittest.cc
@@ -0,0 +1,258 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_config_service.h"
+
+#include "base/basictypes.h"
+#include "base/bind.h"
+#include "base/cancelable_callback.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/message_loop/message_loop.h"
+#include "base/test/test_timeouts.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+class DnsConfigServiceTest : public testing::Test {
+ public:
+ void OnConfigChanged(const DnsConfig& config) {
+ last_config_ = config;
+ if (quit_on_config_)
+ base::MessageLoop::current()->Quit();
+ }
+
+ protected:
+ class TestDnsConfigService : public DnsConfigService {
+ public:
+ virtual void ReadNow() OVERRIDE {}
+ virtual bool StartWatching() OVERRIDE { return true; }
+
+ // Expose the protected methods to this test suite.
+ void InvalidateConfig() {
+ DnsConfigService::InvalidateConfig();
+ }
+
+ void InvalidateHosts() {
+ DnsConfigService::InvalidateHosts();
+ }
+
+ void OnConfigRead(const DnsConfig& config) {
+ DnsConfigService::OnConfigRead(config);
+ }
+
+ void OnHostsRead(const DnsHosts& hosts) {
+ DnsConfigService::OnHostsRead(hosts);
+ }
+
+ void set_watch_failed(bool value) {
+ DnsConfigService::set_watch_failed(value);
+ }
+ };
+
+ void WaitForConfig(base::TimeDelta timeout) {
+ base::CancelableClosure closure(base::MessageLoop::QuitClosure());
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE, closure.callback(), timeout);
+ quit_on_config_ = true;
+ base::MessageLoop::current()->Run();
+ quit_on_config_ = false;
+ closure.Cancel();
+ }
+
+ // Generate a config using the given seed..
+ DnsConfig MakeConfig(unsigned seed) {
+ DnsConfig config;
+ IPAddressNumber ip;
+ CHECK(ParseIPLiteralToNumber("1.2.3.4", &ip));
+ config.nameservers.push_back(IPEndPoint(ip, seed & 0xFFFF));
+ EXPECT_TRUE(config.IsValid());
+ return config;
+ }
+
+ // Generate hosts using the given seed.
+ DnsHosts MakeHosts(unsigned seed) {
+ DnsHosts hosts;
+ std::string hosts_content = "127.0.0.1 localhost";
+ hosts_content.append(seed, '1');
+ ParseHosts(hosts_content, &hosts);
+ EXPECT_FALSE(hosts.empty());
+ return hosts;
+ }
+
+ virtual void SetUp() OVERRIDE {
+ quit_on_config_ = false;
+
+ service_.reset(new TestDnsConfigService());
+ service_->WatchConfig(base::Bind(&DnsConfigServiceTest::OnConfigChanged,
+ base::Unretained(this)));
+ EXPECT_FALSE(last_config_.IsValid());
+ }
+
+ DnsConfig last_config_;
+ bool quit_on_config_;
+
+ // Service under test.
+ scoped_ptr<TestDnsConfigService> service_;
+};
+
+} // namespace
+
+TEST_F(DnsConfigServiceTest, FirstConfig) {
+ DnsConfig config = MakeConfig(1);
+
+ service_->OnConfigRead(config);
+ // No hosts yet, so no config.
+ EXPECT_TRUE(last_config_.Equals(DnsConfig()));
+
+ service_->OnHostsRead(config.hosts);
+ EXPECT_TRUE(last_config_.Equals(config));
+}
+
+TEST_F(DnsConfigServiceTest, Timeout) {
+ DnsConfig config = MakeConfig(1);
+ config.hosts = MakeHosts(1);
+ ASSERT_TRUE(config.IsValid());
+
+ service_->OnConfigRead(config);
+ service_->OnHostsRead(config.hosts);
+ EXPECT_FALSE(last_config_.Equals(DnsConfig()));
+ EXPECT_TRUE(last_config_.Equals(config));
+
+ service_->InvalidateConfig();
+ WaitForConfig(TestTimeouts::action_timeout());
+ EXPECT_FALSE(last_config_.Equals(config));
+ EXPECT_TRUE(last_config_.Equals(DnsConfig()));
+
+ service_->OnConfigRead(config);
+ EXPECT_FALSE(last_config_.Equals(DnsConfig()));
+ EXPECT_TRUE(last_config_.Equals(config));
+
+ service_->InvalidateHosts();
+ WaitForConfig(TestTimeouts::action_timeout());
+ EXPECT_FALSE(last_config_.Equals(config));
+ EXPECT_TRUE(last_config_.Equals(DnsConfig()));
+
+ DnsConfig bad_config = last_config_ = MakeConfig(0xBAD);
+ service_->InvalidateConfig();
+ // We don't expect an update. This should time out.
+ WaitForConfig(base::TimeDelta::FromMilliseconds(100) +
+ TestTimeouts::tiny_timeout());
+ EXPECT_TRUE(last_config_.Equals(bad_config)) << "Unexpected change";
+
+ last_config_ = DnsConfig();
+ service_->OnConfigRead(config);
+ service_->OnHostsRead(config.hosts);
+ EXPECT_FALSE(last_config_.Equals(DnsConfig()));
+ EXPECT_TRUE(last_config_.Equals(config));
+}
+
+TEST_F(DnsConfigServiceTest, SameConfig) {
+ DnsConfig config = MakeConfig(1);
+ config.hosts = MakeHosts(1);
+
+ service_->OnConfigRead(config);
+ service_->OnHostsRead(config.hosts);
+ EXPECT_FALSE(last_config_.Equals(DnsConfig()));
+ EXPECT_TRUE(last_config_.Equals(config));
+
+ last_config_ = DnsConfig();
+ service_->OnConfigRead(config);
+ EXPECT_TRUE(last_config_.Equals(DnsConfig())) << "Unexpected change";
+
+ service_->OnHostsRead(config.hosts);
+ EXPECT_TRUE(last_config_.Equals(DnsConfig())) << "Unexpected change";
+}
+
+TEST_F(DnsConfigServiceTest, DifferentConfig) {
+ DnsConfig config1 = MakeConfig(1);
+ DnsConfig config2 = MakeConfig(2);
+ DnsConfig config3 = MakeConfig(1);
+ config1.hosts = MakeHosts(1);
+ config2.hosts = MakeHosts(1);
+ config3.hosts = MakeHosts(2);
+ ASSERT_TRUE(config1.EqualsIgnoreHosts(config3));
+ ASSERT_FALSE(config1.Equals(config2));
+ ASSERT_FALSE(config1.Equals(config3));
+ ASSERT_FALSE(config2.Equals(config3));
+
+ service_->OnConfigRead(config1);
+ service_->OnHostsRead(config1.hosts);
+ EXPECT_FALSE(last_config_.Equals(DnsConfig()));
+ EXPECT_TRUE(last_config_.Equals(config1));
+
+ // It doesn't matter for this tests, but increases coverage.
+ service_->InvalidateConfig();
+ service_->InvalidateHosts();
+
+ service_->OnConfigRead(config2);
+ EXPECT_TRUE(last_config_.Equals(config1)) << "Unexpected change";
+ service_->OnHostsRead(config2.hosts); // Not an actual change.
+ EXPECT_FALSE(last_config_.Equals(config1));
+ EXPECT_TRUE(last_config_.Equals(config2));
+
+ service_->OnConfigRead(config3);
+ EXPECT_TRUE(last_config_.EqualsIgnoreHosts(config3));
+ service_->OnHostsRead(config3.hosts);
+ EXPECT_FALSE(last_config_.Equals(config2));
+ EXPECT_TRUE(last_config_.Equals(config3));
+}
+
+TEST_F(DnsConfigServiceTest, WatchFailure) {
+ DnsConfig config1 = MakeConfig(1);
+ DnsConfig config2 = MakeConfig(2);
+ config1.hosts = MakeHosts(1);
+ config2.hosts = MakeHosts(2);
+
+ service_->OnConfigRead(config1);
+ service_->OnHostsRead(config1.hosts);
+ EXPECT_FALSE(last_config_.Equals(DnsConfig()));
+ EXPECT_TRUE(last_config_.Equals(config1));
+
+ // Simulate watch failure.
+ service_->set_watch_failed(true);
+ service_->InvalidateConfig();
+ WaitForConfig(TestTimeouts::action_timeout());
+ EXPECT_FALSE(last_config_.Equals(config1));
+ EXPECT_TRUE(last_config_.Equals(DnsConfig()));
+
+ DnsConfig bad_config = last_config_ = MakeConfig(0xBAD);
+ // Actual change in config, so expect an update, but it should be empty.
+ service_->OnConfigRead(config1);
+ EXPECT_FALSE(last_config_.Equals(bad_config));
+ EXPECT_TRUE(last_config_.Equals(DnsConfig()));
+
+ last_config_ = bad_config;
+ // Actual change in config, so expect an update, but it should be empty.
+ service_->InvalidateConfig();
+ service_->OnConfigRead(config2);
+ EXPECT_FALSE(last_config_.Equals(bad_config));
+ EXPECT_TRUE(last_config_.Equals(DnsConfig()));
+
+ last_config_ = bad_config;
+ // No change, so no update.
+ service_->InvalidateConfig();
+ service_->OnConfigRead(config2);
+ EXPECT_TRUE(last_config_.Equals(bad_config));
+}
+
+#if (defined(OS_POSIX) && !defined(OS_ANDROID)) || defined(OS_WIN)
+// TODO(szym): This is really an integration test and can time out if HOSTS is
+// huge. http://crbug.com/107810
+TEST_F(DnsConfigServiceTest, DISABLED_GetSystemConfig) {
+ service_.reset();
+ scoped_ptr<DnsConfigService> service(DnsConfigService::CreateSystemService());
+
+ service->ReadConfig(base::Bind(&DnsConfigServiceTest::OnConfigChanged,
+ base::Unretained(this)));
+ base::TimeDelta kTimeout = TestTimeouts::action_max_timeout();
+ WaitForConfig(kTimeout);
+ ASSERT_TRUE(last_config_.IsValid()) << "Did not receive DnsConfig in " <<
+ kTimeout.InSecondsF() << "s";
+}
+#endif // OS_POSIX || OS_WIN
+
+} // namespace net
+
diff --git a/chromium/net/dns/dns_config_service_win.cc b/chromium/net/dns/dns_config_service_win.cc
new file mode 100644
index 00000000000..6d12155d8cc
--- /dev/null
+++ b/chromium/net/dns/dns_config_service_win.cc
@@ -0,0 +1,737 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_config_service_win.h"
+
+#include <algorithm>
+#include <string>
+
+#include "base/bind.h"
+#include "base/callback.h"
+#include "base/compiler_specific.h"
+#include "base/files/file_path.h"
+#include "base/files/file_path_watcher.h"
+#include "base/logging.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/metrics/histogram.h"
+#include "base/strings/string_split.h"
+#include "base/strings/string_util.h"
+#include "base/strings/utf_string_conversions.h"
+#include "base/synchronization/lock.h"
+#include "base/threading/non_thread_safe.h"
+#include "base/threading/thread_restrictions.h"
+#include "base/time/time.h"
+#include "base/win/object_watcher.h"
+#include "base/win/registry.h"
+#include "base/win/windows_version.h"
+#include "net/base/net_util.h"
+#include "net/base/network_change_notifier.h"
+#include "net/dns/dns_hosts.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/serial_worker.h"
+#include "url/url_canon.h"
+
+#pragma comment(lib, "iphlpapi.lib")
+
+namespace net {
+
+namespace internal {
+
+namespace {
+
+// Interval between retries to parse config. Used only until parsing succeeds.
+const int kRetryIntervalSeconds = 5;
+
+const wchar_t* const kPrimaryDnsSuffixPath =
+ L"SOFTWARE\\Policies\\Microsoft\\System\\DNSClient";
+
+enum HostsParseWinResult {
+ HOSTS_PARSE_WIN_OK = 0,
+ HOSTS_PARSE_WIN_UNREADABLE_HOSTS_FILE,
+ HOSTS_PARSE_WIN_COMPUTER_NAME_FAILED,
+ HOSTS_PARSE_WIN_IPHELPER_FAILED,
+ HOSTS_PARSE_WIN_BAD_ADDRESS,
+ HOSTS_PARSE_WIN_MAX // Bounding values for enumeration.
+};
+
+// Convenience for reading values using RegKey.
+class RegistryReader : public base::NonThreadSafe {
+ public:
+ explicit RegistryReader(const wchar_t* key) {
+ // Ignoring the result. |key_.Valid()| will catch failures.
+ key_.Open(HKEY_LOCAL_MACHINE, key, KEY_QUERY_VALUE);
+ }
+
+ bool ReadString(const wchar_t* name,
+ DnsSystemSettings::RegString* out) const {
+ DCHECK(CalledOnValidThread());
+ out->set = false;
+ if (!key_.Valid()) {
+ // Assume that if the |key_| is invalid then the key is missing.
+ return true;
+ }
+ LONG result = key_.ReadValue(name, &out->value);
+ if (result == ERROR_SUCCESS) {
+ out->set = true;
+ return true;
+ }
+ return (result == ERROR_FILE_NOT_FOUND);
+ }
+
+ bool ReadDword(const wchar_t* name,
+ DnsSystemSettings::RegDword* out) const {
+ DCHECK(CalledOnValidThread());
+ out->set = false;
+ if (!key_.Valid()) {
+ // Assume that if the |key_| is invalid then the key is missing.
+ return true;
+ }
+ LONG result = key_.ReadValueDW(name, &out->value);
+ if (result == ERROR_SUCCESS) {
+ out->set = true;
+ return true;
+ }
+ return (result == ERROR_FILE_NOT_FOUND);
+ }
+
+ private:
+ base::win::RegKey key_;
+
+ DISALLOW_COPY_AND_ASSIGN(RegistryReader);
+};
+
+// Wrapper for GetAdaptersAddresses. Returns NULL if failed.
+scoped_ptr_malloc<IP_ADAPTER_ADDRESSES> ReadIpHelper(ULONG flags) {
+ base::ThreadRestrictions::AssertIOAllowed();
+
+ scoped_ptr_malloc<IP_ADAPTER_ADDRESSES> out;
+ ULONG len = 15000; // As recommended by MSDN for GetAdaptersAddresses.
+ UINT rv = ERROR_BUFFER_OVERFLOW;
+ // Try up to three times.
+ for (unsigned tries = 0; (tries < 3) && (rv == ERROR_BUFFER_OVERFLOW);
+ tries++) {
+ out.reset(reinterpret_cast<PIP_ADAPTER_ADDRESSES>(malloc(len)));
+ rv = GetAdaptersAddresses(AF_UNSPEC, flags, NULL, out.get(), &len);
+ }
+ if (rv != NO_ERROR)
+ out.reset();
+ return out.Pass();
+}
+
+// Converts a base::string16 domain name to ASCII, possibly using punycode.
+// Returns true if the conversion succeeds and output is not empty. In case of
+// failure, |domain| might become dirty.
+bool ParseDomainASCII(const base::string16& widestr, std::string* domain) {
+ DCHECK(domain);
+ if (widestr.empty())
+ return false;
+
+ // Check if already ASCII.
+ if (IsStringASCII(widestr)) {
+ *domain = UTF16ToASCII(widestr);
+ return true;
+ }
+
+ // Otherwise try to convert it from IDN to punycode.
+ const int kInitialBufferSize = 256;
+ url_canon::RawCanonOutputT<char16, kInitialBufferSize> punycode;
+ if (!url_canon::IDNToASCII(widestr.data(), widestr.length(), &punycode))
+ return false;
+
+ // |punycode_output| should now be ASCII; convert it to a std::string.
+ // (We could use UTF16ToASCII() instead, but that requires an extra string
+ // copy. Since ASCII is a subset of UTF8 the following is equivalent).
+ bool success = UTF16ToUTF8(punycode.data(), punycode.length(), domain);
+ DCHECK(success);
+ DCHECK(IsStringASCII(*domain));
+ return success && !domain->empty();
+}
+
+bool ReadDevolutionSetting(const RegistryReader& reader,
+ DnsSystemSettings::DevolutionSetting* setting) {
+ return reader.ReadDword(L"UseDomainNameDevolution", &setting->enabled) &&
+ reader.ReadDword(L"DomainNameDevolutionLevel", &setting->level);
+}
+
+// Reads DnsSystemSettings from IpHelper and registry.
+ConfigParseWinResult ReadSystemSettings(DnsSystemSettings* settings) {
+ settings->addresses = ReadIpHelper(GAA_FLAG_SKIP_ANYCAST |
+ GAA_FLAG_SKIP_UNICAST |
+ GAA_FLAG_SKIP_MULTICAST |
+ GAA_FLAG_SKIP_FRIENDLY_NAME);
+ if (!settings->addresses.get())
+ return CONFIG_PARSE_WIN_READ_IPHELPER;
+
+ RegistryReader tcpip_reader(kTcpipPath);
+ RegistryReader tcpip6_reader(kTcpip6Path);
+ RegistryReader dnscache_reader(kDnscachePath);
+ RegistryReader policy_reader(kPolicyPath);
+ RegistryReader primary_dns_suffix_reader(kPrimaryDnsSuffixPath);
+
+ if (!policy_reader.ReadString(L"SearchList",
+ &settings->policy_search_list)) {
+ return CONFIG_PARSE_WIN_READ_POLICY_SEARCHLIST;
+ }
+
+ if (!tcpip_reader.ReadString(L"SearchList", &settings->tcpip_search_list))
+ return CONFIG_PARSE_WIN_READ_TCPIP_SEARCHLIST;
+
+ if (!tcpip_reader.ReadString(L"Domain", &settings->tcpip_domain))
+ return CONFIG_PARSE_WIN_READ_DOMAIN;
+
+ if (!ReadDevolutionSetting(policy_reader, &settings->policy_devolution))
+ return CONFIG_PARSE_WIN_READ_POLICY_DEVOLUTION;
+
+ if (!ReadDevolutionSetting(dnscache_reader, &settings->dnscache_devolution))
+ return CONFIG_PARSE_WIN_READ_DNSCACHE_DEVOLUTION;
+
+ if (!ReadDevolutionSetting(tcpip_reader, &settings->tcpip_devolution))
+ return CONFIG_PARSE_WIN_READ_TCPIP_DEVOLUTION;
+
+ if (!policy_reader.ReadDword(L"AppendToMultiLabelName",
+ &settings->append_to_multi_label_name)) {
+ return CONFIG_PARSE_WIN_READ_APPEND_MULTILABEL;
+ }
+
+ if (!primary_dns_suffix_reader.ReadString(L"PrimaryDnsSuffix",
+ &settings->primary_dns_suffix)) {
+ return CONFIG_PARSE_WIN_READ_PRIMARY_SUFFIX;
+ }
+ return CONFIG_PARSE_WIN_OK;
+}
+
+// Default address of "localhost" and local computer name can be overridden
+// by the HOSTS file, but if it's not there, then we need to fill it in.
+HostsParseWinResult AddLocalhostEntries(DnsHosts* hosts) {
+ const unsigned char kIPv4Localhost[] = { 127, 0, 0, 1 };
+ const unsigned char kIPv6Localhost[] = { 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 1 };
+ IPAddressNumber loopback_ipv4(kIPv4Localhost,
+ kIPv4Localhost + arraysize(kIPv4Localhost));
+ IPAddressNumber loopback_ipv6(kIPv6Localhost,
+ kIPv6Localhost + arraysize(kIPv6Localhost));
+
+ // This does not override any pre-existing entries from the HOSTS file.
+ hosts->insert(std::make_pair(DnsHostsKey("localhost", ADDRESS_FAMILY_IPV4),
+ loopback_ipv4));
+ hosts->insert(std::make_pair(DnsHostsKey("localhost", ADDRESS_FAMILY_IPV6),
+ loopback_ipv6));
+
+ WCHAR buffer[MAX_PATH];
+ DWORD size = MAX_PATH;
+ std::string localname;
+ if (!GetComputerNameExW(ComputerNameDnsHostname, buffer, &size) ||
+ !ParseDomainASCII(buffer, &localname)) {
+ return HOSTS_PARSE_WIN_COMPUTER_NAME_FAILED;
+ }
+ StringToLowerASCII(&localname);
+
+ bool have_ipv4 =
+ hosts->count(DnsHostsKey(localname, ADDRESS_FAMILY_IPV4)) > 0;
+ bool have_ipv6 =
+ hosts->count(DnsHostsKey(localname, ADDRESS_FAMILY_IPV6)) > 0;
+
+ if (have_ipv4 && have_ipv6)
+ return HOSTS_PARSE_WIN_OK;
+
+ scoped_ptr_malloc<IP_ADAPTER_ADDRESSES> addresses =
+ ReadIpHelper(GAA_FLAG_SKIP_ANYCAST |
+ GAA_FLAG_SKIP_DNS_SERVER |
+ GAA_FLAG_SKIP_MULTICAST |
+ GAA_FLAG_SKIP_FRIENDLY_NAME);
+ if (!addresses.get())
+ return HOSTS_PARSE_WIN_IPHELPER_FAILED;
+
+ // The order of adapters is the network binding order, so stick to the
+ // first good adapter for each family.
+ for (const IP_ADAPTER_ADDRESSES* adapter = addresses.get();
+ adapter != NULL && (!have_ipv4 || !have_ipv6);
+ adapter = adapter->Next) {
+ if (adapter->OperStatus != IfOperStatusUp)
+ continue;
+ if (adapter->IfType == IF_TYPE_SOFTWARE_LOOPBACK)
+ continue;
+
+ for (const IP_ADAPTER_UNICAST_ADDRESS* address =
+ adapter->FirstUnicastAddress;
+ address != NULL;
+ address = address->Next) {
+ IPEndPoint ipe;
+ if (!ipe.FromSockAddr(address->Address.lpSockaddr,
+ address->Address.iSockaddrLength)) {
+ return HOSTS_PARSE_WIN_BAD_ADDRESS;
+ }
+ if (!have_ipv4 && (ipe.GetFamily() == ADDRESS_FAMILY_IPV4)) {
+ have_ipv4 = true;
+ (*hosts)[DnsHostsKey(localname, ADDRESS_FAMILY_IPV4)] = ipe.address();
+ } else if (!have_ipv6 && (ipe.GetFamily() == ADDRESS_FAMILY_IPV6)) {
+ have_ipv6 = true;
+ (*hosts)[DnsHostsKey(localname, ADDRESS_FAMILY_IPV6)] = ipe.address();
+ }
+ }
+ }
+ return HOSTS_PARSE_WIN_OK;
+}
+
+// Watches a single registry key for changes.
+class RegistryWatcher : public base::win::ObjectWatcher::Delegate,
+ public base::NonThreadSafe {
+ public:
+ typedef base::Callback<void(bool succeeded)> CallbackType;
+ RegistryWatcher() {}
+
+ bool Watch(const wchar_t* key, const CallbackType& callback) {
+ DCHECK(CalledOnValidThread());
+ DCHECK(!callback.is_null());
+ DCHECK(callback_.is_null());
+ callback_ = callback;
+ if (key_.Open(HKEY_LOCAL_MACHINE, key, KEY_NOTIFY) != ERROR_SUCCESS)
+ return false;
+ if (key_.StartWatching() != ERROR_SUCCESS)
+ return false;
+ if (!watcher_.StartWatching(key_.watch_event(), this))
+ return false;
+ return true;
+ }
+
+ virtual void OnObjectSignaled(HANDLE object) OVERRIDE {
+ DCHECK(CalledOnValidThread());
+ bool succeeded = (key_.StartWatching() == ERROR_SUCCESS) &&
+ watcher_.StartWatching(key_.watch_event(), this);
+ if (!succeeded && key_.Valid()) {
+ watcher_.StopWatching();
+ key_.StopWatching();
+ key_.Close();
+ }
+ if (!callback_.is_null())
+ callback_.Run(succeeded);
+ }
+
+ private:
+ CallbackType callback_;
+ base::win::RegKey key_;
+ base::win::ObjectWatcher watcher_;
+
+ DISALLOW_COPY_AND_ASSIGN(RegistryWatcher);
+};
+
+// Returns true iff |address| is DNS address from IPv6 stateless discovery,
+// i.e., matches fec0:0:0:ffff::{1,2,3}.
+// http://tools.ietf.org/html/draft-ietf-ipngwg-dns-discovery
+bool IsStatelessDiscoveryAddress(const IPAddressNumber& address) {
+ if (address.size() != kIPv6AddressSize)
+ return false;
+ const uint8 kPrefix[] = {
+ 0xfe, 0xc0, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ };
+ return std::equal(kPrefix, kPrefix + arraysize(kPrefix),
+ address.begin()) && (address.back() < 4);
+}
+
+} // namespace
+
+base::FilePath GetHostsPath() {
+ TCHAR buffer[MAX_PATH];
+ UINT rc = GetSystemDirectory(buffer, MAX_PATH);
+ DCHECK(0 < rc && rc < MAX_PATH);
+ return base::FilePath(buffer).Append(
+ FILE_PATH_LITERAL("drivers\\etc\\hosts"));
+}
+
+bool ParseSearchList(const base::string16& value,
+ std::vector<std::string>* output) {
+ DCHECK(output);
+ if (value.empty())
+ return false;
+
+ output->clear();
+
+ // If the list includes an empty hostname (",," or ", ,"), it is terminated.
+ // Although nslookup and network connection property tab ignore such
+ // fragments ("a,b,,c" becomes ["a", "b", "c"]), our reference is getaddrinfo
+ // (which sees ["a", "b"]). WMI queries also return a matching search list.
+ std::vector<base::string16> woutput;
+ base::SplitString(value, ',', &woutput);
+ for (size_t i = 0; i < woutput.size(); ++i) {
+ // Convert non-ASCII to punycode, although getaddrinfo does not properly
+ // handle such suffixes.
+ const base::string16& t = woutput[i];
+ std::string parsed;
+ if (!ParseDomainASCII(t, &parsed))
+ break;
+ output->push_back(parsed);
+ }
+ return !output->empty();
+}
+
+ConfigParseWinResult ConvertSettingsToDnsConfig(
+ const DnsSystemSettings& settings,
+ DnsConfig* config) {
+ *config = DnsConfig();
+
+ // Use GetAdapterAddresses to get effective DNS server order and
+ // connection-specific DNS suffix. Ignore disconnected and loopback adapters.
+ // The order of adapters is the network binding order, so stick to the
+ // first good adapter.
+ for (const IP_ADAPTER_ADDRESSES* adapter = settings.addresses.get();
+ adapter != NULL && config->nameservers.empty();
+ adapter = adapter->Next) {
+ if (adapter->OperStatus != IfOperStatusUp)
+ continue;
+ if (adapter->IfType == IF_TYPE_SOFTWARE_LOOPBACK)
+ continue;
+
+ for (const IP_ADAPTER_DNS_SERVER_ADDRESS* address =
+ adapter->FirstDnsServerAddress;
+ address != NULL;
+ address = address->Next) {
+ IPEndPoint ipe;
+ if (ipe.FromSockAddr(address->Address.lpSockaddr,
+ address->Address.iSockaddrLength)) {
+ if (IsStatelessDiscoveryAddress(ipe.address()))
+ continue;
+ // Override unset port.
+ if (!ipe.port())
+ ipe = IPEndPoint(ipe.address(), dns_protocol::kDefaultPort);
+ config->nameservers.push_back(ipe);
+ } else {
+ return CONFIG_PARSE_WIN_BAD_ADDRESS;
+ }
+ }
+
+ // IP_ADAPTER_ADDRESSES in Vista+ has a search list at |FirstDnsSuffix|,
+ // but it came up empty in all trials.
+ // |DnsSuffix| stores the effective connection-specific suffix, which is
+ // obtained via DHCP (regkey: Tcpip\Parameters\Interfaces\{XXX}\DhcpDomain)
+ // or specified by the user (regkey: Tcpip\Parameters\Domain).
+ std::string dns_suffix;
+ if (ParseDomainASCII(adapter->DnsSuffix, &dns_suffix))
+ config->search.push_back(dns_suffix);
+ }
+
+ if (config->nameservers.empty())
+ return CONFIG_PARSE_WIN_NO_NAMESERVERS; // No point continuing.
+
+ // Windows always tries a multi-label name "as is" before using suffixes.
+ config->ndots = 1;
+
+ if (!settings.append_to_multi_label_name.set) {
+ // The default setting is true for XP, false for Vista+.
+ if (base::win::GetVersion() >= base::win::VERSION_VISTA) {
+ config->append_to_multi_label_name = false;
+ } else {
+ config->append_to_multi_label_name = true;
+ }
+ } else {
+ config->append_to_multi_label_name =
+ (settings.append_to_multi_label_name.value != 0);
+ }
+
+ // SearchList takes precedence, so check it first.
+ if (settings.policy_search_list.set) {
+ std::vector<std::string> search;
+ if (ParseSearchList(settings.policy_search_list.value, &search)) {
+ config->search.swap(search);
+ return CONFIG_PARSE_WIN_OK;
+ }
+ // Even if invalid, the policy disables the user-specified setting below.
+ } else if (settings.tcpip_search_list.set) {
+ std::vector<std::string> search;
+ if (ParseSearchList(settings.tcpip_search_list.value, &search)) {
+ config->search.swap(search);
+ return CONFIG_PARSE_WIN_OK;
+ }
+ }
+
+ // In absence of explicit search list, suffix search is:
+ // [primary suffix, connection-specific suffix, devolution of primary suffix].
+ // Primary suffix can be set by policy (primary_dns_suffix) or
+ // user setting (tcpip_domain).
+ //
+ // The policy (primary_dns_suffix) can be edited via Group Policy Editor
+ // (gpedit.msc) at Local Computer Policy => Computer Configuration
+ // => Administrative Template => Network => DNS Client => Primary DNS Suffix.
+ //
+ // The user setting (tcpip_domain) can be configurred at Computer Name in
+ // System Settings
+ std::string primary_suffix;
+ if ((settings.primary_dns_suffix.set &&
+ ParseDomainASCII(settings.primary_dns_suffix.value, &primary_suffix)) ||
+ (settings.tcpip_domain.set &&
+ ParseDomainASCII(settings.tcpip_domain.value, &primary_suffix))) {
+ // Primary suffix goes in front.
+ config->search.insert(config->search.begin(), primary_suffix);
+ } else {
+ return CONFIG_PARSE_WIN_OK; // No primary suffix, hence no devolution.
+ }
+
+ // Devolution is determined by precedence: policy > dnscache > tcpip.
+ // |enabled|: UseDomainNameDevolution and |level|: DomainNameDevolutionLevel
+ // are overridden independently.
+ DnsSystemSettings::DevolutionSetting devolution = settings.policy_devolution;
+
+ if (!devolution.enabled.set)
+ devolution.enabled = settings.dnscache_devolution.enabled;
+ if (!devolution.enabled.set)
+ devolution.enabled = settings.tcpip_devolution.enabled;
+ if (devolution.enabled.set && (devolution.enabled.value == 0))
+ return CONFIG_PARSE_WIN_OK; // Devolution disabled.
+
+ // By default devolution is enabled.
+
+ if (!devolution.level.set)
+ devolution.level = settings.dnscache_devolution.level;
+ if (!devolution.level.set)
+ devolution.level = settings.tcpip_devolution.level;
+
+ // After the recent update, Windows will try to determine a safe default
+ // value by comparing the forest root domain (FRD) to the primary suffix.
+ // See http://support.microsoft.com/kb/957579 for details.
+ // For now, if the level is not set, we disable devolution, assuming that
+ // we will fallback to the system getaddrinfo anyway. This might cause
+ // performance loss for resolutions which depend on the system default
+ // devolution setting.
+ //
+ // If the level is explicitly set below 2, devolution is disabled.
+ if (!devolution.level.set || devolution.level.value < 2)
+ return CONFIG_PARSE_WIN_OK; // Devolution disabled.
+
+ // Devolve the primary suffix. This naive logic matches the observed
+ // behavior (see also ParseSearchList). If a suffix is not valid, it will be
+ // discarded when the fully-qualified name is converted to DNS format.
+
+ unsigned num_dots = std::count(primary_suffix.begin(),
+ primary_suffix.end(), '.');
+
+ for (size_t offset = 0; num_dots >= devolution.level.value; --num_dots) {
+ offset = primary_suffix.find('.', offset + 1);
+ config->search.push_back(primary_suffix.substr(offset + 1));
+ }
+ return CONFIG_PARSE_WIN_OK;
+}
+
+// Watches registry and HOSTS file for changes. Must live on a thread which
+// allows IO.
+class DnsConfigServiceWin::Watcher
+ : public NetworkChangeNotifier::IPAddressObserver {
+ public:
+ explicit Watcher(DnsConfigServiceWin* service) : service_(service) {}
+ ~Watcher() {
+ NetworkChangeNotifier::RemoveIPAddressObserver(this);
+ }
+
+ bool Watch() {
+ RegistryWatcher::CallbackType callback =
+ base::Bind(&DnsConfigServiceWin::OnConfigChanged,
+ base::Unretained(service_));
+
+ bool success = true;
+
+ // The Tcpip key must be present.
+ if (!tcpip_watcher_.Watch(kTcpipPath, callback)) {
+ LOG(ERROR) << "DNS registry watch failed to start.";
+ success = false;
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus",
+ DNS_CONFIG_WATCH_FAILED_TO_START_CONFIG,
+ DNS_CONFIG_WATCH_MAX);
+ }
+
+ // Watch for IPv6 nameservers.
+ tcpip6_watcher_.Watch(kTcpip6Path, callback);
+
+ // DNS suffix search list and devolution can be configured via group
+ // policy which sets this registry key. If the key is missing, the policy
+ // does not apply, and the DNS client uses Tcpip and Dnscache settings.
+ // If a policy is installed, DnsConfigService will need to be restarted.
+ // BUG=99509
+
+ dnscache_watcher_.Watch(kDnscachePath, callback);
+ policy_watcher_.Watch(kPolicyPath, callback);
+
+ if (!hosts_watcher_.Watch(GetHostsPath(), false,
+ base::Bind(&Watcher::OnHostsChanged,
+ base::Unretained(this)))) {
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus",
+ DNS_CONFIG_WATCH_FAILED_TO_START_HOSTS,
+ DNS_CONFIG_WATCH_MAX);
+ LOG(ERROR) << "DNS hosts watch failed to start.";
+ success = false;
+ } else {
+ // Also need to observe changes to local non-loopback IP for DnsHosts.
+ NetworkChangeNotifier::AddIPAddressObserver(this);
+ }
+ return success;
+ }
+
+ private:
+ void OnHostsChanged(const base::FilePath& path, bool error) {
+ if (error)
+ NetworkChangeNotifier::RemoveIPAddressObserver(this);
+ service_->OnHostsChanged(!error);
+ }
+
+ // NetworkChangeNotifier::IPAddressObserver:
+ virtual void OnIPAddressChanged() OVERRIDE {
+ // Need to update non-loopback IP of local host.
+ service_->OnHostsChanged(true);
+ }
+
+ DnsConfigServiceWin* service_;
+
+ RegistryWatcher tcpip_watcher_;
+ RegistryWatcher tcpip6_watcher_;
+ RegistryWatcher dnscache_watcher_;
+ RegistryWatcher policy_watcher_;
+ base::FilePathWatcher hosts_watcher_;
+
+ DISALLOW_COPY_AND_ASSIGN(Watcher);
+};
+
+// Reads config from registry and IpHelper. All work performed on WorkerPool.
+class DnsConfigServiceWin::ConfigReader : public SerialWorker {
+ public:
+ explicit ConfigReader(DnsConfigServiceWin* service)
+ : service_(service),
+ success_(false) {}
+
+ private:
+ virtual ~ConfigReader() {}
+
+ virtual void DoWork() OVERRIDE {
+ // Should be called on WorkerPool.
+ base::TimeTicks start_time = base::TimeTicks::Now();
+ DnsSystemSettings settings = {};
+ ConfigParseWinResult result = ReadSystemSettings(&settings);
+ if (result == CONFIG_PARSE_WIN_OK)
+ result = ConvertSettingsToDnsConfig(settings, &dns_config_);
+ success_ = (result == CONFIG_PARSE_WIN_OK);
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.ConfigParseWin",
+ result, CONFIG_PARSE_WIN_MAX);
+ UMA_HISTOGRAM_BOOLEAN("AsyncDNS.ConfigParseResult", success_);
+ UMA_HISTOGRAM_TIMES("AsyncDNS.ConfigParseDuration",
+ base::TimeTicks::Now() - start_time);
+ }
+
+ virtual void OnWorkFinished() OVERRIDE {
+ DCHECK(loop()->BelongsToCurrentThread());
+ DCHECK(!IsCancelled());
+ if (success_) {
+ service_->OnConfigRead(dns_config_);
+ } else {
+ LOG(WARNING) << "Failed to read DnsConfig.";
+ // Try again in a while in case DnsConfigWatcher missed the signal.
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(&ConfigReader::WorkNow, this),
+ base::TimeDelta::FromSeconds(kRetryIntervalSeconds));
+ }
+ }
+
+ DnsConfigServiceWin* service_;
+ // Written in DoWork(), read in OnWorkFinished(). No locking required.
+ DnsConfig dns_config_;
+ bool success_;
+};
+
+// Reads hosts from HOSTS file and fills in localhost and local computer name if
+// necessary. All work performed on WorkerPool.
+class DnsConfigServiceWin::HostsReader : public SerialWorker {
+ public:
+ explicit HostsReader(DnsConfigServiceWin* service)
+ : path_(GetHostsPath()),
+ service_(service),
+ success_(false) {
+ }
+
+ private:
+ virtual ~HostsReader() {}
+
+ virtual void DoWork() OVERRIDE {
+ base::TimeTicks start_time = base::TimeTicks::Now();
+ HostsParseWinResult result = HOSTS_PARSE_WIN_UNREADABLE_HOSTS_FILE;
+ if (ParseHostsFile(path_, &hosts_))
+ result = AddLocalhostEntries(&hosts_);
+ success_ = (result == HOSTS_PARSE_WIN_OK);
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.HostsParseWin",
+ result, HOSTS_PARSE_WIN_MAX);
+ UMA_HISTOGRAM_BOOLEAN("AsyncDNS.HostParseResult", success_);
+ UMA_HISTOGRAM_TIMES("AsyncDNS.HostsParseDuration",
+ base::TimeTicks::Now() - start_time);
+ }
+
+ virtual void OnWorkFinished() OVERRIDE {
+ DCHECK(loop()->BelongsToCurrentThread());
+ if (success_) {
+ service_->OnHostsRead(hosts_);
+ } else {
+ LOG(WARNING) << "Failed to read DnsHosts.";
+ }
+ }
+
+ const base::FilePath path_;
+ DnsConfigServiceWin* service_;
+ // Written in DoWork, read in OnWorkFinished, no locking necessary.
+ DnsHosts hosts_;
+ bool success_;
+
+ DISALLOW_COPY_AND_ASSIGN(HostsReader);
+};
+
+DnsConfigServiceWin::DnsConfigServiceWin()
+ : config_reader_(new ConfigReader(this)),
+ hosts_reader_(new HostsReader(this)) {}
+
+DnsConfigServiceWin::~DnsConfigServiceWin() {
+ config_reader_->Cancel();
+ hosts_reader_->Cancel();
+}
+
+void DnsConfigServiceWin::ReadNow() {
+ config_reader_->WorkNow();
+ hosts_reader_->WorkNow();
+}
+
+bool DnsConfigServiceWin::StartWatching() {
+ // TODO(szym): re-start watcher if that makes sense. http://crbug.com/116139
+ watcher_.reset(new Watcher(this));
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus", DNS_CONFIG_WATCH_STARTED,
+ DNS_CONFIG_WATCH_MAX);
+ return watcher_->Watch();
+}
+
+void DnsConfigServiceWin::OnConfigChanged(bool succeeded) {
+ InvalidateConfig();
+ if (succeeded) {
+ config_reader_->WorkNow();
+ } else {
+ LOG(ERROR) << "DNS config watch failed.";
+ set_watch_failed(true);
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus",
+ DNS_CONFIG_WATCH_FAILED_CONFIG,
+ DNS_CONFIG_WATCH_MAX);
+ }
+}
+
+void DnsConfigServiceWin::OnHostsChanged(bool succeeded) {
+ InvalidateHosts();
+ if (succeeded) {
+ hosts_reader_->WorkNow();
+ } else {
+ LOG(ERROR) << "DNS hosts watch failed.";
+ set_watch_failed(true);
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus",
+ DNS_CONFIG_WATCH_FAILED_HOSTS,
+ DNS_CONFIG_WATCH_MAX);
+ }
+}
+
+} // namespace internal
+
+// static
+scoped_ptr<DnsConfigService> DnsConfigService::CreateSystemService() {
+ return scoped_ptr<DnsConfigService>(new internal::DnsConfigServiceWin());
+}
+
+} // namespace net
diff --git a/chromium/net/dns/dns_config_service_win.h b/chromium/net/dns/dns_config_service_win.h
new file mode 100644
index 00000000000..06fc0d9663b
--- /dev/null
+++ b/chromium/net/dns/dns_config_service_win.h
@@ -0,0 +1,154 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_CONFIG_SERVICE_WIN_H_
+#define NET_DNS_DNS_CONFIG_SERVICE_WIN_H_
+
+// The sole purpose of dns_config_service_win.h is for unittests so we just
+// include these headers here.
+#include <winsock2.h>
+#include <iphlpapi.h>
+
+#include <string>
+#include <vector>
+
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/strings/string16.h"
+#include "net/base/net_export.h"
+#include "net/dns/dns_config_service.h"
+
+// The general effort of DnsConfigServiceWin is to configure |nameservers| and
+// |search| in DnsConfig. The settings are stored in the Windows registry, but
+// to simplify the task we use the IP Helper API wherever possible. That API
+// yields the complete and ordered |nameservers|, but to determine |search| we
+// need to use the registry. On Windows 7, WMI does return the correct |search|
+// but on earlier versions it is insufficient.
+//
+// Experimental evaluation of Windows behavior suggests that domain parsing is
+// naive. Domain suffixes in |search| are not validated until they are appended
+// to the resolved name. We attempt to replicate this behavior.
+
+namespace net {
+
+namespace internal {
+
+// Registry key paths.
+const wchar_t* const kTcpipPath =
+ L"SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters";
+const wchar_t* const kTcpip6Path =
+ L"SYSTEM\\CurrentControlSet\\Services\\Tcpip6\\Parameters";
+const wchar_t* const kDnscachePath =
+ L"SYSTEM\\CurrentControlSet\\Services\\Dnscache\\Parameters";
+const wchar_t* const kPolicyPath =
+ L"SOFTWARE\\Policies\\Microsoft\\Windows NT\\DNSClient";
+
+// Returns the path to the HOSTS file.
+base::FilePath GetHostsPath();
+
+// Parses |value| as search list (comma-delimited list of domain names) from
+// a registry key and stores it in |out|. Returns true on success. Empty
+// entries (e.g., "chromium.org,,org") terminate the list. Non-ascii hostnames
+// are converted to punycode.
+bool NET_EXPORT_PRIVATE ParseSearchList(const base::string16& value,
+ std::vector<std::string>* out);
+
+// All relevant settings read from registry and IP Helper. This isolates our
+// logic from system calls and is exposed for unit tests. Keep it an aggregate
+// struct for easy initialization.
+struct NET_EXPORT_PRIVATE DnsSystemSettings {
+ // The |set| flag distinguishes between empty and unset values.
+ struct RegString {
+ bool set;
+ base::string16 value;
+ };
+
+ struct RegDword {
+ bool set;
+ DWORD value;
+ };
+
+ struct DevolutionSetting {
+ // UseDomainNameDevolution
+ RegDword enabled;
+ // DomainNameDevolutionLevel
+ RegDword level;
+ };
+
+ // Filled in by GetAdapterAddresses. Note that the alternative
+ // GetNetworkParams does not include IPv6 addresses.
+ scoped_ptr_malloc<IP_ADAPTER_ADDRESSES> addresses;
+
+ // SOFTWARE\Policies\Microsoft\Windows NT\DNSClient\SearchList
+ RegString policy_search_list;
+ // SYSTEM\CurrentControlSet\Tcpip\Parameters\SearchList
+ RegString tcpip_search_list;
+ // SYSTEM\CurrentControlSet\Tcpip\Parameters\Domain
+ RegString tcpip_domain;
+ // SOFTWARE\Policies\Microsoft\System\DNSClient\PrimaryDnsSuffix
+ RegString primary_dns_suffix;
+
+ // SOFTWARE\Policies\Microsoft\Windows NT\DNSClient
+ DevolutionSetting policy_devolution;
+ // SYSTEM\CurrentControlSet\Dnscache\Parameters
+ DevolutionSetting dnscache_devolution;
+ // SYSTEM\CurrentControlSet\Tcpip\Parameters
+ DevolutionSetting tcpip_devolution;
+
+ // SOFTWARE\Policies\Microsoft\Windows NT\DNSClient\AppendToMultiLabelName
+ RegDword append_to_multi_label_name;
+};
+
+enum ConfigParseWinResult {
+ CONFIG_PARSE_WIN_OK = 0,
+ CONFIG_PARSE_WIN_READ_IPHELPER,
+ CONFIG_PARSE_WIN_READ_POLICY_SEARCHLIST,
+ CONFIG_PARSE_WIN_READ_TCPIP_SEARCHLIST,
+ CONFIG_PARSE_WIN_READ_DOMAIN,
+ CONFIG_PARSE_WIN_READ_POLICY_DEVOLUTION,
+ CONFIG_PARSE_WIN_READ_DNSCACHE_DEVOLUTION,
+ CONFIG_PARSE_WIN_READ_TCPIP_DEVOLUTION,
+ CONFIG_PARSE_WIN_READ_APPEND_MULTILABEL,
+ CONFIG_PARSE_WIN_READ_PRIMARY_SUFFIX,
+ CONFIG_PARSE_WIN_BAD_ADDRESS,
+ CONFIG_PARSE_WIN_NO_NAMESERVERS,
+ CONFIG_PARSE_WIN_MAX // Bounding values for enumeration.
+};
+
+// Fills in |dns_config| from |settings|. Exposed for tests.
+ConfigParseWinResult NET_EXPORT_PRIVATE ConvertSettingsToDnsConfig(
+ const DnsSystemSettings& settings,
+ DnsConfig* dns_config);
+
+// Use DnsConfigService::CreateSystemService to use it outside of tests.
+class NET_EXPORT_PRIVATE DnsConfigServiceWin : public DnsConfigService {
+ public:
+ DnsConfigServiceWin();
+ virtual ~DnsConfigServiceWin();
+
+ private:
+ class Watcher;
+ class ConfigReader;
+ class HostsReader;
+
+ // DnsConfigService:
+ virtual void ReadNow() OVERRIDE;
+ virtual bool StartWatching() OVERRIDE;
+
+ void OnConfigChanged(bool succeeded);
+ void OnHostsChanged(bool succeeded);
+
+ scoped_ptr<Watcher> watcher_;
+ scoped_refptr<ConfigReader> config_reader_;
+ scoped_refptr<HostsReader> hosts_reader_;
+
+ DISALLOW_COPY_AND_ASSIGN(DnsConfigServiceWin);
+};
+
+} // namespace internal
+
+} // namespace net
+
+#endif // NET_DNS_DNS_CONFIG_SERVICE_WIN_H_
+
diff --git a/chromium/net/dns/dns_config_service_win_unittest.cc b/chromium/net/dns/dns_config_service_win_unittest.cc
new file mode 100644
index 00000000000..b28b8e9554a
--- /dev/null
+++ b/chromium/net/dns/dns_config_service_win_unittest.cc
@@ -0,0 +1,430 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_config_service_win.h"
+
+#include "base/logging.h"
+#include "base/win/windows_version.h"
+#include "net/dns/dns_protocol.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+TEST(DnsConfigServiceWinTest, ParseSearchList) {
+ const struct TestCase {
+ const wchar_t* input;
+ const char* output[4]; // NULL-terminated, empty if expected false
+ } cases[] = {
+ { L"chromium.org", { "chromium.org", NULL } },
+ { L"chromium.org,org", { "chromium.org", "org", NULL } },
+ // Empty suffixes terminate the list
+ { L"crbug.com,com,,org", { "crbug.com", "com", NULL } },
+ // IDN are converted to punycode
+ { L"\u017c\xf3\u0142ta.pi\u0119\u015b\u0107.pl,pl",
+ { "xn--ta-4ja03asj.xn--pi-wla5e0q.pl", "pl", NULL } },
+ // Empty search list is invalid
+ { L"", { NULL } },
+ { L",,", { NULL } },
+ };
+
+ std::vector<std::string> actual_output, expected_output;
+ for (unsigned i = 0; i < arraysize(cases); ++i) {
+ const TestCase& t = cases[i];
+ actual_output.clear();
+ actual_output.push_back("UNSET");
+ expected_output.clear();
+ for (const char* const* output = t.output; *output; ++output) {
+ expected_output.push_back(*output);
+ }
+ bool result = internal::ParseSearchList(t.input, &actual_output);
+ if (!expected_output.empty()) {
+ EXPECT_TRUE(result);
+ EXPECT_EQ(expected_output, actual_output);
+ } else {
+ EXPECT_FALSE(result) << "Unexpected parse success on " << t.input;
+ }
+ }
+}
+
+struct AdapterInfo {
+ IFTYPE if_type;
+ IF_OPER_STATUS oper_status;
+ const WCHAR* dns_suffix;
+ std::string dns_server_addresses[4]; // Empty string indicates end.
+ int ports[4];
+};
+
+scoped_ptr_malloc<IP_ADAPTER_ADDRESSES> CreateAdapterAddresses(
+ const AdapterInfo* infos) {
+ size_t num_adapters = 0;
+ size_t num_addresses = 0;
+ for (size_t i = 0; infos[i].if_type; ++i) {
+ ++num_adapters;
+ for (size_t j = 0; !infos[i].dns_server_addresses[j].empty(); ++j) {
+ ++num_addresses;
+ }
+ }
+
+ size_t heap_size = num_adapters * sizeof(IP_ADAPTER_ADDRESSES) +
+ num_addresses * (sizeof(IP_ADAPTER_DNS_SERVER_ADDRESS) +
+ sizeof(struct sockaddr_storage));
+ scoped_ptr_malloc<IP_ADAPTER_ADDRESSES> heap(
+ reinterpret_cast<IP_ADAPTER_ADDRESSES*>(malloc(heap_size)));
+ CHECK(heap.get());
+ memset(heap.get(), 0, heap_size);
+
+ IP_ADAPTER_ADDRESSES* adapters = heap.get();
+ IP_ADAPTER_DNS_SERVER_ADDRESS* addresses =
+ reinterpret_cast<IP_ADAPTER_DNS_SERVER_ADDRESS*>(adapters + num_adapters);
+ struct sockaddr_storage* storage =
+ reinterpret_cast<struct sockaddr_storage*>(addresses + num_addresses);
+
+ for (size_t i = 0; i < num_adapters; ++i) {
+ const AdapterInfo& info = infos[i];
+ IP_ADAPTER_ADDRESSES* adapter = adapters + i;
+ if (i + 1 < num_adapters)
+ adapter->Next = adapter + 1;
+ adapter->IfType = info.if_type;
+ adapter->OperStatus = info.oper_status;
+ adapter->DnsSuffix = const_cast<PWCHAR>(info.dns_suffix);
+ IP_ADAPTER_DNS_SERVER_ADDRESS* address = NULL;
+ for (size_t j = 0; !info.dns_server_addresses[j].empty(); ++j) {
+ --num_addresses;
+ if (j == 0) {
+ address = adapter->FirstDnsServerAddress = addresses + num_addresses;
+ } else {
+ // Note that |address| is moving backwards.
+ address = address->Next = address - 1;
+ }
+ IPAddressNumber ip;
+ CHECK(ParseIPLiteralToNumber(info.dns_server_addresses[j], &ip));
+ IPEndPoint ipe(ip, info.ports[j]);
+ address->Address.lpSockaddr =
+ reinterpret_cast<LPSOCKADDR>(storage + num_addresses);
+ socklen_t length = sizeof(struct sockaddr_storage);
+ CHECK(ipe.ToSockAddr(address->Address.lpSockaddr, &length));
+ address->Address.iSockaddrLength = static_cast<int>(length);
+ }
+ }
+
+ return heap.Pass();
+}
+
+TEST(DnsConfigServiceWinTest, ConvertAdapterAddresses) {
+ // Check nameservers and connection-specific suffix.
+ const struct TestCase {
+ AdapterInfo input_adapters[4]; // |if_type| == 0 indicates end.
+ std::string expected_nameservers[4]; // Empty string indicates end.
+ std::string expected_suffix;
+ int expected_ports[4];
+ } cases[] = {
+ { // Ignore loopback and inactive adapters.
+ {
+ { IF_TYPE_SOFTWARE_LOOPBACK, IfOperStatusUp, L"funnyloop",
+ { "2.0.0.2" } },
+ { IF_TYPE_FASTETHER, IfOperStatusDormant, L"example.com",
+ { "1.0.0.1" } },
+ { IF_TYPE_USB, IfOperStatusUp, L"chromium.org",
+ { "10.0.0.10", "2001:FFFF::1111" } },
+ { 0 },
+ },
+ { "10.0.0.10", "2001:FFFF::1111" },
+ "chromium.org",
+ },
+ { // Respect configured ports.
+ {
+ { IF_TYPE_USB, IfOperStatusUp, L"chromium.org",
+ { "10.0.0.10", "2001:FFFF::1111" }, { 1024, 24 } },
+ { 0 },
+ },
+ { "10.0.0.10", "2001:FFFF::1111" },
+ "chromium.org",
+ { 1024, 24 },
+ },
+ { // Use the preferred adapter (first in binding order) and filter
+ // stateless DNS discovery addresses.
+ {
+ { IF_TYPE_SOFTWARE_LOOPBACK, IfOperStatusUp, L"funnyloop",
+ { "2.0.0.2" } },
+ { IF_TYPE_FASTETHER, IfOperStatusUp, L"example.com",
+ { "1.0.0.1", "fec0:0:0:ffff::2", "8.8.8.8" } },
+ { IF_TYPE_USB, IfOperStatusUp, L"chromium.org",
+ { "10.0.0.10", "2001:FFFF::1111" } },
+ { 0 },
+ },
+ { "1.0.0.1", "8.8.8.8" },
+ "example.com",
+ },
+ { // No usable adapters.
+ {
+ { IF_TYPE_SOFTWARE_LOOPBACK, IfOperStatusUp, L"localhost",
+ { "2.0.0.2" } },
+ { IF_TYPE_FASTETHER, IfOperStatusDormant, L"example.com",
+ { "1.0.0.1" } },
+ { IF_TYPE_USB, IfOperStatusUp, L"chromium.org" },
+ { 0 },
+ },
+ },
+ };
+
+ for (size_t i = 0; i < arraysize(cases); ++i) {
+ const TestCase& t = cases[i];
+ internal::DnsSystemSettings settings = {
+ CreateAdapterAddresses(t.input_adapters),
+ // Default settings for the rest.
+ };
+ std::vector<IPEndPoint> expected_nameservers;
+ for (size_t j = 0; !t.expected_nameservers[j].empty(); ++j) {
+ IPAddressNumber ip;
+ ASSERT_TRUE(ParseIPLiteralToNumber(t.expected_nameservers[j], &ip));
+ int port = t.expected_ports[j];
+ if (!port)
+ port = dns_protocol::kDefaultPort;
+ expected_nameservers.push_back(IPEndPoint(ip, port));
+ }
+
+ DnsConfig config;
+ internal::ConfigParseWinResult result =
+ internal::ConvertSettingsToDnsConfig(settings, &config);
+ internal::ConfigParseWinResult expected_result =
+ expected_nameservers.empty() ? internal::CONFIG_PARSE_WIN_NO_NAMESERVERS
+ : internal::CONFIG_PARSE_WIN_OK;
+ EXPECT_EQ(expected_result, result);
+ EXPECT_EQ(expected_nameservers, config.nameservers);
+ if (result == internal::CONFIG_PARSE_WIN_OK) {
+ ASSERT_EQ(1u, config.search.size());
+ EXPECT_EQ(t.expected_suffix, config.search[0]);
+ }
+ }
+}
+
+TEST(DnsConfigServiceWinTest, ConvertSuffixSearch) {
+ AdapterInfo infos[2] = {
+ { IF_TYPE_USB, IfOperStatusUp, L"connection.suffix", { "1.0.0.1" } },
+ { 0 },
+ };
+
+ const struct TestCase {
+ internal::DnsSystemSettings input_settings;
+ std::string expected_search[5];
+ } cases[] = {
+ { // Policy SearchList override.
+ {
+ CreateAdapterAddresses(infos),
+ { true, L"policy.searchlist.a,policy.searchlist.b" },
+ { true, L"tcpip.searchlist.a,tcpip.searchlist.b" },
+ { true, L"tcpip.domain" },
+ { true, L"primary.dns.suffix" },
+ },
+ { "policy.searchlist.a", "policy.searchlist.b" },
+ },
+ { // User-specified SearchList override.
+ {
+ CreateAdapterAddresses(infos),
+ { false },
+ { true, L"tcpip.searchlist.a,tcpip.searchlist.b" },
+ { true, L"tcpip.domain" },
+ { true, L"primary.dns.suffix" },
+ },
+ { "tcpip.searchlist.a", "tcpip.searchlist.b" },
+ },
+ { // Void SearchList. Using tcpip.domain
+ {
+ CreateAdapterAddresses(infos),
+ { true, L",bad.searchlist,parsed.as.empty" },
+ { true, L"tcpip.searchlist,good.but.overridden" },
+ { true, L"tcpip.domain" },
+ { false },
+ },
+ { "tcpip.domain", "connection.suffix" },
+ },
+ { // Void SearchList. Using primary.dns.suffix
+ {
+ CreateAdapterAddresses(infos),
+ { true, L",bad.searchlist,parsed.as.empty" },
+ { true, L"tcpip.searchlist,good.but.overridden" },
+ { true, L"tcpip.domain" },
+ { true, L"primary.dns.suffix" },
+ },
+ { "primary.dns.suffix", "connection.suffix" },
+ },
+ { // Void SearchList. Using tcpip.domain when primary.dns.suffix is empty
+ {
+ CreateAdapterAddresses(infos),
+ { true, L",bad.searchlist,parsed.as.empty" },
+ { true, L"tcpip.searchlist,good.but.overridden" },
+ { true, L"tcpip.domain" },
+ { true, L"" },
+ },
+ { "tcpip.domain", "connection.suffix" },
+ },
+ { // Void SearchList. Using tcpip.domain when primary.dns.suffix is NULL
+ {
+ CreateAdapterAddresses(infos),
+ { true, L",bad.searchlist,parsed.as.empty" },
+ { true, L"tcpip.searchlist,good.but.overridden" },
+ { true, L"tcpip.domain" },
+ { true },
+ },
+ { "tcpip.domain", "connection.suffix" },
+ },
+ { // No primary suffix. Devolution does not matter.
+ {
+ CreateAdapterAddresses(infos),
+ { false },
+ { false },
+ { true },
+ { true },
+ { { true, 1 }, { true, 2 } },
+ },
+ { "connection.suffix" },
+ },
+ { // Devolution enabled by policy, level by dnscache.
+ {
+ CreateAdapterAddresses(infos),
+ { false },
+ { false },
+ { true, L"a.b.c.d.e" },
+ { false },
+ { { true, 1 }, { false } }, // policy_devolution: enabled, level
+ { { true, 0 }, { true, 3 } }, // dnscache_devolution
+ { { true, 0 }, { true, 1 } }, // tcpip_devolution
+ },
+ { "a.b.c.d.e", "connection.suffix", "b.c.d.e", "c.d.e" },
+ },
+ { // Devolution enabled by dnscache, level by policy.
+ {
+ CreateAdapterAddresses(infos),
+ { false },
+ { false },
+ { true, L"a.b.c.d.e" },
+ { true, L"f.g.i.l.j" },
+ { { false }, { true, 4 } },
+ { { true, 1 }, { false } },
+ { { true, 0 }, { true, 3 } },
+ },
+ { "f.g.i.l.j", "connection.suffix", "g.i.l.j" },
+ },
+ { // Devolution enabled by default.
+ {
+ CreateAdapterAddresses(infos),
+ { false },
+ { false },
+ { true, L"a.b.c.d.e" },
+ { false },
+ { { false }, { false } },
+ { { false }, { true, 3 } },
+ { { false }, { true, 1 } },
+ },
+ { "a.b.c.d.e", "connection.suffix", "b.c.d.e", "c.d.e" },
+ },
+ { // Devolution enabled at level = 2, but nothing to devolve.
+ {
+ CreateAdapterAddresses(infos),
+ { false },
+ { false },
+ { true, L"a.b" },
+ { false },
+ { { false }, { false } },
+ { { false }, { true, 2 } },
+ { { false }, { true, 2 } },
+ },
+ { "a.b", "connection.suffix" },
+ },
+ { // Devolution disabled when no explicit level.
+ // Windows XP and Vista use a default level = 2, but we don't.
+ {
+ CreateAdapterAddresses(infos),
+ { false },
+ { false },
+ { true, L"a.b.c.d.e" },
+ { false },
+ { { true, 1 }, { false } },
+ { { true, 1 }, { false } },
+ { { true, 1 }, { false } },
+ },
+ { "a.b.c.d.e", "connection.suffix" },
+ },
+ { // Devolution disabled by policy level.
+ {
+ CreateAdapterAddresses(infos),
+ { false },
+ { false },
+ { true, L"a.b.c.d.e" },
+ { false },
+ { { false }, { true, 1 } },
+ { { true, 1 }, { true, 3 } },
+ { { true, 1 }, { true, 4 } },
+ },
+ { "a.b.c.d.e", "connection.suffix" },
+ },
+ { // Devolution disabled by user setting.
+ {
+ CreateAdapterAddresses(infos),
+ { false },
+ { false },
+ { true, L"a.b.c.d.e" },
+ { false },
+ { { false }, { true, 3 } },
+ { { false }, { true, 3 } },
+ { { true, 0 }, { true, 3 } },
+ },
+ { "a.b.c.d.e", "connection.suffix" },
+ },
+ };
+
+ for (size_t i = 0; i < arraysize(cases); ++i) {
+ const TestCase& t = cases[i];
+ DnsConfig config;
+ EXPECT_EQ(internal::CONFIG_PARSE_WIN_OK,
+ internal::ConvertSettingsToDnsConfig(t.input_settings, &config));
+ std::vector<std::string> expected_search;
+ for (size_t j = 0; !t.expected_search[j].empty(); ++j) {
+ expected_search.push_back(t.expected_search[j]);
+ }
+ EXPECT_EQ(expected_search, config.search);
+ }
+}
+
+TEST(DnsConfigServiceWinTest, AppendToMultiLabelName) {
+ AdapterInfo infos[2] = {
+ { IF_TYPE_USB, IfOperStatusUp, L"connection.suffix", { "1.0.0.1" } },
+ { 0 },
+ };
+
+ // The default setting was true pre-Vista.
+ bool default_value = (base::win::GetVersion() < base::win::VERSION_VISTA);
+
+ const struct TestCase {
+ internal::DnsSystemSettings::RegDword input;
+ bool expected_output;
+ } cases[] = {
+ { { true, 0 }, false },
+ { { true, 1 }, true },
+ { { false, 0 }, default_value },
+ };
+
+ for (size_t i = 0; i < arraysize(cases); ++i) {
+ const TestCase& t = cases[i];
+ internal::DnsSystemSettings settings = {
+ CreateAdapterAddresses(infos),
+ { false }, { false }, { false }, { false },
+ { { false }, { false } },
+ { { false }, { false } },
+ { { false }, { false } },
+ t.input,
+ };
+ DnsConfig config;
+ EXPECT_EQ(internal::CONFIG_PARSE_WIN_OK,
+ internal::ConvertSettingsToDnsConfig(settings, &config));
+ EXPECT_EQ(config.append_to_multi_label_name, t.expected_output);
+ }
+}
+
+} // namespace
+
+} // namespace net
+
diff --git a/chromium/net/dns/dns_hosts.cc b/chromium/net/dns/dns_hosts.cc
new file mode 100644
index 00000000000..852d35c8bb4
--- /dev/null
+++ b/chromium/net/dns/dns_hosts.cc
@@ -0,0 +1,169 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_hosts.h"
+
+#include "base/file_util.h"
+#include "base/logging.h"
+#include "base/metrics/histogram.h"
+#include "base/strings/string_util.h"
+#include "base/strings/string_tokenizer.h"
+
+using base::StringPiece;
+
+namespace net {
+
+// Parses the contents of a hosts file. Returns one token (IP or hostname) at
+// a time. Doesn't copy anything; accepts the file as a StringPiece and
+// returns tokens as StringPieces.
+class HostsParser {
+ public:
+ explicit HostsParser(const StringPiece& text)
+ : text_(text),
+ data_(text.data()),
+ end_(text.size()),
+ pos_(0),
+ token_(),
+ token_is_ip_(false) {}
+
+ // Advances to the next token (IP or hostname). Returns whether another
+ // token was available. |token_is_ip| and |token| can be used to find out
+ // the type and text of the token.
+ bool Advance() {
+ bool next_is_ip = (pos_ == 0);
+ while (pos_ < end_ && pos_ != std::string::npos) {
+ switch (text_[pos_]) {
+ case ' ':
+ case '\t':
+ SkipWhitespace();
+ break;
+
+ case '\r':
+ case '\n':
+ next_is_ip = true;
+ pos_++;
+ break;
+
+ case '#':
+ SkipRestOfLine();
+ break;
+
+ default: {
+ size_t token_start = pos_;
+ SkipToken();
+ size_t token_end = (pos_ == std::string::npos) ? end_ : pos_;
+
+ token_ = StringPiece(data_ + token_start, token_end - token_start);
+ token_is_ip_ = next_is_ip;
+
+ return true;
+ }
+ }
+ }
+
+ text_ = StringPiece();
+ return false;
+ }
+
+ // Fast-forwards the parser to the next line. Should be called if an IP
+ // address doesn't parse, to avoid wasting time tokenizing hostnames that
+ // will be ignored.
+ void SkipRestOfLine() {
+ pos_ = text_.find("\n", pos_);
+ }
+
+ // Returns whether the last-parsed token is an IP address (true) or a
+ // hostname (false).
+ bool token_is_ip() { return token_is_ip_; }
+
+ // Returns the text of the last-parsed token as a StringPiece referencing
+ // the same underlying memory as the StringPiece passed to the constructor.
+ // Returns an empty StringPiece if no token has been parsed or the end of
+ // the input string has been reached.
+ const StringPiece& token() { return token_; }
+
+ private:
+ void SkipToken() {
+ pos_ = text_.find_first_of(" \t\n\r#", pos_);
+ }
+
+ void SkipWhitespace() {
+ pos_ = text_.find_first_not_of(" \t", pos_);
+ }
+
+ StringPiece text_;
+ const char* data_;
+ const size_t end_;
+
+ size_t pos_;
+ StringPiece token_;
+ bool token_is_ip_;
+
+ DISALLOW_COPY_AND_ASSIGN(HostsParser);
+};
+
+
+
+void ParseHosts(const std::string& contents, DnsHosts* dns_hosts) {
+ CHECK(dns_hosts);
+ DnsHosts& hosts = *dns_hosts;
+
+ StringPiece ip_text;
+ IPAddressNumber ip;
+ AddressFamily family = ADDRESS_FAMILY_IPV4;
+ HostsParser parser(contents);
+ while (parser.Advance()) {
+ if (parser.token_is_ip()) {
+ StringPiece new_ip_text = parser.token();
+ // Some ad-blocking hosts files contain thousands of entries pointing to
+ // the same IP address (usually 127.0.0.1). Don't bother parsing the IP
+ // again if it's the same as the one above it.
+ if (new_ip_text != ip_text) {
+ IPAddressNumber new_ip;
+ if (ParseIPLiteralToNumber(parser.token().as_string(), &new_ip)) {
+ ip_text = new_ip_text;
+ ip.swap(new_ip);
+ family = (ip.size() == 4) ? ADDRESS_FAMILY_IPV4 : ADDRESS_FAMILY_IPV6;
+ } else {
+ parser.SkipRestOfLine();
+ }
+ }
+ } else {
+ DnsHostsKey key(parser.token().as_string(), family);
+ StringToLowerASCII(&key.first);
+ IPAddressNumber& mapped_ip = hosts[key];
+ if (mapped_ip.empty())
+ mapped_ip = ip;
+ // else ignore this entry (first hit counts)
+ }
+ }
+}
+
+bool ParseHostsFile(const base::FilePath& path, DnsHosts* dns_hosts) {
+ dns_hosts->clear();
+ // Missing file indicates empty HOSTS.
+ if (!base::PathExists(path))
+ return true;
+
+ int64 size;
+ if (!file_util::GetFileSize(path, &size))
+ return false;
+
+ UMA_HISTOGRAM_COUNTS("AsyncDNS.HostsSize", size);
+
+ // Reject HOSTS files larger than |kMaxHostsSize| bytes.
+ const int64 kMaxHostsSize = 1 << 25; // 32MB
+ if (size > kMaxHostsSize)
+ return false;
+
+ std::string contents;
+ if (!file_util::ReadFileToString(path, &contents))
+ return false;
+
+ ParseHosts(contents, dns_hosts);
+ return true;
+}
+
+} // namespace net
+
diff --git a/chromium/net/dns/dns_hosts.h b/chromium/net/dns/dns_hosts.h
new file mode 100644
index 00000000000..c2b290907ad
--- /dev/null
+++ b/chromium/net/dns/dns_hosts.h
@@ -0,0 +1,79 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_HOSTS_H_
+#define NET_DNS_DNS_HOSTS_H_
+
+#include <map>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "base/basictypes.h"
+#include "base/containers/hash_tables.h"
+#include "base/files/file_path.h"
+#include "net/base/address_family.h"
+#include "net/base/net_export.h"
+#include "net/base/net_util.h" // can't forward-declare IPAddressNumber
+
+namespace net {
+ typedef std::pair<std::string, AddressFamily> DnsHostsKey;
+};
+
+namespace BASE_HASH_NAMESPACE {
+#if defined(COMPILER_GCC)
+
+template<>
+struct hash<net::DnsHostsKey> {
+ std::size_t operator()(const net::DnsHostsKey& key) const {
+ hash<base::StringPiece> string_piece_hash;
+ return string_piece_hash(key.first) + key.second;
+ }
+};
+
+#elif defined(COMPILER_MSVC)
+
+inline size_t hash_value(const net::DnsHostsKey& key) {
+ return hash_value(key.first) + key.second;
+}
+
+#endif // COMPILER
+
+} // namespace BASE_HASH_NAMESPACE
+
+namespace net {
+
+// Parsed results of a Hosts file.
+//
+// Although Hosts files map IP address to a list of domain names, for name
+// resolution the desired mapping direction is: domain name to IP address.
+// When parsing Hosts, we apply the "first hit" rule as Windows and glibc do.
+// With a Hosts file of:
+// 300.300.300.300 localhost # bad ip
+// 127.0.0.1 localhost
+// 10.0.0.1 localhost
+// The expected resolution of localhost is 127.0.0.1.
+#if !defined(OS_ANDROID)
+typedef base::hash_map<DnsHostsKey, IPAddressNumber> DnsHosts;
+#else
+// Android's hash_map doesn't support ==, so fall back to map. (Chromium on
+// Android doesn't use the built-in DNS resolver anyway, so it's irrelevant.)
+typedef std::map<DnsHostsKey, IPAddressNumber> DnsHosts;
+#endif
+
+// Parses |contents| (as read from /etc/hosts or equivalent) and stores results
+// in |dns_hosts|. Invalid lines are ignored (as in most implementations).
+void NET_EXPORT_PRIVATE ParseHosts(const std::string& contents,
+ DnsHosts* dns_hosts);
+
+// As above but reads the file pointed to by |path|.
+bool NET_EXPORT_PRIVATE ParseHostsFile(const base::FilePath& path,
+ DnsHosts* dns_hosts);
+
+
+
+} // namespace net
+
+#endif // NET_DNS_DNS_HOSTS_H_
+
diff --git a/chromium/net/dns/dns_hosts_unittest.cc b/chromium/net/dns/dns_hosts_unittest.cc
new file mode 100644
index 00000000000..c0e8805fc47
--- /dev/null
+++ b/chromium/net/dns/dns_hosts_unittest.cc
@@ -0,0 +1,124 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_hosts.h"
+
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+TEST(DnsHostsTest, ParseHosts) {
+ std::string contents =
+ "127.0.0.1 localhost\tlocalhost.localdomain # standard\n"
+ "\n"
+ "1.0.0.1 localhost # ignored, first hit above\n"
+ "fe00::x example company # ignored, malformed IPv6\n"
+ "1.0.0.300 company # ignored, malformed IPv4\n"
+ "1.0.0.1 # ignored, missing hostname\n"
+ "1.0.0.1\t CoMpANy # normalized to 'company' \n"
+ "::1\tlocalhost ip6-localhost ip6-loopback # comment # within a comment\n"
+ "\t fe00::0 ip6-localnet\r\n"
+ "2048::2 example\n"
+ "2048::1 company example # ignored for 'example' \n"
+ "127.0.0.1 cache1\n"
+ "127.0.0.1 cache2 # should reuse parsed IP\n"
+ "256.0.0.0 cache3 # bogus IP should not clear parsed IP cache\n"
+ "127.0.0.1 cache4 # should still be reused\n"
+ "127.0.0.2 cache5\n"
+ "gibberish";
+
+ const struct {
+ const char* host;
+ AddressFamily family;
+ const char* ip;
+ } entries[] = {
+ { "localhost", ADDRESS_FAMILY_IPV4, "127.0.0.1" },
+ { "localhost.localdomain", ADDRESS_FAMILY_IPV4, "127.0.0.1" },
+ { "company", ADDRESS_FAMILY_IPV4, "1.0.0.1" },
+ { "localhost", ADDRESS_FAMILY_IPV6, "::1" },
+ { "ip6-localhost", ADDRESS_FAMILY_IPV6, "::1" },
+ { "ip6-loopback", ADDRESS_FAMILY_IPV6, "::1" },
+ { "ip6-localnet", ADDRESS_FAMILY_IPV6, "fe00::0" },
+ { "company", ADDRESS_FAMILY_IPV6, "2048::1" },
+ { "example", ADDRESS_FAMILY_IPV6, "2048::2" },
+ { "cache1", ADDRESS_FAMILY_IPV4, "127.0.0.1" },
+ { "cache2", ADDRESS_FAMILY_IPV4, "127.0.0.1" },
+ { "cache4", ADDRESS_FAMILY_IPV4, "127.0.0.1" },
+ { "cache5", ADDRESS_FAMILY_IPV4, "127.0.0.2" },
+ };
+
+ DnsHosts expected;
+ for (size_t i = 0; i < ARRAYSIZE_UNSAFE(entries); ++i) {
+ DnsHostsKey key(entries[i].host, entries[i].family);
+ IPAddressNumber& ip = expected[key];
+ ASSERT_TRUE(ip.empty());
+ ASSERT_TRUE(ParseIPLiteralToNumber(entries[i].ip, &ip));
+ ASSERT_EQ(ip.size(), (entries[i].family == ADDRESS_FAMILY_IPV4) ? 4u : 16u);
+ }
+
+ DnsHosts hosts;
+ ParseHosts(contents, &hosts);
+ ASSERT_EQ(expected, hosts);
+}
+
+TEST(DnsHostsTest, HostsParser_Empty) {
+ DnsHosts hosts;
+ ParseHosts("", &hosts);
+ EXPECT_EQ(0u, hosts.size());
+}
+
+TEST(DnsHostsTest, HostsParser_OnlyWhitespace) {
+ DnsHosts hosts;
+ ParseHosts(" ", &hosts);
+ EXPECT_EQ(0u, hosts.size());
+}
+
+TEST(DnsHostsTest, HostsParser_EndsWithNothing) {
+ DnsHosts hosts;
+ ParseHosts("127.0.0.1 localhost", &hosts);
+ EXPECT_EQ(1u, hosts.size());
+}
+
+TEST(DnsHostsTest, HostsParser_EndsWithWhitespace) {
+ DnsHosts hosts;
+ ParseHosts("127.0.0.1 localhost ", &hosts);
+ EXPECT_EQ(1u, hosts.size());
+}
+
+TEST(DnsHostsTest, HostsParser_EndsWithComment) {
+ DnsHosts hosts;
+ ParseHosts("127.0.0.1 localhost # comment", &hosts);
+ EXPECT_EQ(1u, hosts.size());
+}
+
+TEST(DnsHostsTest, HostsParser_EndsWithNewline) {
+ DnsHosts hosts;
+ ParseHosts("127.0.0.1 localhost\n", &hosts);
+ EXPECT_EQ(1u, hosts.size());
+}
+
+TEST(DnsHostsTest, HostsParser_EndsWithTwoNewlines) {
+ DnsHosts hosts;
+ ParseHosts("127.0.0.1 localhost\n\n", &hosts);
+ EXPECT_EQ(1u, hosts.size());
+}
+
+TEST(DnsHostsTest, HostsParser_EndsWithNewlineAndWhitespace) {
+ DnsHosts hosts;
+ ParseHosts("127.0.0.1 localhost\n ", &hosts);
+ EXPECT_EQ(1u, hosts.size());
+}
+
+TEST(DnsHostsTest, HostsParser_EndsWithNewlineAndToken) {
+ DnsHosts hosts;
+ ParseHosts("127.0.0.1 localhost\ntoken", &hosts);
+ EXPECT_EQ(1u, hosts.size());
+}
+
+} // namespace
+
+} // namespace net
+
diff --git a/chromium/net/dns/dns_protocol.h b/chromium/net/dns/dns_protocol.h
new file mode 100644
index 00000000000..a8aad65c2fb
--- /dev/null
+++ b/chromium/net/dns/dns_protocol.h
@@ -0,0 +1,143 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_PROTOCOL_H_
+#define NET_DNS_DNS_PROTOCOL_H_
+
+#include "base/basictypes.h"
+#include "net/base/net_export.h"
+
+namespace net {
+
+namespace dns_protocol {
+
+static const uint16 kDefaultPort = 53;
+static const uint16 kDefaultPortMulticast = 5353;
+
+// DNS packet consists of a header followed by questions and/or answers.
+// For the meaning of specific fields, please see RFC 1035 and 2535
+
+// Header format.
+// 1 1 1 1 1 1
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | ID |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// |QR| Opcode |AA|TC|RD|RA| Z|AD|CD| RCODE |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | QDCOUNT |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | ANCOUNT |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | NSCOUNT |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | ARCOUNT |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+
+// Question format.
+// 1 1 1 1 1 1
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | |
+// / QNAME /
+// / /
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | QTYPE |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | QCLASS |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+
+// Answer format.
+// 1 1 1 1 1 1
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | |
+// / /
+// / NAME /
+// | |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | TYPE |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | CLASS |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | TTL |
+// | |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+// | RDLENGTH |
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
+// / RDATA /
+// / /
+// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+
+#pragma pack(push)
+#pragma pack(1)
+
+// On-the-wire header. All uint16 are in network order.
+// Used internally in DnsQuery and DnsResponseParser.
+struct NET_EXPORT_PRIVATE Header {
+ uint16 id;
+ uint16 flags;
+ uint16 qdcount;
+ uint16 ancount;
+ uint16 nscount;
+ uint16 arcount;
+};
+
+#pragma pack(pop)
+
+static const uint8 kLabelMask = 0xc0;
+static const uint8 kLabelPointer = 0xc0;
+static const uint8 kLabelDirect = 0x0;
+static const uint16 kOffsetMask = 0x3fff;
+
+// In MDns the most significant bit of the rrclass is designated as the
+// "cache-flush bit", as described in http://www.rfc-editor.org/rfc/rfc6762.txt
+// section 10.2.
+static const uint16 kMDnsClassMask = 0x7FFF;
+
+static const int kMaxNameLength = 255;
+
+// RFC 1035, section 4.2.1: Messages carried by UDP are restricted to 512
+// bytes (not counting the IP nor UDP headers).
+static const int kMaxUDPSize = 512;
+
+// RFC 6762, section 17: Messages over the local link are restricted by the
+// medium's MTU, and must be under 9000 bytes
+static const int kMaxMulticastSize = 9000;
+
+// DNS class types.
+static const uint16 kClassIN = 1;
+
+// DNS resource record types. See
+// http://www.iana.org/assignments/dns-parameters
+static const uint16 kTypeA = 1;
+static const uint16 kTypeCNAME = 5;
+static const uint16 kTypePTR = 12;
+static const uint16 kTypeTXT = 16;
+static const uint16 kTypeAAAA = 28;
+static const uint16 kTypeSRV = 33;
+static const uint16 kTypeNSEC = 47;
+
+
+// DNS rcode values.
+static const uint8 kRcodeMask = 0xf;
+static const uint8 kRcodeNOERROR = 0;
+static const uint8 kRcodeFORMERR = 1;
+static const uint8 kRcodeSERVFAIL = 2;
+static const uint8 kRcodeNXDOMAIN = 3;
+static const uint8 kRcodeNOTIMP = 4;
+static const uint8 kRcodeREFUSED = 5;
+
+// DNS flags.
+static const uint16 kFlagResponse = 0x8000;
+static const uint16 kFlagRA = 0x80;
+static const uint16 kFlagRD = 0x100;
+static const uint16 kFlagTC = 0x200;
+static const uint16 kFlagAA = 0x400;
+
+} // namespace dns_protocol
+
+} // namespace net
+
+#endif // NET_DNS_DNS_PROTOCOL_H_
diff --git a/chromium/net/dns/dns_query.cc b/chromium/net/dns/dns_query.cc
new file mode 100644
index 00000000000..270757e7be9
--- /dev/null
+++ b/chromium/net/dns/dns_query.cc
@@ -0,0 +1,89 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_query.h"
+
+#include <limits>
+
+#include "base/sys_byteorder.h"
+#include "net/base/big_endian.h"
+#include "net/base/dns_util.h"
+#include "net/base/io_buffer.h"
+#include "net/dns/dns_protocol.h"
+
+namespace net {
+
+// DNS query consists of a 12-byte header followed by a question section.
+// For details, see RFC 1035 section 4.1.1. This header template sets RD
+// bit, which directs the name server to pursue query recursively, and sets
+// the QDCOUNT to 1, meaning the question section has a single entry.
+DnsQuery::DnsQuery(uint16 id, const base::StringPiece& qname, uint16 qtype)
+ : qname_size_(qname.size()) {
+ DCHECK(!DNSDomainToString(qname).empty());
+ // QNAME + QTYPE + QCLASS
+ size_t question_size = qname_size_ + sizeof(uint16) + sizeof(uint16);
+ io_buffer_ = new IOBufferWithSize(sizeof(dns_protocol::Header) +
+ question_size);
+ dns_protocol::Header* header =
+ reinterpret_cast<dns_protocol::Header*>(io_buffer_->data());
+ memset(header, 0, sizeof(dns_protocol::Header));
+ header->id = base::HostToNet16(id);
+ header->flags = base::HostToNet16(dns_protocol::kFlagRD);
+ header->qdcount = base::HostToNet16(1);
+
+ // Write question section after the header.
+ BigEndianWriter writer(reinterpret_cast<char*>(header + 1), question_size);
+ writer.WriteBytes(qname.data(), qname.size());
+ writer.WriteU16(qtype);
+ writer.WriteU16(dns_protocol::kClassIN);
+}
+
+DnsQuery::~DnsQuery() {
+}
+
+DnsQuery* DnsQuery::CloneWithNewId(uint16 id) const {
+ return new DnsQuery(*this, id);
+}
+
+uint16 DnsQuery::id() const {
+ const dns_protocol::Header* header =
+ reinterpret_cast<const dns_protocol::Header*>(io_buffer_->data());
+ return base::NetToHost16(header->id);
+}
+
+base::StringPiece DnsQuery::qname() const {
+ return base::StringPiece(io_buffer_->data() + sizeof(dns_protocol::Header),
+ qname_size_);
+}
+
+uint16 DnsQuery::qtype() const {
+ uint16 type;
+ ReadBigEndian<uint16>(io_buffer_->data() +
+ sizeof(dns_protocol::Header) +
+ qname_size_, &type);
+ return type;
+}
+
+base::StringPiece DnsQuery::question() const {
+ return base::StringPiece(io_buffer_->data() + sizeof(dns_protocol::Header),
+ qname_size_ + sizeof(uint16) + sizeof(uint16));
+}
+
+DnsQuery::DnsQuery(const DnsQuery& orig, uint16 id) {
+ qname_size_ = orig.qname_size_;
+ io_buffer_ = new IOBufferWithSize(orig.io_buffer()->size());
+ memcpy(io_buffer_.get()->data(), orig.io_buffer()->data(),
+ io_buffer_.get()->size());
+ dns_protocol::Header* header =
+ reinterpret_cast<dns_protocol::Header*>(io_buffer_->data());
+ header->id = base::HostToNet16(id);
+}
+
+void DnsQuery::set_flags(uint16 flags) {
+ dns_protocol::Header* header =
+ reinterpret_cast<dns_protocol::Header*>(io_buffer_->data());
+ header->flags = flags;
+}
+
+} // namespace net
diff --git a/chromium/net/dns/dns_query.h b/chromium/net/dns/dns_query.h
new file mode 100644
index 00000000000..e1469bdfbc0
--- /dev/null
+++ b/chromium/net/dns/dns_query.h
@@ -0,0 +1,58 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_QUERY_H_
+#define NET_DNS_DNS_QUERY_H_
+
+#include "base/basictypes.h"
+#include "base/memory/ref_counted.h"
+#include "base/strings/string_piece.h"
+#include "net/base/net_export.h"
+
+namespace net {
+
+class IOBufferWithSize;
+
+// Represents on-the-wire DNS query message as an object.
+// TODO(szym): add support for the OPT pseudo-RR (EDNS0/DNSSEC).
+class NET_EXPORT_PRIVATE DnsQuery {
+ public:
+ // Constructs a query message from |qname| which *MUST* be in a valid
+ // DNS name format, and |qtype|. The qclass is set to IN.
+ DnsQuery(uint16 id, const base::StringPiece& qname, uint16 qtype);
+ ~DnsQuery();
+
+ // Clones |this| verbatim, with ID field of the header set to |id|.
+ DnsQuery* CloneWithNewId(uint16 id) const;
+
+ // DnsQuery field accessors.
+ uint16 id() const;
+ base::StringPiece qname() const;
+ uint16 qtype() const;
+
+ // Returns the Question section of the query. Used when matching the
+ // response.
+ base::StringPiece question() const;
+
+ // IOBuffer accessor to be used for writing out the query.
+ IOBufferWithSize* io_buffer() const { return io_buffer_.get(); }
+
+ void set_flags(uint16 flags);
+
+ private:
+ DnsQuery(const DnsQuery& orig, uint16 id);
+
+ // Size of the DNS name (*NOT* hostname) we are trying to resolve; used
+ // to calculate offsets.
+ size_t qname_size_;
+
+ // Contains query bytes to be consumed by higher level Write() call.
+ scoped_refptr<IOBufferWithSize> io_buffer_;
+
+ DISALLOW_COPY_AND_ASSIGN(DnsQuery);
+};
+
+} // namespace net
+
+#endif // NET_DNS_DNS_QUERY_H_
diff --git a/chromium/net/dns/dns_query_unittest.cc b/chromium/net/dns/dns_query_unittest.cc
new file mode 100644
index 00000000000..ffe02e70a61
--- /dev/null
+++ b/chromium/net/dns/dns_query_unittest.cc
@@ -0,0 +1,69 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_query.h"
+
+#include "base/bind.h"
+#include "net/base/dns_util.h"
+#include "net/base/io_buffer.h"
+#include "net/dns/dns_protocol.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+TEST(DnsQueryTest, Constructor) {
+ // This includes \0 at the end.
+ const char qname_data[] = "\x03""www""\x07""example""\x03""com";
+ const uint8 query_data[] = {
+ // Header
+ 0xbe, 0xef,
+ 0x01, 0x00, // Flags -- set RD (recursion desired) bit.
+ 0x00, 0x01, // Set QDCOUNT (question count) to 1, all the
+ // rest are 0 for a query.
+ 0x00, 0x00,
+ 0x00, 0x00,
+ 0x00, 0x00,
+
+ // Question
+ 0x03, 'w', 'w', 'w', // QNAME: www.example.com in DNS format.
+ 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e',
+ 0x03, 'c', 'o', 'm',
+ 0x00,
+
+ 0x00, 0x01, // QTYPE: A query.
+ 0x00, 0x01, // QCLASS: IN class.
+ };
+
+ base::StringPiece qname(qname_data, sizeof(qname_data));
+ DnsQuery q1(0xbeef, qname, dns_protocol::kTypeA);
+ EXPECT_EQ(dns_protocol::kTypeA, q1.qtype());
+
+ ASSERT_EQ(static_cast<int>(sizeof(query_data)), q1.io_buffer()->size());
+ EXPECT_EQ(0, memcmp(q1.io_buffer()->data(), query_data, sizeof(query_data)));
+ EXPECT_EQ(qname, q1.qname());
+
+ base::StringPiece question(reinterpret_cast<const char*>(query_data) + 12,
+ 21);
+ EXPECT_EQ(question, q1.question());
+}
+
+TEST(DnsQueryTest, Clone) {
+ // This includes \0 at the end.
+ const char qname_data[] = "\x03""www""\x07""example""\x03""com";
+ base::StringPiece qname(qname_data, sizeof(qname_data));
+
+ DnsQuery q1(0, qname, dns_protocol::kTypeA);
+ EXPECT_EQ(0, q1.id());
+ scoped_ptr<DnsQuery> q2(q1.CloneWithNewId(42));
+ EXPECT_EQ(42, q2->id());
+ EXPECT_EQ(q1.io_buffer()->size(), q2->io_buffer()->size());
+ EXPECT_EQ(q1.qtype(), q2->qtype());
+ EXPECT_EQ(q1.question(), q2->question());
+}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/dns/dns_response.cc b/chromium/net/dns/dns_response.cc
new file mode 100644
index 00000000000..d29d3c4813c
--- /dev/null
+++ b/chromium/net/dns/dns_response.cc
@@ -0,0 +1,337 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_response.h"
+
+#include "base/strings/string_util.h"
+#include "base/sys_byteorder.h"
+#include "net/base/address_list.h"
+#include "net/base/big_endian.h"
+#include "net/base/dns_util.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_errors.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/dns_query.h"
+
+namespace net {
+
+DnsResourceRecord::DnsResourceRecord() {
+}
+
+DnsResourceRecord::~DnsResourceRecord() {
+}
+
+DnsRecordParser::DnsRecordParser() : packet_(NULL), length_(0), cur_(0) {
+}
+
+DnsRecordParser::DnsRecordParser(const void* packet,
+ size_t length,
+ size_t offset)
+ : packet_(reinterpret_cast<const char*>(packet)),
+ length_(length),
+ cur_(packet_ + offset) {
+ DCHECK_LE(offset, length);
+}
+
+unsigned DnsRecordParser::ReadName(const void* const vpos,
+ std::string* out) const {
+ const char* const pos = reinterpret_cast<const char*>(vpos);
+ DCHECK(packet_);
+ DCHECK_LE(packet_, pos);
+ DCHECK_LE(pos, packet_ + length_);
+
+ const char* p = pos;
+ const char* end = packet_ + length_;
+ // Count number of seen bytes to detect loops.
+ unsigned seen = 0;
+ // Remember how many bytes were consumed before first jump.
+ unsigned consumed = 0;
+
+ if (pos >= end)
+ return 0;
+
+ if (out) {
+ out->clear();
+ out->reserve(dns_protocol::kMaxNameLength);
+ }
+
+ for (;;) {
+ // The first two bits of the length give the type of the length. It's
+ // either a direct length or a pointer to the remainder of the name.
+ switch (*p & dns_protocol::kLabelMask) {
+ case dns_protocol::kLabelPointer: {
+ if (p + sizeof(uint16) > end)
+ return 0;
+ if (consumed == 0) {
+ consumed = p - pos + sizeof(uint16);
+ if (!out)
+ return consumed; // If name is not stored, that's all we need.
+ }
+ seen += sizeof(uint16);
+ // If seen the whole packet, then we must be in a loop.
+ if (seen > length_)
+ return 0;
+ uint16 offset;
+ ReadBigEndian<uint16>(p, &offset);
+ offset &= dns_protocol::kOffsetMask;
+ p = packet_ + offset;
+ if (p >= end)
+ return 0;
+ break;
+ }
+ case dns_protocol::kLabelDirect: {
+ uint8 label_len = *p;
+ ++p;
+ // Note: root domain (".") is NOT included.
+ if (label_len == 0) {
+ if (consumed == 0) {
+ consumed = p - pos;
+ } // else we set |consumed| before first jump
+ return consumed;
+ }
+ if (p + label_len >= end)
+ return 0; // Truncated or missing label.
+ if (out) {
+ if (!out->empty())
+ out->append(".");
+ out->append(p, label_len);
+ }
+ p += label_len;
+ seen += 1 + label_len;
+ break;
+ }
+ default:
+ // unhandled label type
+ return 0;
+ }
+ }
+}
+
+bool DnsRecordParser::ReadRecord(DnsResourceRecord* out) {
+ DCHECK(packet_);
+ size_t consumed = ReadName(cur_, &out->name);
+ if (!consumed)
+ return false;
+ BigEndianReader reader(cur_ + consumed,
+ packet_ + length_ - (cur_ + consumed));
+ uint16 rdlen;
+ if (reader.ReadU16(&out->type) &&
+ reader.ReadU16(&out->klass) &&
+ reader.ReadU32(&out->ttl) &&
+ reader.ReadU16(&rdlen) &&
+ reader.ReadPiece(&out->rdata, rdlen)) {
+ cur_ = reader.ptr();
+ return true;
+ }
+ return false;
+}
+
+bool DnsRecordParser::SkipQuestion() {
+ size_t consumed = ReadName(cur_, NULL);
+ if (!consumed)
+ return false;
+
+ const char* next = cur_ + consumed + 2 * sizeof(uint16); // QTYPE + QCLASS
+ if (next > packet_ + length_)
+ return false;
+
+ cur_ = next;
+
+ return true;
+}
+
+DnsResponse::DnsResponse()
+ : io_buffer_(new IOBufferWithSize(dns_protocol::kMaxUDPSize + 1)) {
+}
+
+DnsResponse::DnsResponse(size_t length)
+ : io_buffer_(new IOBufferWithSize(length)) {
+}
+
+DnsResponse::DnsResponse(const void* data,
+ size_t length,
+ size_t answer_offset)
+ : io_buffer_(new IOBufferWithSize(length)),
+ parser_(io_buffer_->data(), length, answer_offset) {
+ DCHECK(data);
+ memcpy(io_buffer_->data(), data, length);
+}
+
+DnsResponse::~DnsResponse() {
+}
+
+bool DnsResponse::InitParse(int nbytes, const DnsQuery& query) {
+ DCHECK_GE(nbytes, 0);
+ // Response includes query, it should be at least that size.
+ if (nbytes < query.io_buffer()->size() || nbytes >= io_buffer_->size())
+ return false;
+
+ // Match the query id.
+ if (base::NetToHost16(header()->id) != query.id())
+ return false;
+
+ // Match question count.
+ if (base::NetToHost16(header()->qdcount) != 1)
+ return false;
+
+ // Match the question section.
+ const size_t hdr_size = sizeof(dns_protocol::Header);
+ const base::StringPiece question = query.question();
+ if (question != base::StringPiece(io_buffer_->data() + hdr_size,
+ question.size())) {
+ return false;
+ }
+
+ // Construct the parser.
+ parser_ = DnsRecordParser(io_buffer_->data(),
+ nbytes,
+ hdr_size + question.size());
+ return true;
+}
+
+bool DnsResponse::InitParseWithoutQuery(int nbytes) {
+ DCHECK_GE(nbytes, 0);
+
+ size_t hdr_size = sizeof(dns_protocol::Header);
+
+ if (nbytes < static_cast<int>(hdr_size) || nbytes >= io_buffer_->size())
+ return false;
+
+ parser_ = DnsRecordParser(
+ io_buffer_->data(), nbytes, hdr_size);
+
+ unsigned qdcount = base::NetToHost16(header()->qdcount);
+ for (unsigned i = 0; i < qdcount; ++i) {
+ if (!parser_.SkipQuestion()) {
+ parser_ = DnsRecordParser(); // Make parser invalid again.
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool DnsResponse::IsValid() const {
+ return parser_.IsValid();
+}
+
+uint16 DnsResponse::flags() const {
+ DCHECK(parser_.IsValid());
+ return base::NetToHost16(header()->flags) & ~(dns_protocol::kRcodeMask);
+}
+
+uint8 DnsResponse::rcode() const {
+ DCHECK(parser_.IsValid());
+ return base::NetToHost16(header()->flags) & dns_protocol::kRcodeMask;
+}
+
+unsigned DnsResponse::answer_count() const {
+ DCHECK(parser_.IsValid());
+ return base::NetToHost16(header()->ancount);
+}
+
+unsigned DnsResponse::additional_answer_count() const {
+ DCHECK(parser_.IsValid());
+ return base::NetToHost16(header()->arcount);
+}
+
+base::StringPiece DnsResponse::qname() const {
+ DCHECK(parser_.IsValid());
+ // The response is HEADER QNAME QTYPE QCLASS ANSWER.
+ // |parser_| is positioned at the beginning of ANSWER, so the end of QNAME is
+ // two uint16s before it.
+ const size_t hdr_size = sizeof(dns_protocol::Header);
+ const size_t qname_size = parser_.GetOffset() - 2 * sizeof(uint16) - hdr_size;
+ return base::StringPiece(io_buffer_->data() + hdr_size, qname_size);
+}
+
+uint16 DnsResponse::qtype() const {
+ DCHECK(parser_.IsValid());
+ // QTYPE starts where QNAME ends.
+ const size_t type_offset = parser_.GetOffset() - 2 * sizeof(uint16);
+ uint16 type;
+ ReadBigEndian<uint16>(io_buffer_->data() + type_offset, &type);
+ return type;
+}
+
+std::string DnsResponse::GetDottedName() const {
+ return DNSDomainToString(qname());
+}
+
+DnsRecordParser DnsResponse::Parser() const {
+ DCHECK(parser_.IsValid());
+ // Return a copy of the parser.
+ return parser_;
+}
+
+const dns_protocol::Header* DnsResponse::header() const {
+ return reinterpret_cast<const dns_protocol::Header*>(io_buffer_->data());
+}
+
+DnsResponse::Result DnsResponse::ParseToAddressList(
+ AddressList* addr_list,
+ base::TimeDelta* ttl) const {
+ DCHECK(IsValid());
+ // DnsTransaction already verified that |response| matches the issued query.
+ // We still need to determine if there is a valid chain of CNAMEs from the
+ // query name to the RR owner name.
+ // We err on the side of caution with the assumption that if we are too picky,
+ // we can always fall back to the system getaddrinfo.
+
+ // Expected owner of record. No trailing dot.
+ std::string expected_name = GetDottedName();
+
+ uint16 expected_type = qtype();
+ DCHECK(expected_type == dns_protocol::kTypeA ||
+ expected_type == dns_protocol::kTypeAAAA);
+
+ size_t expected_size = (expected_type == dns_protocol::kTypeAAAA)
+ ? kIPv6AddressSize : kIPv4AddressSize;
+
+ uint32 ttl_sec = kuint32max;
+ IPAddressList ip_addresses;
+ DnsRecordParser parser = Parser();
+ DnsResourceRecord record;
+ unsigned ancount = answer_count();
+ for (unsigned i = 0; i < ancount; ++i) {
+ if (!parser.ReadRecord(&record))
+ return DNS_MALFORMED_RESPONSE;
+
+ if (record.type == dns_protocol::kTypeCNAME) {
+ // Following the CNAME chain, only if no addresses seen.
+ if (!ip_addresses.empty())
+ return DNS_CNAME_AFTER_ADDRESS;
+
+ if (base::strcasecmp(record.name.c_str(), expected_name.c_str()) != 0)
+ return DNS_NAME_MISMATCH;
+
+ if (record.rdata.size() !=
+ parser.ReadName(record.rdata.begin(), &expected_name))
+ return DNS_MALFORMED_CNAME;
+
+ ttl_sec = std::min(ttl_sec, record.ttl);
+ } else if (record.type == expected_type) {
+ if (record.rdata.size() != expected_size)
+ return DNS_SIZE_MISMATCH;
+
+ if (base::strcasecmp(record.name.c_str(), expected_name.c_str()) != 0)
+ return DNS_NAME_MISMATCH;
+
+ ttl_sec = std::min(ttl_sec, record.ttl);
+ ip_addresses.push_back(IPAddressNumber(record.rdata.begin(),
+ record.rdata.end()));
+ }
+ }
+
+ // TODO(szym): Extract TTL for NODATA results. http://crbug.com/115051
+
+ // getcanonname in eglibc returns the first owner name of an A or AAAA RR.
+ // If the response passed all the checks so far, then |expected_name| is it.
+ *addr_list = AddressList::CreateFromIPAddressList(ip_addresses,
+ expected_name);
+ *ttl = base::TimeDelta::FromSeconds(ttl_sec);
+ return DNS_PARSE_OK;
+}
+
+} // namespace net
diff --git a/chromium/net/dns/dns_response.h b/chromium/net/dns/dns_response.h
new file mode 100644
index 00000000000..57abecf6253
--- /dev/null
+++ b/chromium/net/dns/dns_response.h
@@ -0,0 +1,169 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_RESPONSE_H_
+#define NET_DNS_DNS_RESPONSE_H_
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/memory/ref_counted.h"
+#include "base/strings/string_piece.h"
+#include "base/time/time.h"
+#include "net/base/net_export.h"
+#include "net/base/net_util.h"
+
+namespace net {
+
+class AddressList;
+class DnsQuery;
+class IOBufferWithSize;
+
+namespace dns_protocol {
+struct Header;
+}
+
+// Parsed resource record.
+struct NET_EXPORT_PRIVATE DnsResourceRecord {
+ DnsResourceRecord();
+ ~DnsResourceRecord();
+
+ std::string name; // in dotted form
+ uint16 type;
+ uint16 klass;
+ uint32 ttl;
+ base::StringPiece rdata; // points to the original response buffer
+};
+
+// Iterator to walk over resource records of the DNS response packet.
+class NET_EXPORT_PRIVATE DnsRecordParser {
+ public:
+ // Construct an uninitialized iterator.
+ DnsRecordParser();
+
+ // Construct an iterator to process the |packet| of given |length|.
+ // |offset| points to the beginning of the answer section.
+ DnsRecordParser(const void* packet, size_t length, size_t offset);
+
+ // Returns |true| if initialized.
+ bool IsValid() const { return packet_ != NULL; }
+
+ // Returns |true| if no more bytes remain in the packet.
+ bool AtEnd() const { return cur_ == packet_ + length_; }
+
+ // Returns current offset into the packet.
+ size_t GetOffset() const { return cur_ - packet_; }
+
+ // Parses a (possibly compressed) DNS name from the packet starting at
+ // |pos|. Stores output (even partial) in |out| unless |out| is NULL. |out|
+ // is stored in the dotted form, e.g., "example.com". Returns number of bytes
+ // consumed or 0 on failure.
+ // This is exposed to allow parsing compressed names within RRDATA for TYPEs
+ // such as NS, CNAME, PTR, MX, SOA.
+ // See RFC 1035 section 4.1.4.
+ unsigned ReadName(const void* pos, std::string* out) const;
+
+ // Parses the next resource record into |record|. Returns true if succeeded.
+ bool ReadRecord(DnsResourceRecord* record);
+
+ // Skip a question section, returns true if succeeded.
+ bool SkipQuestion();
+
+ private:
+ const char* packet_;
+ size_t length_;
+ // Current offset within the packet.
+ const char* cur_;
+};
+
+// Buffer-holder for the DNS response allowing easy access to the header fields
+// and resource records. After reading into |io_buffer| must call InitParse to
+// position the RR parser.
+class NET_EXPORT_PRIVATE DnsResponse {
+ public:
+ // Possible results from ParseToAddressList.
+ enum Result {
+ DNS_PARSE_OK = 0,
+ DNS_MALFORMED_RESPONSE, // DnsRecordParser failed before the end of
+ // packet.
+ DNS_MALFORMED_CNAME, // Could not parse CNAME out of RRDATA.
+ DNS_NAME_MISMATCH, // Got an address but no ordered chain of CNAMEs
+ // leads there.
+ DNS_SIZE_MISMATCH, // Got an address but size does not match.
+ DNS_CNAME_AFTER_ADDRESS, // Found CNAME after an address record.
+ DNS_ADDRESS_TTL_MISMATCH, // OBSOLETE. No longer used.
+ DNS_NO_ADDRESSES, // OBSOLETE. No longer used.
+ // Only add new values here.
+ DNS_PARSE_RESULT_MAX, // Bounding value for histograms.
+ };
+
+ // Constructs a response buffer large enough to store one byte more than
+ // largest possible response, to detect malformed responses.
+ DnsResponse();
+
+ // Constructs a response buffer of given length. Used for TCP transactions.
+ explicit DnsResponse(size_t length);
+
+ // Constructs a response from |data|. Used for testing purposes only!
+ DnsResponse(const void* data, size_t length, size_t answer_offset);
+
+ ~DnsResponse();
+
+ // Internal buffer accessor into which actual bytes of response will be
+ // read.
+ IOBufferWithSize* io_buffer() { return io_buffer_.get(); }
+
+ // Assuming the internal buffer holds |nbytes| bytes, returns true iff the
+ // packet matches the |query| id and question.
+ bool InitParse(int nbytes, const DnsQuery& query);
+
+ // Assuming the internal buffer holds |nbytes| bytes, initialize the parser
+ // without matching it against an existing query.
+ bool InitParseWithoutQuery(int nbytes);
+
+ // Returns true if response is valid, that is, after successful InitParse.
+ bool IsValid() const;
+
+ // All of the methods below are valid only if the response is valid.
+
+ // Accessors for the header.
+ uint16 flags() const; // excluding rcode
+ uint8 rcode() const;
+
+ unsigned answer_count() const;
+ unsigned additional_answer_count() const;
+
+ // Accessors to the question. The qname is unparsed.
+ base::StringPiece qname() const;
+ uint16 qtype() const;
+
+ // Returns qname in dotted format.
+ std::string GetDottedName() const;
+
+ // Returns an iterator to the resource records in the answer section.
+ // The iterator is valid only in the scope of the DnsResponse.
+ // This operation is idempotent.
+ DnsRecordParser Parser() const;
+
+ // Extracts an AddressList from this response. Returns SUCCESS if succeeded.
+ // Otherwise returns a detailed error number.
+ Result ParseToAddressList(AddressList* addr_list, base::TimeDelta* ttl) const;
+
+ private:
+ // Convenience for header access.
+ const dns_protocol::Header* header() const;
+
+ // Buffer into which response bytes are read.
+ scoped_refptr<IOBufferWithSize> io_buffer_;
+
+ // Iterator constructed after InitParse positioned at the answer section.
+ // It is never updated afterwards, so can be used in accessors.
+ DnsRecordParser parser_;
+
+ DISALLOW_COPY_AND_ASSIGN(DnsResponse);
+};
+
+} // namespace net
+
+#endif // NET_DNS_DNS_RESPONSE_H_
diff --git a/chromium/net/dns/dns_response_unittest.cc b/chromium/net/dns/dns_response_unittest.cc
new file mode 100644
index 00000000000..12b0377fea6
--- /dev/null
+++ b/chromium/net/dns/dns_response_unittest.cc
@@ -0,0 +1,581 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_response.h"
+
+#include "base/time/time.h"
+#include "net/base/address_list.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_util.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/dns_query.h"
+#include "net/dns/dns_test_util.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+TEST(DnsRecordParserTest, Constructor) {
+ const char data[] = { 0 };
+
+ EXPECT_FALSE(DnsRecordParser().IsValid());
+ EXPECT_TRUE(DnsRecordParser(data, 1, 0).IsValid());
+ EXPECT_TRUE(DnsRecordParser(data, 1, 1).IsValid());
+
+ EXPECT_FALSE(DnsRecordParser(data, 1, 0).AtEnd());
+ EXPECT_TRUE(DnsRecordParser(data, 1, 1).AtEnd());
+}
+
+TEST(DnsRecordParserTest, ReadName) {
+ const uint8 data[] = {
+ // all labels "foo.example.com"
+ 0x03, 'f', 'o', 'o',
+ 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e',
+ 0x03, 'c', 'o', 'm',
+ // byte 0x10
+ 0x00,
+ // byte 0x11
+ // part label, part pointer, "bar.example.com"
+ 0x03, 'b', 'a', 'r',
+ 0xc0, 0x04,
+ // byte 0x17
+ // all pointer to "bar.example.com", 2 jumps
+ 0xc0, 0x11,
+ // byte 0x1a
+ };
+
+ std::string out;
+ DnsRecordParser parser(data, sizeof(data), 0);
+ ASSERT_TRUE(parser.IsValid());
+
+ EXPECT_EQ(0x11u, parser.ReadName(data + 0x00, &out));
+ EXPECT_EQ("foo.example.com", out);
+ // Check that the last "." is never stored.
+ out.clear();
+ EXPECT_EQ(0x1u, parser.ReadName(data + 0x10, &out));
+ EXPECT_EQ("", out);
+ out.clear();
+ EXPECT_EQ(0x6u, parser.ReadName(data + 0x11, &out));
+ EXPECT_EQ("bar.example.com", out);
+ out.clear();
+ EXPECT_EQ(0x2u, parser.ReadName(data + 0x17, &out));
+ EXPECT_EQ("bar.example.com", out);
+
+ // Parse name without storing it.
+ EXPECT_EQ(0x11u, parser.ReadName(data + 0x00, NULL));
+ EXPECT_EQ(0x1u, parser.ReadName(data + 0x10, NULL));
+ EXPECT_EQ(0x6u, parser.ReadName(data + 0x11, NULL));
+ EXPECT_EQ(0x2u, parser.ReadName(data + 0x17, NULL));
+
+ // Check that it works even if initial position is different.
+ parser = DnsRecordParser(data, sizeof(data), 0x12);
+ EXPECT_EQ(0x6u, parser.ReadName(data + 0x11, NULL));
+}
+
+TEST(DnsRecordParserTest, ReadNameFail) {
+ const uint8 data[] = {
+ // label length beyond packet
+ 0x30, 'x', 'x',
+ 0x00,
+ // pointer offset beyond packet
+ 0xc0, 0x20,
+ // pointer loop
+ 0xc0, 0x08,
+ 0xc0, 0x06,
+ // incorrect label type (currently supports only direct and pointer)
+ 0x80, 0x00,
+ // truncated name (missing root label)
+ 0x02, 'x', 'x',
+ };
+
+ DnsRecordParser parser(data, sizeof(data), 0);
+ ASSERT_TRUE(parser.IsValid());
+
+ std::string out;
+ EXPECT_EQ(0u, parser.ReadName(data + 0x00, &out));
+ EXPECT_EQ(0u, parser.ReadName(data + 0x04, &out));
+ EXPECT_EQ(0u, parser.ReadName(data + 0x08, &out));
+ EXPECT_EQ(0u, parser.ReadName(data + 0x0a, &out));
+ EXPECT_EQ(0u, parser.ReadName(data + 0x0c, &out));
+ EXPECT_EQ(0u, parser.ReadName(data + 0x0e, &out));
+}
+
+TEST(DnsRecordParserTest, ReadRecord) {
+ const uint8 data[] = {
+ // Type CNAME record.
+ 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e',
+ 0x03, 'c', 'o', 'm',
+ 0x00,
+ 0x00, 0x05, // TYPE is CNAME.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, 0x24, 0x74, // TTL is 0x00012474.
+ 0x00, 0x06, // RDLENGTH is 6 bytes.
+ 0x03, 'f', 'o', 'o', // compressed name in record
+ 0xc0, 0x00,
+ // Type A record.
+ 0x03, 'b', 'a', 'r', // compressed owner name
+ 0xc0, 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x20, 0x13, 0x55, // TTL is 0x00201355.
+ 0x00, 0x04, // RDLENGTH is 4 bytes.
+ 0x7f, 0x02, 0x04, 0x01, // IP is 127.2.4.1
+ };
+
+ std::string out;
+ DnsRecordParser parser(data, sizeof(data), 0);
+
+ DnsResourceRecord record;
+ EXPECT_TRUE(parser.ReadRecord(&record));
+ EXPECT_EQ("example.com", record.name);
+ EXPECT_EQ(dns_protocol::kTypeCNAME, record.type);
+ EXPECT_EQ(dns_protocol::kClassIN, record.klass);
+ EXPECT_EQ(0x00012474u, record.ttl);
+ EXPECT_EQ(6u, record.rdata.length());
+ EXPECT_EQ(6u, parser.ReadName(record.rdata.data(), &out));
+ EXPECT_EQ("foo.example.com", out);
+ EXPECT_FALSE(parser.AtEnd());
+
+ EXPECT_TRUE(parser.ReadRecord(&record));
+ EXPECT_EQ("bar.example.com", record.name);
+ EXPECT_EQ(dns_protocol::kTypeA, record.type);
+ EXPECT_EQ(dns_protocol::kClassIN, record.klass);
+ EXPECT_EQ(0x00201355u, record.ttl);
+ EXPECT_EQ(4u, record.rdata.length());
+ EXPECT_EQ(base::StringPiece("\x7f\x02\x04\x01"), record.rdata);
+ EXPECT_TRUE(parser.AtEnd());
+
+ // Test truncated record.
+ parser = DnsRecordParser(data, sizeof(data) - 2, 0);
+ EXPECT_TRUE(parser.ReadRecord(&record));
+ EXPECT_FALSE(parser.AtEnd());
+ EXPECT_FALSE(parser.ReadRecord(&record));
+}
+
+TEST(DnsResponseTest, InitParse) {
+ // This includes \0 at the end.
+ const char qname_data[] = "\x0A""codereview""\x08""chromium""\x03""org";
+ const base::StringPiece qname(qname_data, sizeof(qname_data));
+ // Compilers want to copy when binding temporary to const &, so must use heap.
+ scoped_ptr<DnsQuery> query(new DnsQuery(0xcafe, qname, dns_protocol::kTypeA));
+
+ const uint8 response_data[] = {
+ // Header
+ 0xca, 0xfe, // ID
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x01, // 1 question
+ 0x00, 0x02, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Question
+ // This part is echoed back from the respective query.
+ 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w',
+ 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm',
+ 0x03, 'o', 'r', 'g',
+ 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+
+ // Answer 1
+ 0xc0, 0x0c, // NAME is a pointer to name in Question section.
+ 0x00, 0x05, // TYPE is CNAME.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x12, // RDLENGTH is 18 bytes.
+ // ghs.l.google.com in DNS format.
+ 0x03, 'g', 'h', 's',
+ 0x01, 'l',
+ 0x06, 'g', 'o', 'o', 'g', 'l', 'e',
+ 0x03, 'c', 'o', 'm',
+ 0x00,
+
+ // Answer 2
+ 0xc0, 0x35, // NAME is a pointer to name in Answer 1.
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x00, // TTL (4 bytes) is 53 seconds.
+ 0x00, 0x35,
+ 0x00, 0x04, // RDLENGTH is 4 bytes.
+ 0x4a, 0x7d, // RDATA is the IP: 74.125.95.121
+ 0x5f, 0x79,
+ };
+
+ DnsResponse resp;
+ memcpy(resp.io_buffer()->data(), response_data, sizeof(response_data));
+
+ // Reject too short.
+ EXPECT_FALSE(resp.InitParse(query->io_buffer()->size() - 1, *query));
+ EXPECT_FALSE(resp.IsValid());
+
+ // Reject wrong id.
+ scoped_ptr<DnsQuery> other_query(query->CloneWithNewId(0xbeef));
+ EXPECT_FALSE(resp.InitParse(sizeof(response_data), *other_query));
+ EXPECT_FALSE(resp.IsValid());
+
+ // Reject wrong question.
+ scoped_ptr<DnsQuery> wrong_query(
+ new DnsQuery(0xcafe, qname, dns_protocol::kTypeCNAME));
+ EXPECT_FALSE(resp.InitParse(sizeof(response_data), *wrong_query));
+ EXPECT_FALSE(resp.IsValid());
+
+ // Accept matching question.
+ EXPECT_TRUE(resp.InitParse(sizeof(response_data), *query));
+ EXPECT_TRUE(resp.IsValid());
+
+ // Check header access.
+ EXPECT_EQ(0x8180, resp.flags());
+ EXPECT_EQ(0x0, resp.rcode());
+ EXPECT_EQ(2u, resp.answer_count());
+
+ // Check question access.
+ EXPECT_EQ(query->qname(), resp.qname());
+ EXPECT_EQ(query->qtype(), resp.qtype());
+ EXPECT_EQ("codereview.chromium.org", resp.GetDottedName());
+
+ DnsResourceRecord record;
+ DnsRecordParser parser = resp.Parser();
+ EXPECT_TRUE(parser.ReadRecord(&record));
+ EXPECT_FALSE(parser.AtEnd());
+ EXPECT_TRUE(parser.ReadRecord(&record));
+ EXPECT_TRUE(parser.AtEnd());
+ EXPECT_FALSE(parser.ReadRecord(&record));
+}
+
+TEST(DnsResponseTest, InitParseWithoutQuery) {
+ DnsResponse resp;
+ memcpy(resp.io_buffer()->data(), kT0ResponseDatagram,
+ sizeof(kT0ResponseDatagram));
+
+ // Accept matching question.
+ EXPECT_TRUE(resp.InitParseWithoutQuery(sizeof(kT0ResponseDatagram)));
+ EXPECT_TRUE(resp.IsValid());
+
+ // Check header access.
+ EXPECT_EQ(0x8180, resp.flags());
+ EXPECT_EQ(0x0, resp.rcode());
+ EXPECT_EQ(kT0RecordCount, resp.answer_count());
+
+ // Check question access.
+ EXPECT_EQ(kT0Qtype, resp.qtype());
+ EXPECT_EQ(kT0HostName, resp.GetDottedName());
+
+ DnsResourceRecord record;
+ DnsRecordParser parser = resp.Parser();
+ for (unsigned i = 0; i < kT0RecordCount; i ++) {
+ EXPECT_FALSE(parser.AtEnd());
+ EXPECT_TRUE(parser.ReadRecord(&record));
+ }
+ EXPECT_TRUE(parser.AtEnd());
+ EXPECT_FALSE(parser.ReadRecord(&record));
+}
+
+TEST(DnsResponseTest, InitParseWithoutQueryNoQuestions) {
+ const uint8 response_data[] = {
+ // Header
+ 0xca, 0xfe, // ID
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No question
+ 0x00, 0x01, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Answer 1
+ 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w',
+ 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm',
+ 0x03, 'o', 'r', 'g',
+ 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x00, // TTL (4 bytes) is 53 seconds.
+ 0x00, 0x35,
+ 0x00, 0x04, // RDLENGTH is 4 bytes.
+ 0x4a, 0x7d, // RDATA is the IP: 74.125.95.121
+ 0x5f, 0x79,
+ };
+
+ DnsResponse resp;
+ memcpy(resp.io_buffer()->data(), response_data, sizeof(response_data));
+
+ EXPECT_TRUE(resp.InitParseWithoutQuery(sizeof(response_data)));
+
+ // Check header access.
+ EXPECT_EQ(0x8180, resp.flags());
+ EXPECT_EQ(0x0, resp.rcode());
+ EXPECT_EQ(0x1u, resp.answer_count());
+
+ DnsResourceRecord record;
+ DnsRecordParser parser = resp.Parser();
+
+ EXPECT_FALSE(parser.AtEnd());
+ EXPECT_TRUE(parser.ReadRecord(&record));
+ EXPECT_EQ("codereview.chromium.org", record.name);
+ EXPECT_EQ(0x00000035u, record.ttl);
+ EXPECT_EQ(dns_protocol::kTypeA, record.type);
+
+ EXPECT_TRUE(parser.AtEnd());
+ EXPECT_FALSE(parser.ReadRecord(&record));
+}
+
+TEST(DnsResponseTest, InitParseWithoutQueryTwoQuestions) {
+ const uint8 response_data[] = {
+ // Header
+ 0xca, 0xfe, // ID
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x02, // 2 questions
+ 0x00, 0x01, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Question 1
+ 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w',
+ 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm',
+ 0x03, 'o', 'r', 'g',
+ 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+
+ // Question 2
+ 0x0b, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w', '2',
+ 0xc0, 0x18, // pointer to "chromium.org"
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+
+ // Answer 1
+ 0xc0, 0x0c, // NAME is a pointer to name in Question section.
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x00, // TTL (4 bytes) is 53 seconds.
+ 0x00, 0x35,
+ 0x00, 0x04, // RDLENGTH is 4 bytes.
+ 0x4a, 0x7d, // RDATA is the IP: 74.125.95.121
+ 0x5f, 0x79,
+ };
+
+ DnsResponse resp;
+ memcpy(resp.io_buffer()->data(), response_data, sizeof(response_data));
+
+ EXPECT_TRUE(resp.InitParseWithoutQuery(sizeof(response_data)));
+
+ // Check header access.
+ EXPECT_EQ(0x8180, resp.flags());
+ EXPECT_EQ(0x0, resp.rcode());
+ EXPECT_EQ(0x01u, resp.answer_count());
+
+ DnsResourceRecord record;
+ DnsRecordParser parser = resp.Parser();
+
+ EXPECT_FALSE(parser.AtEnd());
+ EXPECT_TRUE(parser.ReadRecord(&record));
+ EXPECT_EQ("codereview.chromium.org", record.name);
+ EXPECT_EQ(0x35u, record.ttl);
+ EXPECT_EQ(dns_protocol::kTypeA, record.type);
+
+ EXPECT_TRUE(parser.AtEnd());
+ EXPECT_FALSE(parser.ReadRecord(&record));
+}
+
+TEST(DnsResponseTest, InitParseWithoutQueryPacketTooShort) {
+ const uint8 response_data[] = {
+ // Header
+ 0xca, 0xfe, // ID
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No question
+ };
+
+ DnsResponse resp;
+ memcpy(resp.io_buffer()->data(), response_data, sizeof(response_data));
+
+ EXPECT_FALSE(resp.InitParseWithoutQuery(sizeof(response_data)));
+}
+
+void VerifyAddressList(const std::vector<const char*>& ip_addresses,
+ const AddressList& addrlist) {
+ ASSERT_EQ(ip_addresses.size(), addrlist.size());
+
+ for (size_t i = 0; i < addrlist.size(); ++i) {
+ EXPECT_EQ(ip_addresses[i], addrlist[i].ToStringWithoutPort());
+ }
+}
+
+TEST(DnsResponseTest, ParseToAddressList) {
+ const struct TestCase {
+ size_t query_size;
+ const uint8* response_data;
+ size_t response_size;
+ const char* const* expected_addresses;
+ size_t num_expected_addresses;
+ const char* expected_cname;
+ int expected_ttl_sec;
+ } cases[] = {
+ {
+ kT0QuerySize,
+ kT0ResponseDatagram, arraysize(kT0ResponseDatagram),
+ kT0IpAddresses, arraysize(kT0IpAddresses),
+ kT0CanonName,
+ kT0TTL,
+ },
+ {
+ kT1QuerySize,
+ kT1ResponseDatagram, arraysize(kT1ResponseDatagram),
+ kT1IpAddresses, arraysize(kT1IpAddresses),
+ kT1CanonName,
+ kT1TTL,
+ },
+ {
+ kT2QuerySize,
+ kT2ResponseDatagram, arraysize(kT2ResponseDatagram),
+ kT2IpAddresses, arraysize(kT2IpAddresses),
+ kT2CanonName,
+ kT2TTL,
+ },
+ {
+ kT3QuerySize,
+ kT3ResponseDatagram, arraysize(kT3ResponseDatagram),
+ kT3IpAddresses, arraysize(kT3IpAddresses),
+ kT3CanonName,
+ kT3TTL,
+ },
+ };
+
+ for (size_t i = 0; i < ARRAYSIZE_UNSAFE(cases); ++i) {
+ const TestCase& t = cases[i];
+ DnsResponse response(t.response_data, t.response_size, t.query_size);
+ AddressList addr_list;
+ base::TimeDelta ttl;
+ EXPECT_EQ(DnsResponse::DNS_PARSE_OK,
+ response.ParseToAddressList(&addr_list, &ttl));
+ std::vector<const char*> expected_addresses(
+ t.expected_addresses,
+ t.expected_addresses + t.num_expected_addresses);
+ VerifyAddressList(expected_addresses, addr_list);
+ EXPECT_EQ(t.expected_cname, addr_list.canonical_name());
+ EXPECT_EQ(base::TimeDelta::FromSeconds(t.expected_ttl_sec), ttl);
+ }
+}
+
+const uint8 kResponseTruncatedRecord[] = {
+ // Header: 1 question, 1 answer RR
+ 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
+ // Question: name = 'a', type = A (0x1)
+ 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01,
+ // Answer: name = 'a', type = A, TTL = 0xFF, RDATA = 10.10.10.10
+ 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF,
+ 0x00, 0x04, 0x0A, 0x0A, 0x0A, // Truncated RDATA.
+};
+
+const uint8 kResponseTruncatedCNAME[] = {
+ // Header: 1 question, 1 answer RR
+ 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
+ // Question: name = 'a', type = A (0x1)
+ 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01,
+ // Answer: name = 'a', type = CNAME, TTL = 0xFF, RDATA = 'foo' (truncated)
+ 0x01, 'a', 0x00, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF,
+ 0x00, 0x03, 0x03, 'f', 'o', // Truncated name.
+};
+
+const uint8 kResponseNameMismatch[] = {
+ // Header: 1 question, 1 answer RR
+ 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
+ // Question: name = 'a', type = A (0x1)
+ 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01,
+ // Answer: name = 'b', type = A, TTL = 0xFF, RDATA = 10.10.10.10
+ 0x01, 'b', 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF,
+ 0x00, 0x04, 0x0A, 0x0A, 0x0A, 0x0A,
+};
+
+const uint8 kResponseNameMismatchInChain[] = {
+ // Header: 1 question, 3 answer RR
+ 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00,
+ // Question: name = 'a', type = A (0x1)
+ 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01,
+ // Answer: name = 'a', type = CNAME, TTL = 0xFF, RDATA = 'b'
+ 0x01, 'a', 0x00, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF,
+ 0x00, 0x03, 0x01, 'b', 0x00,
+ // Answer: name = 'b', type = A, TTL = 0xFF, RDATA = 10.10.10.10
+ 0x01, 'b', 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF,
+ 0x00, 0x04, 0x0A, 0x0A, 0x0A, 0x0A,
+ // Answer: name = 'c', type = A, TTL = 0xFF, RDATA = 10.10.10.11
+ 0x01, 'c', 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF,
+ 0x00, 0x04, 0x0A, 0x0A, 0x0A, 0x0B,
+};
+
+const uint8 kResponseSizeMismatch[] = {
+ // Header: 1 answer RR
+ 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
+ // Question: name = 'a', type = AAAA (0x1c)
+ 0x01, 'a', 0x00, 0x00, 0x1c, 0x00, 0x01,
+ // Answer: name = 'a', type = AAAA, TTL = 0xFF, RDATA = 10.10.10.10
+ 0x01, 'a', 0x00, 0x00, 0x1c, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF,
+ 0x00, 0x04, 0x0A, 0x0A, 0x0A, 0x0A,
+};
+
+const uint8 kResponseCNAMEAfterAddress[] = {
+ // Header: 2 answer RR
+ 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00,
+ // Question: name = 'a', type = A (0x1)
+ 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01,
+ // Answer: name = 'a', type = A, TTL = 0xFF, RDATA = 10.10.10.10.
+ 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF,
+ 0x00, 0x04, 0x0A, 0x0A, 0x0A, 0x0A,
+ // Answer: name = 'a', type = CNAME, TTL = 0xFF, RDATA = 'b'
+ 0x01, 'a', 0x00, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF,
+ 0x00, 0x03, 0x01, 'b', 0x00,
+};
+
+const uint8 kResponseNoAddresses[] = {
+ // Header: 1 question, 1 answer RR, 1 authority RR
+ 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
+ // Question: name = 'a', type = A (0x1)
+ 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01,
+ // Answer: name = 'a', type = CNAME, TTL = 0xFF, RDATA = 'b'
+ 0x01, 'a', 0x00, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF,
+ 0x00, 0x03, 0x01, 'b', 0x00,
+ // Authority section
+ // Answer: name = 'b', type = A, TTL = 0xFF, RDATA = 10.10.10.10
+ 0x01, 'b', 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF,
+ 0x00, 0x04, 0x0A, 0x0A, 0x0A, 0x0A,
+};
+
+TEST(DnsResponseTest, ParseToAddressListFail) {
+ const struct TestCase {
+ const uint8* data;
+ size_t size;
+ DnsResponse::Result expected_result;
+ } cases[] = {
+ { kResponseTruncatedRecord, arraysize(kResponseTruncatedRecord),
+ DnsResponse::DNS_MALFORMED_RESPONSE },
+ { kResponseTruncatedCNAME, arraysize(kResponseTruncatedCNAME),
+ DnsResponse::DNS_MALFORMED_CNAME },
+ { kResponseNameMismatch, arraysize(kResponseNameMismatch),
+ DnsResponse::DNS_NAME_MISMATCH },
+ { kResponseNameMismatchInChain, arraysize(kResponseNameMismatchInChain),
+ DnsResponse::DNS_NAME_MISMATCH },
+ { kResponseSizeMismatch, arraysize(kResponseSizeMismatch),
+ DnsResponse::DNS_SIZE_MISMATCH },
+ { kResponseCNAMEAfterAddress, arraysize(kResponseCNAMEAfterAddress),
+ DnsResponse::DNS_CNAME_AFTER_ADDRESS },
+ // Not actually a failure, just an empty result.
+ { kResponseNoAddresses, arraysize(kResponseNoAddresses),
+ DnsResponse::DNS_PARSE_OK },
+ };
+
+ const size_t kQuerySize = 12 + 7;
+
+ for (size_t i = 0; i < ARRAYSIZE_UNSAFE(cases); ++i) {
+ const TestCase& t = cases[i];
+
+ DnsResponse response(t.data, t.size, kQuerySize);
+ AddressList addr_list;
+ base::TimeDelta ttl;
+ EXPECT_EQ(t.expected_result,
+ response.ParseToAddressList(&addr_list, &ttl));
+ }
+}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/dns/dns_session.cc b/chromium/net/dns/dns_session.cc
new file mode 100644
index 00000000000..ea8b6a14274
--- /dev/null
+++ b/chromium/net/dns/dns_session.cc
@@ -0,0 +1,298 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_session.h"
+
+#include "base/basictypes.h"
+#include "base/bind.h"
+#include "base/lazy_instance.h"
+#include "base/metrics/histogram.h"
+#include "base/metrics/sample_vector.h"
+#include "base/rand_util.h"
+#include "base/stl_util.h"
+#include "base/time/time.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/dns/dns_config_service.h"
+#include "net/dns/dns_socket_pool.h"
+#include "net/socket/stream_socket.h"
+#include "net/udp/datagram_client_socket.h"
+
+namespace net {
+
+namespace {
+// Never exceed max timeout.
+const unsigned kMaxTimeoutMs = 5000;
+// Set min timeout, in case we are talking to a local DNS proxy.
+const unsigned kMinTimeoutMs = 10;
+
+// Number of buckets in the histogram of observed RTTs.
+const size_t kRTTBucketCount = 100;
+// Target percentile in the RTT histogram used for retransmission timeout.
+const unsigned kRTOPercentile = 99;
+} // namespace
+
+// Runtime statistics of DNS server.
+struct DnsSession::ServerStats {
+ ServerStats(base::TimeDelta rtt_estimate_param, RttBuckets* buckets)
+ : last_failure_count(0), rtt_estimate(rtt_estimate_param) {
+ rtt_histogram.reset(new base::SampleVector(buckets));
+ // Seed histogram with 2 samples at |rtt_estimate| timeout.
+ rtt_histogram->Accumulate(rtt_estimate.InMilliseconds(), 2);
+ }
+
+ // Count of consecutive failures after last success.
+ int last_failure_count;
+
+ // Last time when server returned failure or timeout.
+ base::Time last_failure;
+ // Last time when server returned success.
+ base::Time last_success;
+
+ // Estimated RTT using moving average.
+ base::TimeDelta rtt_estimate;
+ // Estimated error in the above.
+ base::TimeDelta rtt_deviation;
+
+ // A histogram of observed RTT .
+ scoped_ptr<base::SampleVector> rtt_histogram;
+
+ DISALLOW_COPY_AND_ASSIGN(ServerStats);
+};
+
+// static
+base::LazyInstance<DnsSession::RttBuckets>::Leaky DnsSession::rtt_buckets_ =
+ LAZY_INSTANCE_INITIALIZER;
+
+DnsSession::RttBuckets::RttBuckets() : base::BucketRanges(kRTTBucketCount + 1) {
+ base::Histogram::InitializeBucketRanges(1, 5000, this);
+}
+
+DnsSession::SocketLease::SocketLease(scoped_refptr<DnsSession> session,
+ unsigned server_index,
+ scoped_ptr<DatagramClientSocket> socket)
+ : session_(session), server_index_(server_index), socket_(socket.Pass()) {}
+
+DnsSession::SocketLease::~SocketLease() {
+ session_->FreeSocket(server_index_, socket_.Pass());
+}
+
+DnsSession::DnsSession(const DnsConfig& config,
+ scoped_ptr<DnsSocketPool> socket_pool,
+ const RandIntCallback& rand_int_callback,
+ NetLog* net_log)
+ : config_(config),
+ socket_pool_(socket_pool.Pass()),
+ rand_callback_(base::Bind(rand_int_callback, 0, kuint16max)),
+ net_log_(net_log),
+ server_index_(0) {
+ socket_pool_->Initialize(&config_.nameservers, net_log);
+ UMA_HISTOGRAM_CUSTOM_COUNTS(
+ "AsyncDNS.ServerCount", config_.nameservers.size(), 0, 10, 10);
+ for (size_t i = 0; i < config_.nameservers.size(); ++i) {
+ server_stats_.push_back(new ServerStats(config_.timeout,
+ rtt_buckets_.Pointer()));
+ }
+}
+
+DnsSession::~DnsSession() {
+ RecordServerStats();
+}
+
+int DnsSession::NextQueryId() const { return rand_callback_.Run(); }
+
+unsigned DnsSession::NextFirstServerIndex() {
+ unsigned index = NextGoodServerIndex(server_index_);
+ if (config_.rotate)
+ server_index_ = (server_index_ + 1) % config_.nameservers.size();
+ return index;
+}
+
+unsigned DnsSession::NextGoodServerIndex(unsigned server_index) {
+ unsigned index = server_index;
+ base::Time oldest_server_failure(base::Time::Now());
+ unsigned oldest_server_failure_index = 0;
+
+ UMA_HISTOGRAM_BOOLEAN("AsyncDNS.ServerIsGood",
+ server_stats_[server_index]->last_failure.is_null());
+
+ do {
+ base::Time cur_server_failure = server_stats_[index]->last_failure;
+ // If number of failures on this server doesn't exceed number of allowed
+ // attempts, return its index.
+ if (server_stats_[server_index]->last_failure_count < config_.attempts) {
+ return index;
+ }
+ // Track oldest failed server.
+ if (cur_server_failure < oldest_server_failure) {
+ oldest_server_failure = cur_server_failure;
+ oldest_server_failure_index = index;
+ }
+ index = (index + 1) % config_.nameservers.size();
+ } while (index != server_index);
+
+ // If we are here it means that there are no successful servers, so we have
+ // to use one that has failed oldest.
+ return oldest_server_failure_index;
+}
+
+void DnsSession::RecordServerFailure(unsigned server_index) {
+ UMA_HISTOGRAM_CUSTOM_COUNTS(
+ "AsyncDNS.ServerFailureIndex", server_index, 0, 10, 10);
+ ++(server_stats_[server_index]->last_failure_count);
+ server_stats_[server_index]->last_failure = base::Time::Now();
+}
+
+void DnsSession::RecordServerSuccess(unsigned server_index) {
+ if (server_stats_[server_index]->last_success.is_null()) {
+ UMA_HISTOGRAM_COUNTS_100("AsyncDNS.ServerFailuresAfterNetworkChange",
+ server_stats_[server_index]->last_failure_count);
+ } else {
+ UMA_HISTOGRAM_COUNTS_100("AsyncDNS.ServerFailuresBeforeSuccess",
+ server_stats_[server_index]->last_failure_count);
+ }
+ server_stats_[server_index]->last_failure_count = 0;
+ server_stats_[server_index]->last_failure = base::Time();
+ server_stats_[server_index]->last_success = base::Time::Now();
+}
+
+void DnsSession::RecordRTT(unsigned server_index, base::TimeDelta rtt) {
+ DCHECK_LT(server_index, server_stats_.size());
+
+ // For measurement, assume it is the first attempt (no backoff).
+ base::TimeDelta timeout_jacobson = NextTimeoutFromJacobson(server_index, 0);
+ base::TimeDelta timeout_histogram = NextTimeoutFromHistogram(server_index, 0);
+ UMA_HISTOGRAM_TIMES("AsyncDNS.TimeoutErrorJacobson", rtt - timeout_jacobson);
+ UMA_HISTOGRAM_TIMES("AsyncDNS.TimeoutErrorHistogram",
+ rtt - timeout_histogram);
+ UMA_HISTOGRAM_TIMES("AsyncDNS.TimeoutErrorJacobsonUnder",
+ timeout_jacobson - rtt);
+ UMA_HISTOGRAM_TIMES("AsyncDNS.TimeoutErrorHistogramUnder",
+ timeout_histogram - rtt);
+
+ // Jacobson/Karels algorithm for TCP.
+ // Using parameters: alpha = 1/8, delta = 1/4, beta = 4
+ base::TimeDelta& estimate = server_stats_[server_index]->rtt_estimate;
+ base::TimeDelta& deviation = server_stats_[server_index]->rtt_deviation;
+ base::TimeDelta current_error = rtt - estimate;
+ estimate += current_error / 8; // * alpha
+ base::TimeDelta abs_error = base::TimeDelta::FromInternalValue(
+ std::abs(current_error.ToInternalValue()));
+ deviation += (abs_error - deviation) / 4; // * delta
+
+ // Histogram-based method.
+ server_stats_[server_index]->rtt_histogram
+ ->Accumulate(rtt.InMilliseconds(), 1);
+}
+
+void DnsSession::RecordLostPacket(unsigned server_index, int attempt) {
+ base::TimeDelta timeout_jacobson =
+ NextTimeoutFromJacobson(server_index, attempt);
+ base::TimeDelta timeout_histogram =
+ NextTimeoutFromHistogram(server_index, attempt);
+ UMA_HISTOGRAM_TIMES("AsyncDNS.TimeoutSpentJacobson", timeout_jacobson);
+ UMA_HISTOGRAM_TIMES("AsyncDNS.TimeoutSpentHistogram", timeout_histogram);
+}
+
+void DnsSession::RecordServerStats() {
+ for (size_t index = 0; index < server_stats_.size(); ++index) {
+ if (server_stats_[index]->last_failure_count) {
+ if (server_stats_[index]->last_success.is_null()) {
+ UMA_HISTOGRAM_COUNTS("AsyncDNS.ServerFailuresWithoutSuccess",
+ server_stats_[index]->last_failure_count);
+ } else {
+ UMA_HISTOGRAM_COUNTS("AsyncDNS.ServerFailuresAfterSuccess",
+ server_stats_[index]->last_failure_count);
+ }
+ }
+ }
+}
+
+
+base::TimeDelta DnsSession::NextTimeout(unsigned server_index, int attempt) {
+ // Respect config timeout if it exceeds |kMaxTimeoutMs|.
+ if (config_.timeout.InMilliseconds() >= kMaxTimeoutMs)
+ return config_.timeout;
+ return NextTimeoutFromHistogram(server_index, attempt);
+}
+
+// Allocate a socket, already connected to the server address.
+scoped_ptr<DnsSession::SocketLease> DnsSession::AllocateSocket(
+ unsigned server_index, const NetLog::Source& source) {
+ scoped_ptr<DatagramClientSocket> socket;
+
+ socket = socket_pool_->AllocateSocket(server_index);
+ if (!socket.get())
+ return scoped_ptr<SocketLease>();
+
+ socket->NetLog().BeginEvent(NetLog::TYPE_SOCKET_IN_USE,
+ source.ToEventParametersCallback());
+
+ SocketLease* lease = new SocketLease(this, server_index, socket.Pass());
+ return scoped_ptr<SocketLease>(lease);
+}
+
+scoped_ptr<StreamSocket> DnsSession::CreateTCPSocket(
+ unsigned server_index, const NetLog::Source& source) {
+ return socket_pool_->CreateTCPSocket(server_index, source);
+}
+
+// Release a socket.
+void DnsSession::FreeSocket(unsigned server_index,
+ scoped_ptr<DatagramClientSocket> socket) {
+ DCHECK(socket.get());
+
+ socket->NetLog().EndEvent(NetLog::TYPE_SOCKET_IN_USE);
+
+ socket_pool_->FreeSocket(server_index, socket.Pass());
+}
+
+base::TimeDelta DnsSession::NextTimeoutFromJacobson(unsigned server_index,
+ int attempt) {
+ DCHECK_LT(server_index, server_stats_.size());
+
+ base::TimeDelta timeout = server_stats_[server_index]->rtt_estimate +
+ 4 * server_stats_[server_index]->rtt_deviation;
+
+ timeout = std::max(timeout, base::TimeDelta::FromMilliseconds(kMinTimeoutMs));
+
+ // The timeout doubles every full round.
+ unsigned num_backoffs = attempt / config_.nameservers.size();
+
+ return std::min(timeout * (1 << num_backoffs),
+ base::TimeDelta::FromMilliseconds(kMaxTimeoutMs));
+}
+
+base::TimeDelta DnsSession::NextTimeoutFromHistogram(unsigned server_index,
+ int attempt) {
+ DCHECK_LT(server_index, server_stats_.size());
+
+ COMPILE_ASSERT(std::numeric_limits<base::HistogramBase::Count>::is_signed,
+ histogram_base_count_assumed_to_be_signed);
+
+ // Use fixed percentile of observed samples.
+ const base::SampleVector& samples =
+ *server_stats_[server_index]->rtt_histogram;
+
+ base::HistogramBase::Count total = samples.TotalCount();
+ base::HistogramBase::Count remaining_count = kRTOPercentile * total / 100;
+ size_t index = 0;
+ while (remaining_count > 0 && index < rtt_buckets_.Get().size()) {
+ remaining_count -= samples.GetCountAtIndex(index);
+ ++index;
+ }
+
+ base::TimeDelta timeout =
+ base::TimeDelta::FromMilliseconds(rtt_buckets_.Get().range(index));
+
+ timeout = std::max(timeout, base::TimeDelta::FromMilliseconds(kMinTimeoutMs));
+
+ // The timeout still doubles every full round.
+ unsigned num_backoffs = attempt / config_.nameservers.size();
+
+ return std::min(timeout * (1 << num_backoffs),
+ base::TimeDelta::FromMilliseconds(kMaxTimeoutMs));
+}
+
+} // namespace net
diff --git a/chromium/net/dns/dns_session.h b/chromium/net/dns/dns_session.h
new file mode 100644
index 00000000000..01ba5e5d154
--- /dev/null
+++ b/chromium/net/dns/dns_session.h
@@ -0,0 +1,147 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_SESSION_H_
+#define NET_DNS_DNS_SESSION_H_
+
+#include <vector>
+
+#include "base/lazy_instance.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/memory/scoped_vector.h"
+#include "base/metrics/bucket_ranges.h"
+#include "base/time/time.h"
+#include "net/base/net_export.h"
+#include "net/base/rand_callback.h"
+#include "net/dns/dns_config_service.h"
+#include "net/dns/dns_socket_pool.h"
+
+namespace base {
+class BucketRanges;
+class SampleVector;
+}
+
+namespace net {
+
+class ClientSocketFactory;
+class DatagramClientSocket;
+class NetLog;
+class StreamSocket;
+
+// Session parameters and state shared between DNS transactions.
+// Ref-counted so that DnsClient::Request can keep working in absence of
+// DnsClient. A DnsSession must be recreated when DnsConfig changes.
+class NET_EXPORT_PRIVATE DnsSession
+ : NON_EXPORTED_BASE(public base::RefCounted<DnsSession>) {
+ public:
+ typedef base::Callback<int()> RandCallback;
+
+ class NET_EXPORT_PRIVATE SocketLease {
+ public:
+ SocketLease(scoped_refptr<DnsSession> session,
+ unsigned server_index,
+ scoped_ptr<DatagramClientSocket> socket);
+ ~SocketLease();
+
+ unsigned server_index() const { return server_index_; }
+
+ DatagramClientSocket* socket() { return socket_.get(); }
+
+ private:
+ scoped_refptr<DnsSession> session_;
+ unsigned server_index_;
+ scoped_ptr<DatagramClientSocket> socket_;
+
+ DISALLOW_COPY_AND_ASSIGN(SocketLease);
+ };
+
+ DnsSession(const DnsConfig& config,
+ scoped_ptr<DnsSocketPool> socket_pool,
+ const RandIntCallback& rand_int_callback,
+ NetLog* net_log);
+
+ const DnsConfig& config() const { return config_; }
+ NetLog* net_log() const { return net_log_; }
+
+ // Return the next random query ID.
+ int NextQueryId() const;
+
+ // Return the index of the first configured server to use on first attempt.
+ unsigned NextFirstServerIndex();
+
+ // Start with |server_index| and find the index of the next known good server
+ // to use on this attempt. Returns |server_index| if this server has no
+ // recorded failures, or if there are no other servers that have not failed
+ // or have failed longer time ago.
+ unsigned NextGoodServerIndex(unsigned server_index);
+
+ // Record that server failed to respond (due to SRV_FAIL or timeout).
+ void RecordServerFailure(unsigned server_index);
+
+ // Record that server responded successfully.
+ void RecordServerSuccess(unsigned server_index);
+
+ // Record how long it took to receive a response from the server.
+ void RecordRTT(unsigned server_index, base::TimeDelta rtt);
+
+ // Record suspected loss of a packet for a specific server.
+ void RecordLostPacket(unsigned server_index, int attempt);
+
+ // Record server stats before it is destroyed.
+ void RecordServerStats();
+
+ // Return the timeout for the next query. |attempt| counts from 0 and is used
+ // for exponential backoff.
+ base::TimeDelta NextTimeout(unsigned server_index, int attempt);
+
+ // Allocate a socket, already connected to the server address.
+ // When the SocketLease is destroyed, the socket will be freed.
+ scoped_ptr<SocketLease> AllocateSocket(unsigned server_index,
+ const NetLog::Source& source);
+
+ // Creates a StreamSocket from the factory for a transaction over TCP. These
+ // sockets are not pooled.
+ scoped_ptr<StreamSocket> CreateTCPSocket(unsigned server_index,
+ const NetLog::Source& source);
+
+ private:
+ friend class base::RefCounted<DnsSession>;
+ ~DnsSession();
+
+ // Release a socket.
+ void FreeSocket(unsigned server_index,
+ scoped_ptr<DatagramClientSocket> socket);
+
+ // Return the timeout using the TCP timeout method.
+ base::TimeDelta NextTimeoutFromJacobson(unsigned server_index, int attempt);
+
+ // Compute the timeout using the histogram method.
+ base::TimeDelta NextTimeoutFromHistogram(unsigned server_index, int attempt);
+
+ const DnsConfig config_;
+ scoped_ptr<DnsSocketPool> socket_pool_;
+ RandCallback rand_callback_;
+ NetLog* net_log_;
+
+ // Current index into |config_.nameservers| to begin resolution with.
+ int server_index_;
+
+ struct ServerStats;
+
+ // Track runtime statistics of each DNS server.
+ ScopedVector<ServerStats> server_stats_;
+
+ // Buckets shared for all |ServerStats::rtt_histogram|.
+ struct RttBuckets : public base::BucketRanges {
+ RttBuckets();
+ };
+ static base::LazyInstance<RttBuckets>::Leaky rtt_buckets_;
+
+ DISALLOW_COPY_AND_ASSIGN(DnsSession);
+};
+
+} // namespace net
+
+#endif // NET_DNS_DNS_SESSION_H_
diff --git a/chromium/net/dns/dns_session_unittest.cc b/chromium/net/dns/dns_session_unittest.cc
new file mode 100644
index 00000000000..ed726f23234
--- /dev/null
+++ b/chromium/net/dns/dns_session_unittest.cc
@@ -0,0 +1,252 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_session.h"
+
+#include <list>
+
+#include "base/bind.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/rand_util.h"
+#include "base/stl_util.h"
+#include "net/base/net_log.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/dns_socket_pool.h"
+#include "net/socket/socket_test_util.h"
+#include "net/socket/ssl_client_socket.h"
+#include "net/socket/stream_socket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+class TestClientSocketFactory : public ClientSocketFactory {
+ public:
+ virtual ~TestClientSocketFactory();
+
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ net::NetLog* net_log,
+ const net::NetLog::Source& source) OVERRIDE;
+
+ virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
+ const AddressList& addresses,
+ NetLog*, const NetLog::Source&) OVERRIDE {
+ NOTIMPLEMENTED();
+ return scoped_ptr<StreamSocket>();
+ }
+
+ virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context) OVERRIDE {
+ NOTIMPLEMENTED();
+ return scoped_ptr<SSLClientSocket>();
+ }
+
+ virtual void ClearSSLSessionCache() OVERRIDE {
+ NOTIMPLEMENTED();
+ }
+
+ private:
+ std::list<SocketDataProvider*> data_providers_;
+};
+
+struct PoolEvent {
+ enum { ALLOCATE, FREE } action;
+ unsigned server_index;
+};
+
+class DnsSessionTest : public testing::Test {
+ public:
+ void OnSocketAllocated(unsigned server_index);
+ void OnSocketFreed(unsigned server_index);
+
+ protected:
+ void Initialize(unsigned num_servers);
+ scoped_ptr<DnsSession::SocketLease> Allocate(unsigned server_index);
+ bool DidAllocate(unsigned server_index);
+ bool DidFree(unsigned server_index);
+ bool NoMoreEvents();
+
+ DnsConfig config_;
+ scoped_ptr<TestClientSocketFactory> test_client_socket_factory_;
+ scoped_refptr<DnsSession> session_;
+ NetLog::Source source_;
+
+ private:
+ bool ExpectEvent(const PoolEvent& event);
+ std::list<PoolEvent> events_;
+};
+
+class MockDnsSocketPool : public DnsSocketPool {
+ public:
+ MockDnsSocketPool(ClientSocketFactory* factory, DnsSessionTest* test)
+ : DnsSocketPool(factory), test_(test) { }
+
+ virtual ~MockDnsSocketPool() { }
+
+ virtual void Initialize(
+ const std::vector<IPEndPoint>* nameservers,
+ NetLog* net_log) OVERRIDE {
+ InitializeInternal(nameservers, net_log);
+ }
+
+ virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
+ unsigned server_index) OVERRIDE {
+ test_->OnSocketAllocated(server_index);
+ return CreateConnectedSocket(server_index);
+ }
+
+ virtual void FreeSocket(
+ unsigned server_index,
+ scoped_ptr<DatagramClientSocket> socket) OVERRIDE {
+ test_->OnSocketFreed(server_index);
+ }
+
+ private:
+ DnsSessionTest* test_;
+};
+
+void DnsSessionTest::Initialize(unsigned num_servers) {
+ CHECK(num_servers < 256u);
+ config_.nameservers.clear();
+ IPAddressNumber dns_ip;
+ bool rv = ParseIPLiteralToNumber("192.168.1.0", &dns_ip);
+ EXPECT_TRUE(rv);
+ for (unsigned char i = 0; i < num_servers; ++i) {
+ dns_ip[3] = i;
+ IPEndPoint dns_endpoint(dns_ip, dns_protocol::kDefaultPort);
+ config_.nameservers.push_back(dns_endpoint);
+ }
+
+ test_client_socket_factory_.reset(new TestClientSocketFactory());
+
+ DnsSocketPool* dns_socket_pool =
+ new MockDnsSocketPool(test_client_socket_factory_.get(), this);
+
+ session_ = new DnsSession(config_,
+ scoped_ptr<DnsSocketPool>(dns_socket_pool),
+ base::Bind(&base::RandInt),
+ NULL /* NetLog */);
+
+ events_.clear();
+}
+
+scoped_ptr<DnsSession::SocketLease> DnsSessionTest::Allocate(
+ unsigned server_index) {
+ return session_->AllocateSocket(server_index, source_);
+}
+
+bool DnsSessionTest::DidAllocate(unsigned server_index) {
+ PoolEvent expected_event = { PoolEvent::ALLOCATE, server_index };
+ return ExpectEvent(expected_event);
+}
+
+bool DnsSessionTest::DidFree(unsigned server_index) {
+ PoolEvent expected_event = { PoolEvent::FREE, server_index };
+ return ExpectEvent(expected_event);
+}
+
+bool DnsSessionTest::NoMoreEvents() {
+ return events_.empty();
+}
+
+void DnsSessionTest::OnSocketAllocated(unsigned server_index) {
+ PoolEvent event = { PoolEvent::ALLOCATE, server_index };
+ events_.push_back(event);
+}
+
+void DnsSessionTest::OnSocketFreed(unsigned server_index) {
+ PoolEvent event = { PoolEvent::FREE, server_index };
+ events_.push_back(event);
+}
+
+bool DnsSessionTest::ExpectEvent(const PoolEvent& expected) {
+ if (events_.empty()) {
+ return false;
+ }
+
+ const PoolEvent actual = events_.front();
+ if ((expected.action != actual.action)
+ || (expected.server_index != actual.server_index)) {
+ return false;
+ }
+ events_.pop_front();
+
+ return true;
+}
+
+scoped_ptr<DatagramClientSocket>
+TestClientSocketFactory::CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ net::NetLog* net_log,
+ const net::NetLog::Source& source) {
+ // We're not actually expecting to send or receive any data, so use the
+ // simplest SocketDataProvider with no data supplied.
+ SocketDataProvider* data_provider = new StaticSocketDataProvider();
+ data_providers_.push_back(data_provider);
+ scoped_ptr<MockUDPClientSocket> socket(
+ new MockUDPClientSocket(data_provider, net_log));
+ data_provider->set_socket(socket.get());
+ return socket.PassAs<DatagramClientSocket>();
+}
+
+TestClientSocketFactory::~TestClientSocketFactory() {
+ STLDeleteElements(&data_providers_);
+}
+
+TEST_F(DnsSessionTest, AllocateFree) {
+ scoped_ptr<DnsSession::SocketLease> lease1, lease2;
+
+ Initialize(2);
+ EXPECT_TRUE(NoMoreEvents());
+
+ lease1 = Allocate(0);
+ EXPECT_TRUE(DidAllocate(0));
+ EXPECT_TRUE(NoMoreEvents());
+
+ lease2 = Allocate(1);
+ EXPECT_TRUE(DidAllocate(1));
+ EXPECT_TRUE(NoMoreEvents());
+
+ lease1.reset();
+ EXPECT_TRUE(DidFree(0));
+ EXPECT_TRUE(NoMoreEvents());
+
+ lease2.reset();
+ EXPECT_TRUE(DidFree(1));
+ EXPECT_TRUE(NoMoreEvents());
+}
+
+// Expect default calculated timeout to be within 10ms of in DnsConfig.
+TEST_F(DnsSessionTest, HistogramTimeoutNormal) {
+ Initialize(2);
+ base::TimeDelta timeoutDelta = session_->NextTimeout(0, 0) - config_.timeout;
+ EXPECT_LT(timeoutDelta.InMilliseconds(), 10);
+}
+
+// Expect short calculated timeout to be within 10ms of in DnsConfig.
+TEST_F(DnsSessionTest, HistogramTimeoutShort) {
+ config_.timeout = base::TimeDelta::FromMilliseconds(15);
+ Initialize(2);
+ base::TimeDelta timeoutDelta = session_->NextTimeout(0, 0) - config_.timeout;
+ EXPECT_LT(timeoutDelta.InMilliseconds(), 10);
+}
+
+// Expect long calculated timeout to be equal to one in DnsConfig.
+TEST_F(DnsSessionTest, HistogramTimeoutLong) {
+ config_.timeout = base::TimeDelta::FromSeconds(15);
+ Initialize(2);
+ base::TimeDelta timeout = session_->NextTimeout(0, 0);
+ EXPECT_EQ(config_.timeout.InMilliseconds(), timeout.InMilliseconds());
+}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/dns/dns_socket_pool.cc b/chromium/net/dns/dns_socket_pool.cc
new file mode 100644
index 00000000000..7a7ecd6ee8f
--- /dev/null
+++ b/chromium/net/dns/dns_socket_pool.cc
@@ -0,0 +1,234 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_socket_pool.h"
+
+#include "base/logging.h"
+#include "base/rand_util.h"
+#include "base/stl_util.h"
+#include "net/base/address_list.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/rand_callback.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/stream_socket.h"
+#include "net/udp/datagram_client_socket.h"
+
+namespace net {
+
+namespace {
+
+// When we initialize the SocketPool, we allocate kInitialPoolSize sockets.
+// When we allocate a socket, we ensure we have at least kAllocateMinSize
+// sockets to choose from. When we free a socket, we retain it if we have
+// less than kRetainMaxSize sockets in the pool.
+
+// On Windows, we can't request specific (random) ports, since that will
+// trigger firewall prompts, so request default ones, but keep a pile of
+// them. Everywhere else, request fresh, random ports each time.
+#if defined(OS_WIN)
+const DatagramSocket::BindType kBindType = DatagramSocket::DEFAULT_BIND;
+const unsigned kInitialPoolSize = 256;
+const unsigned kAllocateMinSize = 256;
+const unsigned kRetainMaxSize = 0;
+#else
+const DatagramSocket::BindType kBindType = DatagramSocket::RANDOM_BIND;
+const unsigned kInitialPoolSize = 0;
+const unsigned kAllocateMinSize = 1;
+const unsigned kRetainMaxSize = 0;
+#endif
+
+} // namespace
+
+DnsSocketPool::DnsSocketPool(ClientSocketFactory* socket_factory)
+ : socket_factory_(socket_factory),
+ net_log_(NULL),
+ nameservers_(NULL),
+ initialized_(false) {
+}
+
+void DnsSocketPool::InitializeInternal(
+ const std::vector<IPEndPoint>* nameservers,
+ NetLog* net_log) {
+ DCHECK(nameservers);
+ DCHECK(!initialized_);
+
+ net_log_ = net_log;
+ nameservers_ = nameservers;
+ initialized_ = true;
+}
+
+scoped_ptr<StreamSocket> DnsSocketPool::CreateTCPSocket(
+ unsigned server_index,
+ const NetLog::Source& source) {
+ DCHECK_LT(server_index, nameservers_->size());
+
+ return scoped_ptr<StreamSocket>(
+ socket_factory_->CreateTransportClientSocket(
+ AddressList((*nameservers_)[server_index]), net_log_, source));
+}
+
+scoped_ptr<DatagramClientSocket> DnsSocketPool::CreateConnectedSocket(
+ unsigned server_index) {
+ DCHECK_LT(server_index, nameservers_->size());
+
+ scoped_ptr<DatagramClientSocket> socket;
+
+ NetLog::Source no_source;
+ socket = socket_factory_->CreateDatagramClientSocket(
+ kBindType, base::Bind(&base::RandInt), net_log_, no_source);
+
+ if (socket.get()) {
+ int rv = socket->Connect((*nameservers_)[server_index]);
+ if (rv != OK) {
+ LOG(WARNING) << "Failed to connect socket: " << rv;
+ socket.reset();
+ }
+ } else {
+ LOG(WARNING) << "Failed to create socket.";
+ }
+
+ return socket.Pass();
+}
+
+class NullDnsSocketPool : public DnsSocketPool {
+ public:
+ NullDnsSocketPool(ClientSocketFactory* factory)
+ : DnsSocketPool(factory) {
+ }
+
+ virtual void Initialize(
+ const std::vector<IPEndPoint>* nameservers,
+ NetLog* net_log) OVERRIDE {
+ InitializeInternal(nameservers, net_log);
+ }
+
+ virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
+ unsigned server_index) OVERRIDE {
+ return CreateConnectedSocket(server_index);
+ }
+
+ virtual void FreeSocket(
+ unsigned server_index,
+ scoped_ptr<DatagramClientSocket> socket) OVERRIDE {
+ }
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(NullDnsSocketPool);
+};
+
+// static
+scoped_ptr<DnsSocketPool> DnsSocketPool::CreateNull(
+ ClientSocketFactory* factory) {
+ return scoped_ptr<DnsSocketPool>(new NullDnsSocketPool(factory));
+}
+
+class DefaultDnsSocketPool : public DnsSocketPool {
+ public:
+ DefaultDnsSocketPool(ClientSocketFactory* factory)
+ : DnsSocketPool(factory) {
+ };
+
+ virtual ~DefaultDnsSocketPool();
+
+ virtual void Initialize(
+ const std::vector<IPEndPoint>* nameservers,
+ NetLog* net_log) OVERRIDE;
+
+ virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
+ unsigned server_index) OVERRIDE;
+
+ virtual void FreeSocket(
+ unsigned server_index,
+ scoped_ptr<DatagramClientSocket> socket) OVERRIDE;
+
+ private:
+ void FillPool(unsigned server_index, unsigned size);
+
+ typedef std::vector<DatagramClientSocket*> SocketVector;
+
+ std::vector<SocketVector> pools_;
+
+ DISALLOW_COPY_AND_ASSIGN(DefaultDnsSocketPool);
+};
+
+// static
+scoped_ptr<DnsSocketPool> DnsSocketPool::CreateDefault(
+ ClientSocketFactory* factory) {
+ return scoped_ptr<DnsSocketPool>(new DefaultDnsSocketPool(factory));
+}
+
+void DefaultDnsSocketPool::Initialize(
+ const std::vector<IPEndPoint>* nameservers,
+ NetLog* net_log) {
+ InitializeInternal(nameservers, net_log);
+
+ DCHECK(pools_.empty());
+ const unsigned num_servers = nameservers->size();
+ pools_.resize(num_servers);
+ for (unsigned server_index = 0; server_index < num_servers; ++server_index)
+ FillPool(server_index, kInitialPoolSize);
+}
+
+DefaultDnsSocketPool::~DefaultDnsSocketPool() {
+ unsigned num_servers = pools_.size();
+ for (unsigned server_index = 0; server_index < num_servers; ++server_index) {
+ SocketVector& pool = pools_[server_index];
+ STLDeleteElements(&pool);
+ }
+}
+
+scoped_ptr<DatagramClientSocket> DefaultDnsSocketPool::AllocateSocket(
+ unsigned server_index) {
+ DCHECK_LT(server_index, pools_.size());
+ SocketVector& pool = pools_[server_index];
+
+ FillPool(server_index, kAllocateMinSize);
+ if (pool.size() == 0) {
+ LOG(WARNING) << "No DNS sockets available in pool " << server_index << "!";
+ return scoped_ptr<DatagramClientSocket>();
+ }
+
+ if (pool.size() < kAllocateMinSize) {
+ LOG(WARNING) << "Low DNS port entropy: wanted " << kAllocateMinSize
+ << " sockets to choose from, but only have " << pool.size()
+ << " in pool " << server_index << ".";
+ }
+
+ unsigned socket_index = base::RandInt(0, pool.size() - 1);
+ DatagramClientSocket* socket = pool[socket_index];
+ pool[socket_index] = pool.back();
+ pool.pop_back();
+
+ return scoped_ptr<DatagramClientSocket>(socket);
+}
+
+void DefaultDnsSocketPool::FreeSocket(
+ unsigned server_index,
+ scoped_ptr<DatagramClientSocket> socket) {
+ DCHECK_LT(server_index, pools_.size());
+
+ // In some builds, kRetainMaxSize will be 0 if we never reuse sockets.
+ // In that case, don't compile this code to avoid a "tautological
+ // comparison" warning from clang.
+#if kRetainMaxSize > 0
+ SocketVector& pool = pools_[server_index];
+ if (pool.size() < kRetainMaxSize)
+ pool.push_back(socket.release());
+#endif
+}
+
+void DefaultDnsSocketPool::FillPool(unsigned server_index, unsigned size) {
+ SocketVector& pool = pools_[server_index];
+
+ for (unsigned pool_index = pool.size(); pool_index < size; ++pool_index) {
+ DatagramClientSocket* socket =
+ CreateConnectedSocket(server_index).release();
+ if (!socket)
+ break;
+ pool.push_back(socket);
+ }
+}
+
+} // namespace net
diff --git a/chromium/net/dns/dns_socket_pool.h b/chromium/net/dns/dns_socket_pool.h
new file mode 100644
index 00000000000..6bfe474d6a9
--- /dev/null
+++ b/chromium/net/dns/dns_socket_pool.h
@@ -0,0 +1,91 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_SOCKET_POOL_H_
+#define NET_DNS_DNS_SOCKET_POOL_H_
+
+#include <vector>
+
+#include "base/memory/scoped_ptr.h"
+#include "net/base/net_export.h"
+#include "net/base/net_log.h"
+
+namespace net {
+
+class ClientSocketFactory;
+class DatagramClientSocket;
+class IPEndPoint;
+class NetLog;
+class StreamSocket;
+
+// A DnsSocketPool is an abstraction layer around a ClientSocketFactory that
+// allows preallocation, reuse, or other strategies to manage sockets connected
+// to DNS servers.
+class NET_EXPORT_PRIVATE DnsSocketPool {
+ public:
+ virtual ~DnsSocketPool() { }
+
+ // Creates a DnsSocketPool that implements the default strategy for managing
+ // sockets. (This varies by platform; see DnsSocketPoolImpl in
+ // dns_socket_pool.cc for details.)
+ static scoped_ptr<DnsSocketPool> CreateDefault(
+ ClientSocketFactory* factory);
+
+ // Creates a DnsSocketPool that implements a "null" strategy -- no sockets are
+ // preallocated, allocation requests are satisfied by calling the factory
+ // directly, and returned sockets are deleted immediately.
+ static scoped_ptr<DnsSocketPool> CreateNull(
+ ClientSocketFactory* factory);
+
+ // Initializes the DnsSocketPool. |nameservers| is the list of nameservers
+ // for which the DnsSocketPool will manage sockets; |net_log| is the NetLog
+ // used when constructing sockets with the factory.
+ //
+ // Initialize may not be called more than once, and must be called before
+ // calling AllocateSocket or FreeSocket.
+ virtual void Initialize(
+ const std::vector<IPEndPoint>* nameservers,
+ NetLog* net_log) = 0;
+
+ // Allocates a socket that is already connected to the nameserver referenced
+ // by |server_index|. May return a scoped_ptr to NULL if no sockets are
+ // available to reuse and the factory fails to produce a socket (or produces
+ // one on which Connect fails).
+ virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
+ unsigned server_index) = 0;
+
+ // Frees a socket allocated by AllocateSocket. |server_index| must be the
+ // same index passed to AllocateSocket.
+ virtual void FreeSocket(
+ unsigned server_index,
+ scoped_ptr<DatagramClientSocket> socket) = 0;
+
+ // Creates a StreamSocket from the factory for a transaction over TCP. These
+ // sockets are not pooled.
+ scoped_ptr<StreamSocket> CreateTCPSocket(
+ unsigned server_index,
+ const NetLog::Source& source);
+
+ protected:
+ DnsSocketPool(ClientSocketFactory* socket_factory);
+
+ void InitializeInternal(
+ const std::vector<IPEndPoint>* nameservers,
+ NetLog* net_log);
+
+ scoped_ptr<DatagramClientSocket> CreateConnectedSocket(
+ unsigned server_index);
+
+ private:
+ ClientSocketFactory* socket_factory_;
+ NetLog* net_log_;
+ const std::vector<IPEndPoint>* nameservers_;
+ bool initialized_;
+
+ DISALLOW_COPY_AND_ASSIGN(DnsSocketPool);
+};
+
+} // namespace net
+
+#endif // NET_DNS_DNS_SOCKET_POOL_H_
diff --git a/chromium/net/dns/dns_test_util.cc b/chromium/net/dns/dns_test_util.cc
new file mode 100644
index 00000000000..37bf855dd10
--- /dev/null
+++ b/chromium/net/dns/dns_test_util.cc
@@ -0,0 +1,210 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_test_util.h"
+
+#include <string>
+
+#include "base/bind.h"
+#include "base/memory/weak_ptr.h"
+#include "base/message_loop/message_loop.h"
+#include "base/sys_byteorder.h"
+#include "net/base/big_endian.h"
+#include "net/base/dns_util.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_errors.h"
+#include "net/dns/address_sorter.h"
+#include "net/dns/dns_client.h"
+#include "net/dns/dns_config_service.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/dns_query.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/dns_transaction.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+namespace {
+
+// A DnsTransaction which uses MockDnsClientRuleList to determine the response.
+class MockTransaction : public DnsTransaction,
+ public base::SupportsWeakPtr<MockTransaction> {
+ public:
+ MockTransaction(const MockDnsClientRuleList& rules,
+ const std::string& hostname,
+ uint16 qtype,
+ const DnsTransactionFactory::CallbackType& callback)
+ : result_(MockDnsClientRule::FAIL),
+ hostname_(hostname),
+ qtype_(qtype),
+ callback_(callback),
+ started_(false) {
+ // Find the relevant rule which matches |qtype| and prefix of |hostname|.
+ for (size_t i = 0; i < rules.size(); ++i) {
+ const std::string& prefix = rules[i].prefix;
+ if ((rules[i].qtype == qtype) &&
+ (hostname.size() >= prefix.size()) &&
+ (hostname.compare(0, prefix.size(), prefix) == 0)) {
+ result_ = rules[i].result;
+ break;
+ }
+ }
+ }
+
+ virtual const std::string& GetHostname() const OVERRIDE {
+ return hostname_;
+ }
+
+ virtual uint16 GetType() const OVERRIDE {
+ return qtype_;
+ }
+
+ virtual void Start() OVERRIDE {
+ EXPECT_FALSE(started_);
+ started_ = true;
+ // Using WeakPtr to cleanly cancel when transaction is destroyed.
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr()));
+ }
+
+ private:
+ void Finish() {
+ switch (result_) {
+ case MockDnsClientRule::EMPTY:
+ case MockDnsClientRule::OK: {
+ std::string qname;
+ DNSDomainFromDot(hostname_, &qname);
+ DnsQuery query(0, qname, qtype_);
+
+ DnsResponse response;
+ char* buffer = response.io_buffer()->data();
+ int nbytes = query.io_buffer()->size();
+ memcpy(buffer, query.io_buffer()->data(), nbytes);
+ dns_protocol::Header* header =
+ reinterpret_cast<dns_protocol::Header*>(buffer);
+ header->flags |= dns_protocol::kFlagResponse;
+
+ if (MockDnsClientRule::OK == result_) {
+ const uint16 kPointerToQueryName =
+ static_cast<uint16>(0xc000 | sizeof(*header));
+
+ const uint32 kTTL = 86400; // One day.
+
+ // Size of RDATA which is a IPv4 or IPv6 address.
+ size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ?
+ net::kIPv4AddressSize : net::kIPv6AddressSize;
+
+ // 12 is the sum of sizes of the compressed name reference, TYPE,
+ // CLASS, TTL and RDLENGTH.
+ size_t answer_size = 12 + rdata_size;
+
+ // Write answer with loopback IP address.
+ header->ancount = base::HostToNet16(1);
+ BigEndianWriter writer(buffer + nbytes, answer_size);
+ writer.WriteU16(kPointerToQueryName);
+ writer.WriteU16(qtype_);
+ writer.WriteU16(net::dns_protocol::kClassIN);
+ writer.WriteU32(kTTL);
+ writer.WriteU16(rdata_size);
+ if (qtype_ == net::dns_protocol::kTypeA) {
+ char kIPv4Loopback[] = { 0x7f, 0, 0, 1 };
+ writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback));
+ } else {
+ char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 1 };
+ writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback));
+ }
+ nbytes += answer_size;
+ }
+ EXPECT_TRUE(response.InitParse(nbytes, query));
+ callback_.Run(this, OK, &response);
+ } break;
+ case MockDnsClientRule::FAIL:
+ callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL);
+ break;
+ case MockDnsClientRule::TIMEOUT:
+ callback_.Run(this, ERR_DNS_TIMED_OUT, NULL);
+ break;
+ default:
+ NOTREACHED();
+ break;
+ }
+ }
+
+ MockDnsClientRule::Result result_;
+ const std::string hostname_;
+ const uint16 qtype_;
+ DnsTransactionFactory::CallbackType callback_;
+ bool started_;
+};
+
+
+// A DnsTransactionFactory which creates MockTransaction.
+class MockTransactionFactory : public DnsTransactionFactory {
+ public:
+ explicit MockTransactionFactory(const MockDnsClientRuleList& rules)
+ : rules_(rules) {}
+ virtual ~MockTransactionFactory() {}
+
+ virtual scoped_ptr<DnsTransaction> CreateTransaction(
+ const std::string& hostname,
+ uint16 qtype,
+ const DnsTransactionFactory::CallbackType& callback,
+ const BoundNetLog&) OVERRIDE {
+ return scoped_ptr<DnsTransaction>(
+ new MockTransaction(rules_, hostname, qtype, callback));
+ }
+
+ private:
+ MockDnsClientRuleList rules_;
+};
+
+class MockAddressSorter : public AddressSorter {
+ public:
+ virtual ~MockAddressSorter() {}
+ virtual void Sort(const AddressList& list,
+ const CallbackType& callback) const OVERRIDE {
+ // Do nothing.
+ callback.Run(true, list);
+ }
+};
+
+// MockDnsClient provides MockTransactionFactory.
+class MockDnsClient : public DnsClient {
+ public:
+ MockDnsClient(const DnsConfig& config,
+ const MockDnsClientRuleList& rules)
+ : config_(config), factory_(rules) {}
+ virtual ~MockDnsClient() {}
+
+ virtual void SetConfig(const DnsConfig& config) OVERRIDE {
+ config_ = config;
+ }
+
+ virtual const DnsConfig* GetConfig() const OVERRIDE {
+ return config_.IsValid() ? &config_ : NULL;
+ }
+
+ virtual DnsTransactionFactory* GetTransactionFactory() OVERRIDE {
+ return config_.IsValid() ? &factory_ : NULL;
+ }
+
+ virtual AddressSorter* GetAddressSorter() OVERRIDE {
+ return &address_sorter_;
+ }
+
+ private:
+ DnsConfig config_;
+ MockTransactionFactory factory_;
+ MockAddressSorter address_sorter_;
+};
+
+} // namespace
+
+// static
+scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config,
+ const MockDnsClientRuleList& rules) {
+ return scoped_ptr<DnsClient>(new MockDnsClient(config, rules));
+}
+
+} // namespace net
diff --git a/chromium/net/dns/dns_test_util.h b/chromium/net/dns/dns_test_util.h
new file mode 100644
index 00000000000..d447b299c86
--- /dev/null
+++ b/chromium/net/dns/dns_test_util.h
@@ -0,0 +1,205 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_TEST_UTIL_H_
+#define NET_DNS_DNS_TEST_UTIL_H_
+
+#include <string>
+#include <vector>
+
+#include "base/basictypes.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/dns/dns_config_service.h"
+#include "net/dns/dns_protocol.h"
+
+namespace net {
+
+//-----------------------------------------------------------------------------
+// Query/response set for www.google.com, ID is fixed to 0.
+static const char kT0HostName[] = "www.google.com";
+static const uint16 kT0Qtype = dns_protocol::kTypeA;
+static const char kT0DnsName[] = {
+ 0x03, 'w', 'w', 'w',
+ 0x06, 'g', 'o', 'o', 'g', 'l', 'e',
+ 0x03, 'c', 'o', 'm',
+ 0x00
+};
+static const size_t kT0QuerySize = 32;
+static const uint8 kT0ResponseDatagram[] = {
+ // response contains one CNAME for www.l.google.com and the following
+ // IP addresses: 74.125.226.{179,180,176,177,178}
+ 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x06,
+ 0x00, 0x00, 0x00, 0x00, 0x03, 0x77, 0x77, 0x77,
+ 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03,
+ 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01,
+ 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x01,
+ 0x4d, 0x13, 0x00, 0x08, 0x03, 0x77, 0x77, 0x77,
+ 0x01, 0x6c, 0xc0, 0x10, 0xc0, 0x2c, 0x00, 0x01,
+ 0x00, 0x01, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x04,
+ 0x4a, 0x7d, 0xe2, 0xb3, 0xc0, 0x2c, 0x00, 0x01,
+ 0x00, 0x01, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x04,
+ 0x4a, 0x7d, 0xe2, 0xb4, 0xc0, 0x2c, 0x00, 0x01,
+ 0x00, 0x01, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x04,
+ 0x4a, 0x7d, 0xe2, 0xb0, 0xc0, 0x2c, 0x00, 0x01,
+ 0x00, 0x01, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x04,
+ 0x4a, 0x7d, 0xe2, 0xb1, 0xc0, 0x2c, 0x00, 0x01,
+ 0x00, 0x01, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x04,
+ 0x4a, 0x7d, 0xe2, 0xb2
+};
+static const char* const kT0IpAddresses[] = {
+ "74.125.226.179", "74.125.226.180", "74.125.226.176",
+ "74.125.226.177", "74.125.226.178"
+};
+static const char kT0CanonName[] = "www.l.google.com";
+static const int kT0TTL = 0x000000e4;
+// +1 for the CNAME record.
+static const unsigned kT0RecordCount = arraysize(kT0IpAddresses) + 1;
+
+//-----------------------------------------------------------------------------
+// Query/response set for codereview.chromium.org, ID is fixed to 1.
+static const char kT1HostName[] = "codereview.chromium.org";
+static const uint16 kT1Qtype = dns_protocol::kTypeA;
+static const char kT1DnsName[] = {
+ 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w',
+ 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm',
+ 0x03, 'o', 'r', 'g',
+ 0x00
+};
+static const size_t kT1QuerySize = 41;
+static const uint8 kT1ResponseDatagram[] = {
+ // response contains one CNAME for ghs.l.google.com and the following
+ // IP address: 64.233.169.121
+ 0x00, 0x01, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02,
+ 0x00, 0x00, 0x00, 0x00, 0x0a, 0x63, 0x6f, 0x64,
+ 0x65, 0x72, 0x65, 0x76, 0x69, 0x65, 0x77, 0x08,
+ 0x63, 0x68, 0x72, 0x6f, 0x6d, 0x69, 0x75, 0x6d,
+ 0x03, 0x6f, 0x72, 0x67, 0x00, 0x00, 0x01, 0x00,
+ 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00,
+ 0x01, 0x41, 0x75, 0x00, 0x12, 0x03, 0x67, 0x68,
+ 0x73, 0x01, 0x6c, 0x06, 0x67, 0x6f, 0x6f, 0x67,
+ 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0xc0,
+ 0x35, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01,
+ 0x0b, 0x00, 0x04, 0x40, 0xe9, 0xa9, 0x79
+};
+static const char* const kT1IpAddresses[] = {
+ "64.233.169.121"
+};
+static const char kT1CanonName[] = "ghs.l.google.com";
+static const int kT1TTL = 0x0000010b;
+// +1 for the CNAME record.
+static const unsigned kT1RecordCount = arraysize(kT1IpAddresses) + 1;
+
+//-----------------------------------------------------------------------------
+// Query/response set for www.ccs.neu.edu, ID is fixed to 2.
+static const char kT2HostName[] = "www.ccs.neu.edu";
+static const uint16 kT2Qtype = dns_protocol::kTypeA;
+static const char kT2DnsName[] = {
+ 0x03, 'w', 'w', 'w',
+ 0x03, 'c', 'c', 's',
+ 0x03, 'n', 'e', 'u',
+ 0x03, 'e', 'd', 'u',
+ 0x00
+};
+static const size_t kT2QuerySize = 33;
+static const uint8 kT2ResponseDatagram[] = {
+ // response contains one CNAME for vulcan.ccs.neu.edu and the following
+ // IP address: 129.10.116.81
+ 0x00, 0x02, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02,
+ 0x00, 0x00, 0x00, 0x00, 0x03, 0x77, 0x77, 0x77,
+ 0x03, 0x63, 0x63, 0x73, 0x03, 0x6e, 0x65, 0x75,
+ 0x03, 0x65, 0x64, 0x75, 0x00, 0x00, 0x01, 0x00,
+ 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00,
+ 0x00, 0x01, 0x2c, 0x00, 0x09, 0x06, 0x76, 0x75,
+ 0x6c, 0x63, 0x61, 0x6e, 0xc0, 0x10, 0xc0, 0x2d,
+ 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x2c,
+ 0x00, 0x04, 0x81, 0x0a, 0x74, 0x51
+};
+static const char* const kT2IpAddresses[] = {
+ "129.10.116.81"
+};
+static const char kT2CanonName[] = "vulcan.ccs.neu.edu";
+static const int kT2TTL = 0x0000012c;
+// +1 for the CNAME record.
+static const unsigned kT2RecordCount = arraysize(kT2IpAddresses) + 1;
+
+//-----------------------------------------------------------------------------
+// Query/response set for www.google.az, ID is fixed to 3.
+static const char kT3HostName[] = "www.google.az";
+static const uint16 kT3Qtype = dns_protocol::kTypeA;
+static const char kT3DnsName[] = {
+ 0x03, 'w', 'w', 'w',
+ 0x06, 'g', 'o', 'o', 'g', 'l', 'e',
+ 0x02, 'a', 'z',
+ 0x00
+};
+static const size_t kT3QuerySize = 31;
+static const uint8 kT3ResponseDatagram[] = {
+ // response contains www.google.com as CNAME for www.google.az and
+ // www.l.google.com as CNAME for www.google.com and the following
+ // IP addresses: 74.125.226.{178,179,180,176,177}
+ // The TTLs on the records are: 0x00015099, 0x00025099, 0x00000415,
+ // 0x00003015, 0x00002015, 0x00000015, 0x00001015.
+ // The last record is an imaginary TXT record for t.google.com.
+ 0x00, 0x03, 0x81, 0x80, 0x00, 0x01, 0x00, 0x08,
+ 0x00, 0x00, 0x00, 0x00, 0x03, 0x77, 0x77, 0x77,
+ 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x02,
+ 0x61, 0x7a, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0,
+ 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x01, 0x50,
+ 0x99, 0x00, 0x10, 0x03, 0x77, 0x77, 0x77, 0x06,
+ 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63,
+ 0x6f, 0x6d, 0x00, 0xc0, 0x2b, 0x00, 0x05, 0x00,
+ 0x01, 0x00, 0x02, 0x50, 0x99, 0x00, 0x08, 0x03,
+ 0x77, 0x77, 0x77, 0x01, 0x6c, 0xc0, 0x2f, 0xc0,
+ 0x47, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x04,
+ 0x15, 0x00, 0x04, 0x4a, 0x7d, 0xe2, 0xb2, 0xc0,
+ 0x47, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x30,
+ 0x15, 0x00, 0x04, 0x4a, 0x7d, 0xe2, 0xb3, 0xc0,
+ 0x47, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x20,
+ 0x15, 0x00, 0x04, 0x4a, 0x7d, 0xe2, 0xb4, 0xc0,
+ 0x47, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x15, 0x00, 0x04, 0x4a, 0x7d, 0xe2, 0xb0, 0xc0,
+ 0x47, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x10,
+ 0x15, 0x00, 0x04, 0x4a, 0x7d, 0xe2, 0xb1, 0x01,
+ 0x74, 0xc0, 0x2f, 0x00, 0x10, 0x00, 0x01, 0x00,
+ 0x00, 0x00, 0x01, 0x00, 0x04, 0xde, 0xad, 0xfe,
+ 0xed
+};
+static const char* const kT3IpAddresses[] = {
+ "74.125.226.178", "74.125.226.179", "74.125.226.180",
+ "74.125.226.176", "74.125.226.177"
+};
+static const char kT3CanonName[] = "www.l.google.com";
+static const int kT3TTL = 0x00000015;
+// +2 for the CNAME records, +1 for TXT record.
+static const unsigned kT3RecordCount = arraysize(kT3IpAddresses) + 3;
+
+class DnsClient;
+
+struct MockDnsClientRule {
+ enum Result {
+ FAIL, // Fail asynchronously with ERR_NAME_NOT_RESOLVED.
+ TIMEOUT, // Fail asynchronously with ERR_DNS_TIMEOUT.
+ EMPTY, // Return an empty response.
+ OK, // Return a response with loopback address.
+ };
+
+ MockDnsClientRule(const std::string& prefix_arg,
+ uint16 qtype_arg,
+ Result result_arg)
+ : result(result_arg), prefix(prefix_arg), qtype(qtype_arg) { }
+
+ Result result;
+ std::string prefix;
+ uint16 qtype;
+};
+
+typedef std::vector<MockDnsClientRule> MockDnsClientRuleList;
+
+// Creates mock DnsClient for testing HostResolverImpl.
+scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config,
+ const MockDnsClientRuleList& rules);
+
+} // namespace net
+
+#endif // NET_DNS_DNS_TEST_UTIL_H_
diff --git a/chromium/net/dns/dns_transaction.cc b/chromium/net/dns/dns_transaction.cc
new file mode 100644
index 00000000000..170ea678d4b
--- /dev/null
+++ b/chromium/net/dns/dns_transaction.cc
@@ -0,0 +1,963 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_transaction.h"
+
+#include <deque>
+#include <string>
+#include <vector>
+
+#include "base/bind.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/memory/scoped_vector.h"
+#include "base/memory/weak_ptr.h"
+#include "base/message_loop/message_loop.h"
+#include "base/metrics/histogram.h"
+#include "base/rand_util.h"
+#include "base/stl_util.h"
+#include "base/strings/string_piece.h"
+#include "base/threading/non_thread_safe.h"
+#include "base/timer/timer.h"
+#include "base/values.h"
+#include "net/base/big_endian.h"
+#include "net/base/completion_callback.h"
+#include "net/base/dns_util.h"
+#include "net/base/io_buffer.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/dns_query.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/dns_session.h"
+#include "net/socket/stream_socket.h"
+#include "net/udp/datagram_client_socket.h"
+
+namespace net {
+
+namespace {
+
+// Provide a common macro to simplify code and readability. We must use a
+// macro as the underlying HISTOGRAM macro creates static variables.
+#define DNS_HISTOGRAM(name, time) UMA_HISTOGRAM_CUSTOM_TIMES(name, time, \
+ base::TimeDelta::FromMilliseconds(1), base::TimeDelta::FromHours(1), 100)
+
+// Count labels in the fully-qualified name in DNS format.
+int CountLabels(const std::string& name) {
+ size_t count = 0;
+ for (size_t i = 0; i < name.size() && name[i]; i += name[i] + 1)
+ ++count;
+ return count;
+}
+
+bool IsIPLiteral(const std::string& hostname) {
+ IPAddressNumber ip;
+ return ParseIPLiteralToNumber(hostname, &ip);
+}
+
+base::Value* NetLogStartCallback(const std::string* hostname,
+ uint16 qtype,
+ NetLog::LogLevel /* log_level */) {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ dict->SetString("hostname", *hostname);
+ dict->SetInteger("query_type", qtype);
+ return dict;
+};
+
+// ----------------------------------------------------------------------------
+
+// A single asynchronous DNS exchange, which consists of sending out a
+// DNS query, waiting for a response, and returning the response that it
+// matches. Logging is done in the socket and in the outer DnsTransaction.
+class DnsAttempt {
+ public:
+ explicit DnsAttempt(unsigned server_index)
+ : result_(ERR_FAILED), server_index_(server_index) {}
+
+ virtual ~DnsAttempt() {}
+ // Starts the attempt. Returns ERR_IO_PENDING if cannot complete synchronously
+ // and calls |callback| upon completion.
+ virtual int Start(const CompletionCallback& callback) = 0;
+
+ // Returns the query of this attempt.
+ virtual const DnsQuery* GetQuery() const = 0;
+
+ // Returns the response or NULL if has not received a matching response from
+ // the server.
+ virtual const DnsResponse* GetResponse() const = 0;
+
+ // Returns the net log bound to the source of the socket.
+ virtual const BoundNetLog& GetSocketNetLog() const = 0;
+
+ // Returns the index of the destination server within DnsConfig::nameservers.
+ unsigned server_index() const { return server_index_; }
+
+ // Returns a Value representing the received response, along with a reference
+ // to the NetLog source source of the UDP socket used. The request must have
+ // completed before this is called.
+ base::Value* NetLogResponseCallback(NetLog::LogLevel log_level) const {
+ DCHECK(GetResponse()->IsValid());
+
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ dict->SetInteger("rcode", GetResponse()->rcode());
+ dict->SetInteger("answer_count", GetResponse()->answer_count());
+ GetSocketNetLog().source().AddToEventParameters(dict);
+ return dict;
+ }
+
+ void set_result(int result) {
+ result_ = result;
+ }
+
+ // True if current attempt is pending (waiting for server response).
+ bool is_pending() const {
+ return result_ == ERR_IO_PENDING;
+ }
+
+ // True if attempt is completed (received server response).
+ bool is_completed() const {
+ return (result_ == OK) || (result_ == ERR_NAME_NOT_RESOLVED) ||
+ (result_ == ERR_DNS_SERVER_REQUIRES_TCP);
+ }
+
+ private:
+ // Result of last operation.
+ int result_;
+
+ const unsigned server_index_;
+};
+
+class DnsUDPAttempt : public DnsAttempt {
+ public:
+ DnsUDPAttempt(unsigned server_index,
+ scoped_ptr<DnsSession::SocketLease> socket_lease,
+ scoped_ptr<DnsQuery> query)
+ : DnsAttempt(server_index),
+ next_state_(STATE_NONE),
+ received_malformed_response_(false),
+ socket_lease_(socket_lease.Pass()),
+ query_(query.Pass()) {}
+
+ // DnsAttempt:
+ virtual int Start(const CompletionCallback& callback) OVERRIDE {
+ DCHECK_EQ(STATE_NONE, next_state_);
+ callback_ = callback;
+ start_time_ = base::TimeTicks::Now();
+ next_state_ = STATE_SEND_QUERY;
+ return DoLoop(OK);
+ }
+
+ virtual const DnsQuery* GetQuery() const OVERRIDE {
+ return query_.get();
+ }
+
+ virtual const DnsResponse* GetResponse() const OVERRIDE {
+ const DnsResponse* resp = response_.get();
+ return (resp != NULL && resp->IsValid()) ? resp : NULL;
+ }
+
+ virtual const BoundNetLog& GetSocketNetLog() const OVERRIDE {
+ return socket_lease_->socket()->NetLog();
+ }
+
+ private:
+ enum State {
+ STATE_SEND_QUERY,
+ STATE_SEND_QUERY_COMPLETE,
+ STATE_READ_RESPONSE,
+ STATE_READ_RESPONSE_COMPLETE,
+ STATE_NONE,
+ };
+
+ DatagramClientSocket* socket() {
+ return socket_lease_->socket();
+ }
+
+ int DoLoop(int result) {
+ CHECK_NE(STATE_NONE, next_state_);
+ int rv = result;
+ do {
+ State state = next_state_;
+ next_state_ = STATE_NONE;
+ switch (state) {
+ case STATE_SEND_QUERY:
+ rv = DoSendQuery();
+ break;
+ case STATE_SEND_QUERY_COMPLETE:
+ rv = DoSendQueryComplete(rv);
+ break;
+ case STATE_READ_RESPONSE:
+ rv = DoReadResponse();
+ break;
+ case STATE_READ_RESPONSE_COMPLETE:
+ rv = DoReadResponseComplete(rv);
+ break;
+ default:
+ NOTREACHED();
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
+
+ set_result(rv);
+ // If we received a malformed response, and are now waiting for another one,
+ // indicate to the transaction that the server might be misbehaving.
+ if (rv == ERR_IO_PENDING && received_malformed_response_)
+ return ERR_DNS_MALFORMED_RESPONSE;
+ if (rv == OK) {
+ DCHECK_EQ(STATE_NONE, next_state_);
+ DNS_HISTOGRAM("AsyncDNS.UDPAttemptSuccess",
+ base::TimeTicks::Now() - start_time_);
+ } else if (rv != ERR_IO_PENDING) {
+ DNS_HISTOGRAM("AsyncDNS.UDPAttemptFail",
+ base::TimeTicks::Now() - start_time_);
+ }
+ return rv;
+ }
+
+ int DoSendQuery() {
+ next_state_ = STATE_SEND_QUERY_COMPLETE;
+ return socket()->Write(query_->io_buffer(),
+ query_->io_buffer()->size(),
+ base::Bind(&DnsUDPAttempt::OnIOComplete,
+ base::Unretained(this)));
+ }
+
+ int DoSendQueryComplete(int rv) {
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ if (rv < 0)
+ return rv;
+
+ // Writing to UDP should not result in a partial datagram.
+ if (rv != query_->io_buffer()->size())
+ return ERR_MSG_TOO_BIG;
+
+ next_state_ = STATE_READ_RESPONSE;
+ return OK;
+ }
+
+ int DoReadResponse() {
+ next_state_ = STATE_READ_RESPONSE_COMPLETE;
+ response_.reset(new DnsResponse());
+ return socket()->Read(response_->io_buffer(),
+ response_->io_buffer()->size(),
+ base::Bind(&DnsUDPAttempt::OnIOComplete,
+ base::Unretained(this)));
+ }
+
+ int DoReadResponseComplete(int rv) {
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ if (rv < 0)
+ return rv;
+
+ DCHECK(rv);
+ if (!response_->InitParse(rv, *query_)) {
+ // Other implementations simply ignore mismatched responses. Since each
+ // DnsUDPAttempt binds to a different port, we might find that responses
+ // to previously timed out queries lead to failures in the future.
+ // Our solution is to make another attempt, in case the query truly
+ // failed, but keep this attempt alive, in case it was a false alarm.
+ received_malformed_response_ = true;
+ next_state_ = STATE_READ_RESPONSE;
+ return OK;
+ }
+ if (response_->flags() & dns_protocol::kFlagTC)
+ return ERR_DNS_SERVER_REQUIRES_TCP;
+ // TODO(szym): Extract TTL for NXDOMAIN results. http://crbug.com/115051
+ if (response_->rcode() == dns_protocol::kRcodeNXDOMAIN)
+ return ERR_NAME_NOT_RESOLVED;
+ if (response_->rcode() != dns_protocol::kRcodeNOERROR)
+ return ERR_DNS_SERVER_FAILED;
+
+ return OK;
+ }
+
+ void OnIOComplete(int rv) {
+ rv = DoLoop(rv);
+ if (rv != ERR_IO_PENDING)
+ callback_.Run(rv);
+ }
+
+ State next_state_;
+ bool received_malformed_response_;
+ base::TimeTicks start_time_;
+
+ scoped_ptr<DnsSession::SocketLease> socket_lease_;
+ scoped_ptr<DnsQuery> query_;
+
+ scoped_ptr<DnsResponse> response_;
+
+ CompletionCallback callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(DnsUDPAttempt);
+};
+
+class DnsTCPAttempt : public DnsAttempt {
+ public:
+ DnsTCPAttempt(unsigned server_index,
+ scoped_ptr<StreamSocket> socket,
+ scoped_ptr<DnsQuery> query)
+ : DnsAttempt(server_index),
+ next_state_(STATE_NONE),
+ socket_(socket.Pass()),
+ query_(query.Pass()),
+ length_buffer_(new IOBufferWithSize(sizeof(uint16))),
+ response_length_(0) {}
+
+ // DnsAttempt:
+ virtual int Start(const CompletionCallback& callback) OVERRIDE {
+ DCHECK_EQ(STATE_NONE, next_state_);
+ callback_ = callback;
+ start_time_ = base::TimeTicks::Now();
+ next_state_ = STATE_CONNECT_COMPLETE;
+ int rv = socket_->Connect(base::Bind(&DnsTCPAttempt::OnIOComplete,
+ base::Unretained(this)));
+ if (rv == ERR_IO_PENDING) {
+ set_result(rv);
+ return rv;
+ }
+ return DoLoop(rv);
+ }
+
+ virtual const DnsQuery* GetQuery() const OVERRIDE {
+ return query_.get();
+ }
+
+ virtual const DnsResponse* GetResponse() const OVERRIDE {
+ const DnsResponse* resp = response_.get();
+ return (resp != NULL && resp->IsValid()) ? resp : NULL;
+ }
+
+ virtual const BoundNetLog& GetSocketNetLog() const OVERRIDE {
+ return socket_->NetLog();
+ }
+
+ private:
+ enum State {
+ STATE_CONNECT_COMPLETE,
+ STATE_SEND_LENGTH,
+ STATE_SEND_QUERY,
+ STATE_READ_LENGTH,
+ STATE_READ_RESPONSE,
+ STATE_NONE,
+ };
+
+ int DoLoop(int result) {
+ CHECK_NE(STATE_NONE, next_state_);
+ int rv = result;
+ do {
+ State state = next_state_;
+ next_state_ = STATE_NONE;
+ switch (state) {
+ case STATE_CONNECT_COMPLETE:
+ rv = DoConnectComplete(rv);
+ break;
+ case STATE_SEND_LENGTH:
+ rv = DoSendLength(rv);
+ break;
+ case STATE_SEND_QUERY:
+ rv = DoSendQuery(rv);
+ break;
+ case STATE_READ_LENGTH:
+ rv = DoReadLength(rv);
+ break;
+ case STATE_READ_RESPONSE:
+ rv = DoReadResponse(rv);
+ break;
+ default:
+ NOTREACHED();
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
+
+ set_result(rv);
+ if (rv == OK) {
+ DCHECK_EQ(STATE_NONE, next_state_);
+ DNS_HISTOGRAM("AsyncDNS.TCPAttemptSuccess",
+ base::TimeTicks::Now() - start_time_);
+ } else if (rv != ERR_IO_PENDING) {
+ DNS_HISTOGRAM("AsyncDNS.TCPAttemptFail",
+ base::TimeTicks::Now() - start_time_);
+ }
+ return rv;
+ }
+
+ int DoConnectComplete(int rv) {
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ if (rv < 0)
+ return rv;
+
+ WriteBigEndian<uint16>(length_buffer_->data(), query_->io_buffer()->size());
+ buffer_ =
+ new DrainableIOBuffer(length_buffer_.get(), length_buffer_->size());
+ next_state_ = STATE_SEND_LENGTH;
+ return OK;
+ }
+
+ int DoSendLength(int rv) {
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ if (rv < 0)
+ return rv;
+
+ buffer_->DidConsume(rv);
+ if (buffer_->BytesRemaining() > 0) {
+ next_state_ = STATE_SEND_LENGTH;
+ return socket_->Write(
+ buffer_.get(),
+ buffer_->BytesRemaining(),
+ base::Bind(&DnsTCPAttempt::OnIOComplete, base::Unretained(this)));
+ }
+ buffer_ = new DrainableIOBuffer(query_->io_buffer(),
+ query_->io_buffer()->size());
+ next_state_ = STATE_SEND_QUERY;
+ return OK;
+ }
+
+ int DoSendQuery(int rv) {
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ if (rv < 0)
+ return rv;
+
+ buffer_->DidConsume(rv);
+ if (buffer_->BytesRemaining() > 0) {
+ next_state_ = STATE_SEND_QUERY;
+ return socket_->Write(
+ buffer_.get(),
+ buffer_->BytesRemaining(),
+ base::Bind(&DnsTCPAttempt::OnIOComplete, base::Unretained(this)));
+ }
+ buffer_ =
+ new DrainableIOBuffer(length_buffer_.get(), length_buffer_->size());
+ next_state_ = STATE_READ_LENGTH;
+ return OK;
+ }
+
+ int DoReadLength(int rv) {
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ if (rv < 0)
+ return rv;
+
+ buffer_->DidConsume(rv);
+ if (buffer_->BytesRemaining() > 0) {
+ next_state_ = STATE_READ_LENGTH;
+ return socket_->Read(
+ buffer_.get(),
+ buffer_->BytesRemaining(),
+ base::Bind(&DnsTCPAttempt::OnIOComplete, base::Unretained(this)));
+ }
+ ReadBigEndian<uint16>(length_buffer_->data(), &response_length_);
+ // Check if advertised response is too short. (Optimization only.)
+ if (response_length_ < query_->io_buffer()->size())
+ return ERR_DNS_MALFORMED_RESPONSE;
+ // Allocate more space so that DnsResponse::InitParse sanity check passes.
+ response_.reset(new DnsResponse(response_length_ + 1));
+ buffer_ = new DrainableIOBuffer(response_->io_buffer(), response_length_);
+ next_state_ = STATE_READ_RESPONSE;
+ return OK;
+ }
+
+ int DoReadResponse(int rv) {
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ if (rv < 0)
+ return rv;
+
+ buffer_->DidConsume(rv);
+ if (buffer_->BytesRemaining() > 0) {
+ next_state_ = STATE_READ_RESPONSE;
+ return socket_->Read(
+ buffer_.get(),
+ buffer_->BytesRemaining(),
+ base::Bind(&DnsTCPAttempt::OnIOComplete, base::Unretained(this)));
+ }
+ if (!response_->InitParse(buffer_->BytesConsumed(), *query_))
+ return ERR_DNS_MALFORMED_RESPONSE;
+ if (response_->flags() & dns_protocol::kFlagTC)
+ return ERR_UNEXPECTED;
+ // TODO(szym): Frankly, none of these are expected.
+ if (response_->rcode() == dns_protocol::kRcodeNXDOMAIN)
+ return ERR_NAME_NOT_RESOLVED;
+ if (response_->rcode() != dns_protocol::kRcodeNOERROR)
+ return ERR_DNS_SERVER_FAILED;
+
+ return OK;
+ }
+
+ void OnIOComplete(int rv) {
+ rv = DoLoop(rv);
+ if (rv != ERR_IO_PENDING)
+ callback_.Run(rv);
+ }
+
+ State next_state_;
+ base::TimeTicks start_time_;
+
+ scoped_ptr<StreamSocket> socket_;
+ scoped_ptr<DnsQuery> query_;
+ scoped_refptr<IOBufferWithSize> length_buffer_;
+ scoped_refptr<DrainableIOBuffer> buffer_;
+
+ uint16 response_length_;
+ scoped_ptr<DnsResponse> response_;
+
+ CompletionCallback callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(DnsTCPAttempt);
+};
+
+// ----------------------------------------------------------------------------
+
+// Implements DnsTransaction. Configuration is supplied by DnsSession.
+// The suffix list is built according to the DnsConfig from the session.
+// The timeout for each DnsUDPAttempt is given by DnsSession::NextTimeout.
+// The first server to attempt on each query is given by
+// DnsSession::NextFirstServerIndex, and the order is round-robin afterwards.
+// Each server is attempted DnsConfig::attempts times.
+class DnsTransactionImpl : public DnsTransaction,
+ public base::NonThreadSafe,
+ public base::SupportsWeakPtr<DnsTransactionImpl> {
+ public:
+ DnsTransactionImpl(DnsSession* session,
+ const std::string& hostname,
+ uint16 qtype,
+ const DnsTransactionFactory::CallbackType& callback,
+ const BoundNetLog& net_log)
+ : session_(session),
+ hostname_(hostname),
+ qtype_(qtype),
+ callback_(callback),
+ net_log_(net_log),
+ qnames_initial_size_(0),
+ attempts_count_(0),
+ had_tcp_attempt_(false),
+ first_server_index_(0) {
+ DCHECK(session_.get());
+ DCHECK(!hostname_.empty());
+ DCHECK(!callback_.is_null());
+ DCHECK(!IsIPLiteral(hostname_));
+ }
+
+ virtual ~DnsTransactionImpl() {
+ if (!callback_.is_null()) {
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_DNS_TRANSACTION,
+ ERR_ABORTED);
+ } // otherwise logged in DoCallback or Start
+ }
+
+ virtual const std::string& GetHostname() const OVERRIDE {
+ DCHECK(CalledOnValidThread());
+ return hostname_;
+ }
+
+ virtual uint16 GetType() const OVERRIDE {
+ DCHECK(CalledOnValidThread());
+ return qtype_;
+ }
+
+ virtual void Start() OVERRIDE {
+ DCHECK(!callback_.is_null());
+ DCHECK(attempts_.empty());
+ net_log_.BeginEvent(NetLog::TYPE_DNS_TRANSACTION,
+ base::Bind(&NetLogStartCallback, &hostname_, qtype_));
+ AttemptResult result(PrepareSearch(), NULL);
+ if (result.rv == OK) {
+ qnames_initial_size_ = qnames_.size();
+ if (qtype_ == dns_protocol::kTypeA)
+ UMA_HISTOGRAM_COUNTS("AsyncDNS.SuffixSearchStart", qnames_.size());
+ result = ProcessAttemptResult(StartQuery());
+ }
+
+ // Must always return result asynchronously, to avoid reentrancy.
+ if (result.rv != ERR_IO_PENDING) {
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE,
+ base::Bind(&DnsTransactionImpl::DoCallback, AsWeakPtr(), result));
+ }
+ }
+
+ private:
+ // Wrapper for the result of a DnsUDPAttempt.
+ struct AttemptResult {
+ AttemptResult(int rv, const DnsAttempt* attempt)
+ : rv(rv), attempt(attempt) {}
+
+ int rv;
+ const DnsAttempt* attempt;
+ };
+
+ // Prepares |qnames_| according to the DnsConfig.
+ int PrepareSearch() {
+ const DnsConfig& config = session_->config();
+
+ std::string labeled_hostname;
+ if (!DNSDomainFromDot(hostname_, &labeled_hostname))
+ return ERR_INVALID_ARGUMENT;
+
+ if (hostname_[hostname_.size() - 1] == '.') {
+ // It's a fully-qualified name, no suffix search.
+ qnames_.push_back(labeled_hostname);
+ return OK;
+ }
+
+ int ndots = CountLabels(labeled_hostname) - 1;
+
+ if (ndots > 0 && !config.append_to_multi_label_name) {
+ qnames_.push_back(labeled_hostname);
+ return OK;
+ }
+
+ // Set true when |labeled_hostname| is put on the list.
+ bool had_hostname = false;
+
+ if (ndots >= config.ndots) {
+ qnames_.push_back(labeled_hostname);
+ had_hostname = true;
+ }
+
+ std::string qname;
+ for (size_t i = 0; i < config.search.size(); ++i) {
+ // Ignore invalid (too long) combinations.
+ if (!DNSDomainFromDot(hostname_ + "." + config.search[i], &qname))
+ continue;
+ if (qname.size() == labeled_hostname.size()) {
+ if (had_hostname)
+ continue;
+ had_hostname = true;
+ }
+ qnames_.push_back(qname);
+ }
+
+ if (ndots > 0 && !had_hostname)
+ qnames_.push_back(labeled_hostname);
+
+ return qnames_.empty() ? ERR_DNS_SEARCH_EMPTY : OK;
+ }
+
+ void DoCallback(AttemptResult result) {
+ DCHECK(!callback_.is_null());
+ DCHECK_NE(ERR_IO_PENDING, result.rv);
+ const DnsResponse* response = result.attempt ?
+ result.attempt->GetResponse() : NULL;
+ CHECK(result.rv != OK || response != NULL);
+
+ timer_.Stop();
+ RecordLostPacketsIfAny();
+ if (result.rv == OK)
+ UMA_HISTOGRAM_COUNTS("AsyncDNS.AttemptCountSuccess", attempts_count_);
+ else
+ UMA_HISTOGRAM_COUNTS("AsyncDNS.AttemptCountFail", attempts_count_);
+
+ if (response && qtype_ == dns_protocol::kTypeA) {
+ UMA_HISTOGRAM_COUNTS("AsyncDNS.SuffixSearchRemain", qnames_.size());
+ UMA_HISTOGRAM_COUNTS("AsyncDNS.SuffixSearchDone",
+ qnames_initial_size_ - qnames_.size());
+ }
+
+ DnsTransactionFactory::CallbackType callback = callback_;
+ callback_.Reset();
+
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_DNS_TRANSACTION, result.rv);
+ callback.Run(this, result.rv, response);
+ }
+
+ // Makes another attempt at the current name, |qnames_.front()|, using the
+ // next nameserver.
+ AttemptResult MakeAttempt() {
+ unsigned attempt_number = attempts_.size();
+
+ uint16 id = session_->NextQueryId();
+ scoped_ptr<DnsQuery> query;
+ if (attempts_.empty()) {
+ query.reset(new DnsQuery(id, qnames_.front(), qtype_));
+ } else {
+ query.reset(attempts_[0]->GetQuery()->CloneWithNewId(id));
+ }
+
+ const DnsConfig& config = session_->config();
+
+ unsigned server_index =
+ (first_server_index_ + attempt_number) % config.nameservers.size();
+ // Skip over known failed servers.
+ server_index = session_->NextGoodServerIndex(server_index);
+
+ scoped_ptr<DnsSession::SocketLease> lease =
+ session_->AllocateSocket(server_index, net_log_.source());
+
+ bool got_socket = !!lease.get();
+
+ DnsUDPAttempt* attempt =
+ new DnsUDPAttempt(server_index, lease.Pass(), query.Pass());
+
+ attempts_.push_back(attempt);
+ ++attempts_count_;
+
+ if (!got_socket)
+ return AttemptResult(ERR_CONNECTION_REFUSED, NULL);
+
+ net_log_.AddEvent(
+ NetLog::TYPE_DNS_TRANSACTION_ATTEMPT,
+ attempt->GetSocketNetLog().source().ToEventParametersCallback());
+
+ int rv = attempt->Start(
+ base::Bind(&DnsTransactionImpl::OnUdpAttemptComplete,
+ base::Unretained(this), attempt_number,
+ base::TimeTicks::Now()));
+ if (rv == ERR_IO_PENDING) {
+ base::TimeDelta timeout = session_->NextTimeout(server_index,
+ attempt_number);
+ timer_.Start(FROM_HERE, timeout, this, &DnsTransactionImpl::OnTimeout);
+ }
+ return AttemptResult(rv, attempt);
+ }
+
+ AttemptResult MakeTCPAttempt(const DnsAttempt* previous_attempt) {
+ DCHECK(previous_attempt);
+ DCHECK(!had_tcp_attempt_);
+
+ unsigned server_index = previous_attempt->server_index();
+
+ scoped_ptr<StreamSocket> socket(
+ session_->CreateTCPSocket(server_index, net_log_.source()));
+
+ // TODO(szym): Reuse the same id to help the server?
+ uint16 id = session_->NextQueryId();
+ scoped_ptr<DnsQuery> query(
+ previous_attempt->GetQuery()->CloneWithNewId(id));
+
+ RecordLostPacketsIfAny();
+ // Cancel all other attempts, no point waiting on them.
+ attempts_.clear();
+
+ unsigned attempt_number = attempts_.size();
+
+ DnsTCPAttempt* attempt = new DnsTCPAttempt(server_index, socket.Pass(),
+ query.Pass());
+
+ attempts_.push_back(attempt);
+ ++attempts_count_;
+ had_tcp_attempt_ = true;
+
+ net_log_.AddEvent(
+ NetLog::TYPE_DNS_TRANSACTION_TCP_ATTEMPT,
+ attempt->GetSocketNetLog().source().ToEventParametersCallback());
+
+ int rv = attempt->Start(base::Bind(&DnsTransactionImpl::OnAttemptComplete,
+ base::Unretained(this),
+ attempt_number));
+ if (rv == ERR_IO_PENDING) {
+ // Custom timeout for TCP attempt.
+ base::TimeDelta timeout = timer_.GetCurrentDelay() * 2;
+ timer_.Start(FROM_HERE, timeout, this, &DnsTransactionImpl::OnTimeout);
+ }
+ return AttemptResult(rv, attempt);
+ }
+
+ // Begins query for the current name. Makes the first attempt.
+ AttemptResult StartQuery() {
+ std::string dotted_qname = DNSDomainToString(qnames_.front());
+ net_log_.BeginEvent(NetLog::TYPE_DNS_TRANSACTION_QUERY,
+ NetLog::StringCallback("qname", &dotted_qname));
+
+ first_server_index_ = session_->NextFirstServerIndex();
+ RecordLostPacketsIfAny();
+ attempts_.clear();
+ had_tcp_attempt_ = false;
+ return MakeAttempt();
+ }
+
+ void OnUdpAttemptComplete(unsigned attempt_number,
+ base::TimeTicks start,
+ int rv) {
+ DCHECK_LT(attempt_number, attempts_.size());
+ const DnsAttempt* attempt = attempts_[attempt_number];
+ if (attempt->GetResponse()) {
+ session_->RecordRTT(attempt->server_index(),
+ base::TimeTicks::Now() - start);
+ }
+ OnAttemptComplete(attempt_number, rv);
+ }
+
+ void OnAttemptComplete(unsigned attempt_number, int rv) {
+ if (callback_.is_null())
+ return;
+ DCHECK_LT(attempt_number, attempts_.size());
+ const DnsAttempt* attempt = attempts_[attempt_number];
+ AttemptResult result = ProcessAttemptResult(AttemptResult(rv, attempt));
+ if (result.rv != ERR_IO_PENDING)
+ DoCallback(result);
+ }
+
+ // Record packet loss for any incomplete attempts.
+ void RecordLostPacketsIfAny() {
+ // Loop through attempts until we find first that is completed
+ size_t first_completed = 0;
+ for (first_completed = 0; first_completed < attempts_.size();
+ ++first_completed) {
+ if (attempts_[first_completed]->is_completed())
+ break;
+ }
+ // If there were no completed attempts, then we must be offline, so don't
+ // record any attempts as lost packets.
+ if (first_completed == attempts_.size())
+ return;
+
+ size_t num_servers = session_->config().nameservers.size();
+ std::vector<int> server_attempts(num_servers);
+ for (size_t i = 0; i < first_completed; ++i) {
+ unsigned server_index = attempts_[i]->server_index();
+ int server_attempt = server_attempts[server_index]++;
+ // Don't record lost packet unless attempt is in pending state.
+ if (!attempts_[i]->is_pending())
+ continue;
+ session_->RecordLostPacket(server_index, server_attempt);
+ }
+ }
+
+ void LogResponse(const DnsAttempt* attempt) {
+ if (attempt && attempt->GetResponse()) {
+ net_log_.AddEvent(
+ NetLog::TYPE_DNS_TRANSACTION_RESPONSE,
+ base::Bind(&DnsAttempt::NetLogResponseCallback,
+ base::Unretained(attempt)));
+ }
+ }
+
+ bool MoreAttemptsAllowed() const {
+ if (had_tcp_attempt_)
+ return false;
+ const DnsConfig& config = session_->config();
+ return attempts_.size() < config.attempts * config.nameservers.size();
+ }
+
+ // Resolves the result of a DnsAttempt until a terminal result is reached
+ // or it will complete asynchronously (ERR_IO_PENDING).
+ AttemptResult ProcessAttemptResult(AttemptResult result) {
+ while (result.rv != ERR_IO_PENDING) {
+ LogResponse(result.attempt);
+
+ switch (result.rv) {
+ case OK:
+ session_->RecordServerSuccess(result.attempt->server_index());
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_DNS_TRANSACTION_QUERY,
+ result.rv);
+ DCHECK(result.attempt);
+ DCHECK(result.attempt->GetResponse());
+ return result;
+ case ERR_NAME_NOT_RESOLVED:
+ session_->RecordServerSuccess(result.attempt->server_index());
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_DNS_TRANSACTION_QUERY,
+ result.rv);
+ // Try next suffix.
+ qnames_.pop_front();
+ if (qnames_.empty()) {
+ return AttemptResult(ERR_NAME_NOT_RESOLVED, NULL);
+ } else {
+ result = StartQuery();
+ }
+ break;
+ case ERR_CONNECTION_REFUSED:
+ case ERR_DNS_TIMED_OUT:
+ if (result.attempt)
+ session_->RecordServerFailure(result.attempt->server_index());
+ if (MoreAttemptsAllowed()) {
+ result = MakeAttempt();
+ } else {
+ return result;
+ }
+ break;
+ case ERR_DNS_SERVER_REQUIRES_TCP:
+ result = MakeTCPAttempt(result.attempt);
+ break;
+ default:
+ // Server failure.
+ DCHECK(result.attempt);
+ if (result.attempt != attempts_.back()) {
+ // This attempt already timed out. Ignore it.
+ session_->RecordServerFailure(result.attempt->server_index());
+ return AttemptResult(ERR_IO_PENDING, NULL);
+ }
+ if (MoreAttemptsAllowed()) {
+ result = MakeAttempt();
+ } else if (result.rv == ERR_DNS_MALFORMED_RESPONSE &&
+ !had_tcp_attempt_) {
+ // For UDP only, ignore the response and wait until the last attempt
+ // times out.
+ return AttemptResult(ERR_IO_PENDING, NULL);
+ } else {
+ return AttemptResult(result.rv, NULL);
+ }
+ break;
+ }
+ }
+ return result;
+ }
+
+ void OnTimeout() {
+ if (callback_.is_null())
+ return;
+ DCHECK(!attempts_.empty());
+ AttemptResult result = ProcessAttemptResult(
+ AttemptResult(ERR_DNS_TIMED_OUT, attempts_.back()));
+ if (result.rv != ERR_IO_PENDING)
+ DoCallback(result);
+ }
+
+ scoped_refptr<DnsSession> session_;
+ std::string hostname_;
+ uint16 qtype_;
+ // Cleared in DoCallback.
+ DnsTransactionFactory::CallbackType callback_;
+
+ BoundNetLog net_log_;
+
+ // Search list of fully-qualified DNS names to query next (in DNS format).
+ std::deque<std::string> qnames_;
+ size_t qnames_initial_size_;
+
+ // List of attempts for the current name.
+ ScopedVector<DnsAttempt> attempts_;
+ // Count of attempts, not reset when |attempts_| vector is cleared.
+ int attempts_count_;
+ bool had_tcp_attempt_;
+
+ // Index of the first server to try on each search query.
+ int first_server_index_;
+
+ base::OneShotTimer<DnsTransactionImpl> timer_;
+
+ DISALLOW_COPY_AND_ASSIGN(DnsTransactionImpl);
+};
+
+// ----------------------------------------------------------------------------
+
+// Implementation of DnsTransactionFactory that returns instances of
+// DnsTransactionImpl.
+class DnsTransactionFactoryImpl : public DnsTransactionFactory {
+ public:
+ explicit DnsTransactionFactoryImpl(DnsSession* session) {
+ session_ = session;
+ }
+
+ virtual scoped_ptr<DnsTransaction> CreateTransaction(
+ const std::string& hostname,
+ uint16 qtype,
+ const CallbackType& callback,
+ const BoundNetLog& net_log) OVERRIDE {
+ return scoped_ptr<DnsTransaction>(new DnsTransactionImpl(
+ session_.get(), hostname, qtype, callback, net_log));
+ }
+
+ private:
+ scoped_refptr<DnsSession> session_;
+};
+
+} // namespace
+
+// static
+scoped_ptr<DnsTransactionFactory> DnsTransactionFactory::CreateFactory(
+ DnsSession* session) {
+ return scoped_ptr<DnsTransactionFactory>(
+ new DnsTransactionFactoryImpl(session));
+}
+
+} // namespace net
diff --git a/chromium/net/dns/dns_transaction.h b/chromium/net/dns/dns_transaction.h
new file mode 100644
index 00000000000..faf4f64e79d
--- /dev/null
+++ b/chromium/net/dns/dns_transaction.h
@@ -0,0 +1,78 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_DNS_TRANSACTION_H_
+#define NET_DNS_DNS_TRANSACTION_H_
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/callback_forward.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/net_export.h"
+
+namespace net {
+
+class BoundNetLog;
+class DnsResponse;
+class DnsSession;
+
+// DnsTransaction implements a stub DNS resolver as defined in RFC 1034.
+// The DnsTransaction takes care of retransmissions, name server fallback (or
+// round-robin), suffix search, and simple response validation ("does it match
+// the query") to fight poisoning.
+//
+// Destroying DnsTransaction cancels the underlying network effort.
+class NET_EXPORT_PRIVATE DnsTransaction {
+ public:
+ virtual ~DnsTransaction() {}
+
+ // Returns the original |hostname|.
+ virtual const std::string& GetHostname() const = 0;
+
+ // Returns the |qtype|.
+ virtual uint16 GetType() const = 0;
+
+ // Starts the transaction. Always completes asynchronously.
+ virtual void Start() = 0;
+};
+
+// Creates DnsTransaction which performs asynchronous DNS search.
+// It does NOT perform caching, aggregation or prioritization of transactions.
+//
+// Destroying the factory does NOT affect any already created DnsTransactions.
+class NET_EXPORT_PRIVATE DnsTransactionFactory {
+ public:
+ // Called with the response or NULL if no matching response was received.
+ // Note that the |GetDottedName()| of the response may be different than the
+ // original |hostname| as a result of suffix search.
+ typedef base::Callback<void(DnsTransaction* transaction,
+ int neterror,
+ const DnsResponse* response)> CallbackType;
+
+ virtual ~DnsTransactionFactory() {}
+
+ // Creates DnsTransaction for the given |hostname| and |qtype| (assuming
+ // QCLASS is IN). |hostname| should be in the dotted form. A dot at the end
+ // implies the domain name is fully-qualified and will be exempt from suffix
+ // search. |hostname| should not be an IP literal.
+ //
+ // The transaction will run |callback| upon asynchronous completion.
+ // The |net_log| is used as the parent log.
+ virtual scoped_ptr<DnsTransaction> CreateTransaction(
+ const std::string& hostname,
+ uint16 qtype,
+ const CallbackType& callback,
+ const BoundNetLog& net_log) WARN_UNUSED_RESULT = 0;
+
+ // Creates a DnsTransactionFactory which creates DnsTransactionImpl using the
+ // |session|.
+ static scoped_ptr<DnsTransactionFactory> CreateFactory(
+ DnsSession* session) WARN_UNUSED_RESULT;
+};
+
+} // namespace net
+
+#endif // NET_DNS_DNS_TRANSACTION_H_
+
diff --git a/chromium/net/dns/dns_transaction_unittest.cc b/chromium/net/dns/dns_transaction_unittest.cc
new file mode 100644
index 00000000000..7040e44be16
--- /dev/null
+++ b/chromium/net/dns/dns_transaction_unittest.cc
@@ -0,0 +1,940 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/dns_transaction.h"
+
+#include "base/bind.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/memory/scoped_vector.h"
+#include "base/rand_util.h"
+#include "base/sys_byteorder.h"
+#include "base/test/test_timeouts.h"
+#include "net/base/big_endian.h"
+#include "net/base/dns_util.h"
+#include "net/base/net_log.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/dns_query.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/dns_session.h"
+#include "net/dns/dns_test_util.h"
+#include "net/socket/socket_test_util.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+std::string DomainFromDot(const base::StringPiece& dotted) {
+ std::string out;
+ EXPECT_TRUE(DNSDomainFromDot(dotted, &out));
+ return out;
+}
+
+// A SocketDataProvider builder.
+class DnsSocketData {
+ public:
+ // The ctor takes parameters for the DnsQuery.
+ DnsSocketData(uint16 id,
+ const char* dotted_name,
+ uint16 qtype,
+ IoMode mode,
+ bool use_tcp)
+ : query_(new DnsQuery(id, DomainFromDot(dotted_name), qtype)),
+ use_tcp_(use_tcp) {
+ if (use_tcp_) {
+ scoped_ptr<uint16> length(new uint16);
+ *length = base::HostToNet16(query_->io_buffer()->size());
+ writes_.push_back(MockWrite(mode,
+ reinterpret_cast<const char*>(length.get()),
+ sizeof(uint16)));
+ lengths_.push_back(length.release());
+ }
+ writes_.push_back(MockWrite(mode,
+ query_->io_buffer()->data(),
+ query_->io_buffer()->size()));
+ }
+ ~DnsSocketData() {}
+
+ // All responses must be added before GetProvider.
+
+ // Adds pre-built DnsResponse. |tcp_length| will be used in TCP mode only.
+ void AddResponseWithLength(scoped_ptr<DnsResponse> response, IoMode mode,
+ uint16 tcp_length) {
+ CHECK(!provider_.get());
+ if (use_tcp_) {
+ scoped_ptr<uint16> length(new uint16);
+ *length = base::HostToNet16(tcp_length);
+ reads_.push_back(MockRead(mode,
+ reinterpret_cast<const char*>(length.get()),
+ sizeof(uint16)));
+ lengths_.push_back(length.release());
+ }
+ reads_.push_back(MockRead(mode,
+ response->io_buffer()->data(),
+ response->io_buffer()->size()));
+ responses_.push_back(response.release());
+ }
+
+ // Adds pre-built DnsResponse.
+ void AddResponse(scoped_ptr<DnsResponse> response, IoMode mode) {
+ uint16 tcp_length = response->io_buffer()->size();
+ AddResponseWithLength(response.Pass(), mode, tcp_length);
+ }
+
+ // Adds pre-built response from |data| buffer.
+ void AddResponseData(const uint8* data, size_t length, IoMode mode) {
+ CHECK(!provider_.get());
+ AddResponse(make_scoped_ptr(
+ new DnsResponse(reinterpret_cast<const char*>(data), length, 0)), mode);
+ }
+
+ // Add no-answer (RCODE only) response matching the query.
+ void AddRcode(int rcode, IoMode mode) {
+ scoped_ptr<DnsResponse> response(
+ new DnsResponse(query_->io_buffer()->data(),
+ query_->io_buffer()->size(),
+ 0));
+ dns_protocol::Header* header =
+ reinterpret_cast<dns_protocol::Header*>(response->io_buffer()->data());
+ header->flags |= base::HostToNet16(dns_protocol::kFlagResponse | rcode);
+ AddResponse(response.Pass(), mode);
+ }
+
+ // Build, if needed, and return the SocketDataProvider. No new responses
+ // should be added afterwards.
+ SocketDataProvider* GetProvider() {
+ if (provider_.get())
+ return provider_.get();
+ // Terminate the reads with ERR_IO_PENDING to prevent overrun and default to
+ // timeout.
+ reads_.push_back(MockRead(ASYNC, ERR_IO_PENDING));
+ provider_.reset(new DelayedSocketData(1, &reads_[0], reads_.size(),
+ &writes_[0], writes_.size()));
+ if (use_tcp_) {
+ provider_->set_connect_data(MockConnect(reads_[0].mode, OK));
+ }
+ return provider_.get();
+ }
+
+ uint16 query_id() const {
+ return query_->id();
+ }
+
+ // Returns true if the expected query was written to the socket.
+ bool was_written() const {
+ CHECK(provider_.get());
+ return provider_->write_index() > 0;
+ }
+
+ private:
+ scoped_ptr<DnsQuery> query_;
+ bool use_tcp_;
+ ScopedVector<uint16> lengths_;
+ ScopedVector<DnsResponse> responses_;
+ std::vector<MockWrite> writes_;
+ std::vector<MockRead> reads_;
+ scoped_ptr<DelayedSocketData> provider_;
+
+ DISALLOW_COPY_AND_ASSIGN(DnsSocketData);
+};
+
+class TestSocketFactory;
+
+// A variant of MockUDPClientSocket which always fails to Connect.
+class FailingUDPClientSocket : public MockUDPClientSocket {
+ public:
+ FailingUDPClientSocket(SocketDataProvider* data,
+ net::NetLog* net_log)
+ : MockUDPClientSocket(data, net_log) {
+ }
+ virtual ~FailingUDPClientSocket() {}
+ virtual int Connect(const IPEndPoint& endpoint) OVERRIDE {
+ return ERR_CONNECTION_REFUSED;
+ }
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(FailingUDPClientSocket);
+};
+
+// A variant of MockUDPClientSocket which notifies the factory OnConnect.
+class TestUDPClientSocket : public MockUDPClientSocket {
+ public:
+ TestUDPClientSocket(TestSocketFactory* factory,
+ SocketDataProvider* data,
+ net::NetLog* net_log)
+ : MockUDPClientSocket(data, net_log), factory_(factory) {
+ }
+ virtual ~TestUDPClientSocket() {}
+ virtual int Connect(const IPEndPoint& endpoint) OVERRIDE;
+
+ private:
+ TestSocketFactory* factory_;
+
+ DISALLOW_COPY_AND_ASSIGN(TestUDPClientSocket);
+};
+
+// Creates TestUDPClientSockets and keeps endpoints reported via OnConnect.
+class TestSocketFactory : public MockClientSocketFactory {
+ public:
+ TestSocketFactory() : fail_next_socket_(false) {}
+ virtual ~TestSocketFactory() {}
+
+ virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ net::NetLog* net_log,
+ const net::NetLog::Source& source) OVERRIDE {
+ if (fail_next_socket_) {
+ fail_next_socket_ = false;
+ return scoped_ptr<DatagramClientSocket>(
+ new FailingUDPClientSocket(&empty_data_, net_log));
+ }
+ SocketDataProvider* data_provider = mock_data().GetNext();
+ scoped_ptr<TestUDPClientSocket> socket(
+ new TestUDPClientSocket(this, data_provider, net_log));
+ data_provider->set_socket(socket.get());
+ return socket.PassAs<DatagramClientSocket>();
+ }
+
+ void OnConnect(const IPEndPoint& endpoint) {
+ remote_endpoints_.push_back(endpoint);
+ }
+
+ std::vector<IPEndPoint> remote_endpoints_;
+ bool fail_next_socket_;
+
+ private:
+ StaticSocketDataProvider empty_data_;
+
+ DISALLOW_COPY_AND_ASSIGN(TestSocketFactory);
+};
+
+int TestUDPClientSocket::Connect(const IPEndPoint& endpoint) {
+ factory_->OnConnect(endpoint);
+ return MockUDPClientSocket::Connect(endpoint);
+}
+
+// Helper class that holds a DnsTransaction and handles OnTransactionComplete.
+class TransactionHelper {
+ public:
+ // If |expected_answer_count| < 0 then it is the expected net error.
+ TransactionHelper(const char* hostname,
+ uint16 qtype,
+ int expected_answer_count)
+ : hostname_(hostname),
+ qtype_(qtype),
+ expected_answer_count_(expected_answer_count),
+ cancel_in_callback_(false),
+ quit_in_callback_(false),
+ completed_(false) {
+ }
+
+ // Mark that the transaction shall be destroyed immediately upon callback.
+ void set_cancel_in_callback() {
+ cancel_in_callback_ = true;
+ }
+
+ // Mark to call MessageLoop::Quit() upon callback.
+ void set_quit_in_callback() {
+ quit_in_callback_ = true;
+ }
+
+ void StartTransaction(DnsTransactionFactory* factory) {
+ EXPECT_EQ(NULL, transaction_.get());
+ transaction_ = factory->CreateTransaction(
+ hostname_,
+ qtype_,
+ base::Bind(&TransactionHelper::OnTransactionComplete,
+ base::Unretained(this)),
+ BoundNetLog());
+ EXPECT_EQ(hostname_, transaction_->GetHostname());
+ EXPECT_EQ(qtype_, transaction_->GetType());
+ transaction_->Start();
+ }
+
+ void Cancel() {
+ ASSERT_TRUE(transaction_.get() != NULL);
+ transaction_.reset(NULL);
+ }
+
+ void OnTransactionComplete(DnsTransaction* t,
+ int rv,
+ const DnsResponse* response) {
+ EXPECT_FALSE(completed_);
+ EXPECT_EQ(transaction_.get(), t);
+
+ completed_ = true;
+
+ if (cancel_in_callback_) {
+ Cancel();
+ return;
+ }
+
+ // Tell MessageLoop to quit now, in case any ASSERT_* fails.
+ if (quit_in_callback_)
+ base::MessageLoop::current()->Quit();
+
+ if (expected_answer_count_ >= 0) {
+ ASSERT_EQ(OK, rv);
+ ASSERT_TRUE(response != NULL);
+ EXPECT_EQ(static_cast<unsigned>(expected_answer_count_),
+ response->answer_count());
+ EXPECT_EQ(qtype_, response->qtype());
+
+ DnsRecordParser parser = response->Parser();
+ DnsResourceRecord record;
+ for (int i = 0; i < expected_answer_count_; ++i) {
+ EXPECT_TRUE(parser.ReadRecord(&record));
+ }
+ } else {
+ EXPECT_EQ(expected_answer_count_, rv);
+ }
+ }
+
+ bool has_completed() const {
+ return completed_;
+ }
+
+ // Shorthands for commonly used commands.
+
+ bool Run(DnsTransactionFactory* factory) {
+ StartTransaction(factory);
+ base::MessageLoop::current()->RunUntilIdle();
+ return has_completed();
+ }
+
+ // Use when some of the responses are timeouts.
+ bool RunUntilDone(DnsTransactionFactory* factory) {
+ set_quit_in_callback();
+ StartTransaction(factory);
+ base::MessageLoop::current()->Run();
+ return has_completed();
+ }
+
+ private:
+ std::string hostname_;
+ uint16 qtype_;
+ scoped_ptr<DnsTransaction> transaction_;
+ int expected_answer_count_;
+ bool cancel_in_callback_;
+ bool quit_in_callback_;
+
+ bool completed_;
+};
+
+class DnsTransactionTest : public testing::Test {
+ public:
+ DnsTransactionTest() {}
+
+ // Generates |nameservers| for DnsConfig.
+ void ConfigureNumServers(unsigned num_servers) {
+ CHECK_LE(num_servers, 255u);
+ config_.nameservers.clear();
+ IPAddressNumber dns_ip;
+ {
+ bool rv = ParseIPLiteralToNumber("192.168.1.0", &dns_ip);
+ EXPECT_TRUE(rv);
+ }
+ for (unsigned i = 0; i < num_servers; ++i) {
+ dns_ip[3] = i;
+ config_.nameservers.push_back(IPEndPoint(dns_ip,
+ dns_protocol::kDefaultPort));
+ }
+ }
+
+ // Called after fully configuring |config|.
+ void ConfigureFactory() {
+ socket_factory_.reset(new TestSocketFactory());
+ session_ = new DnsSession(
+ config_,
+ DnsSocketPool::CreateNull(socket_factory_.get()),
+ base::Bind(&DnsTransactionTest::GetNextId, base::Unretained(this)),
+ NULL /* NetLog */);
+ transaction_factory_ = DnsTransactionFactory::CreateFactory(session_.get());
+ }
+
+ void AddSocketData(scoped_ptr<DnsSocketData> data) {
+ CHECK(socket_factory_.get());
+ transaction_ids_.push_back(data->query_id());
+ socket_factory_->AddSocketDataProvider(data->GetProvider());
+ socket_data_.push_back(data.release());
+ }
+
+ // Add expected query for |dotted_name| and |qtype| with |id| and response
+ // taken verbatim from |data| of |data_length| bytes. The transaction id in
+ // |data| should equal |id|, unless testing mismatched response.
+ void AddQueryAndResponse(uint16 id,
+ const char* dotted_name,
+ uint16 qtype,
+ const uint8* response_data,
+ size_t response_length,
+ IoMode mode,
+ bool use_tcp) {
+ CHECK(socket_factory_.get());
+ scoped_ptr<DnsSocketData> data(
+ new DnsSocketData(id, dotted_name, qtype, mode, use_tcp));
+ data->AddResponseData(response_data, response_length, mode);
+ AddSocketData(data.Pass());
+ }
+
+ void AddAsyncQueryAndResponse(uint16 id,
+ const char* dotted_name,
+ uint16 qtype,
+ const uint8* data,
+ size_t data_length) {
+ AddQueryAndResponse(id, dotted_name, qtype, data, data_length, ASYNC,
+ false);
+ }
+
+ void AddSyncQueryAndResponse(uint16 id,
+ const char* dotted_name,
+ uint16 qtype,
+ const uint8* data,
+ size_t data_length) {
+ AddQueryAndResponse(id, dotted_name, qtype, data, data_length, SYNCHRONOUS,
+ false);
+ }
+
+ // Add expected query of |dotted_name| and |qtype| and no response.
+ void AddQueryAndTimeout(const char* dotted_name, uint16 qtype) {
+ uint16 id = base::RandInt(0, kuint16max);
+ scoped_ptr<DnsSocketData> data(
+ new DnsSocketData(id, dotted_name, qtype, ASYNC, false));
+ AddSocketData(data.Pass());
+ }
+
+ // Add expected query of |dotted_name| and |qtype| and matching response with
+ // no answer and RCODE set to |rcode|. The id will be generated randomly.
+ void AddQueryAndRcode(const char* dotted_name,
+ uint16 qtype,
+ int rcode,
+ IoMode mode,
+ bool use_tcp) {
+ CHECK_NE(dns_protocol::kRcodeNOERROR, rcode);
+ uint16 id = base::RandInt(0, kuint16max);
+ scoped_ptr<DnsSocketData> data(
+ new DnsSocketData(id, dotted_name, qtype, mode, use_tcp));
+ data->AddRcode(rcode, mode);
+ AddSocketData(data.Pass());
+ }
+
+ void AddAsyncQueryAndRcode(const char* dotted_name, uint16 qtype, int rcode) {
+ AddQueryAndRcode(dotted_name, qtype, rcode, ASYNC, false);
+ }
+
+ void AddSyncQueryAndRcode(const char* dotted_name, uint16 qtype, int rcode) {
+ AddQueryAndRcode(dotted_name, qtype, rcode, SYNCHRONOUS, false);
+ }
+
+ // Checks if the sockets were connected in the order matching the indices in
+ // |servers|.
+ void CheckServerOrder(const unsigned* servers, size_t num_attempts) {
+ ASSERT_EQ(num_attempts, socket_factory_->remote_endpoints_.size());
+ for (size_t i = 0; i < num_attempts; ++i) {
+ EXPECT_EQ(socket_factory_->remote_endpoints_[i],
+ session_->config().nameservers[servers[i]]);
+ }
+ }
+
+ virtual void SetUp() OVERRIDE {
+ // By default set one server,
+ ConfigureNumServers(1);
+ // and no retransmissions,
+ config_.attempts = 1;
+ // but long enough timeout for memory tests.
+ config_.timeout = TestTimeouts::action_timeout();
+ ConfigureFactory();
+ }
+
+ virtual void TearDown() OVERRIDE {
+ // Check that all socket data was at least written to.
+ for (size_t i = 0; i < socket_data_.size(); ++i) {
+ EXPECT_TRUE(socket_data_[i]->was_written()) << i;
+ }
+ }
+
+ protected:
+ int GetNextId(int min, int max) {
+ EXPECT_FALSE(transaction_ids_.empty());
+ int id = transaction_ids_.front();
+ transaction_ids_.pop_front();
+ EXPECT_GE(id, min);
+ EXPECT_LE(id, max);
+ return id;
+ }
+
+ DnsConfig config_;
+
+ ScopedVector<DnsSocketData> socket_data_;
+
+ std::deque<int> transaction_ids_;
+ scoped_ptr<TestSocketFactory> socket_factory_;
+ scoped_refptr<DnsSession> session_;
+ scoped_ptr<DnsTransactionFactory> transaction_factory_;
+};
+
+TEST_F(DnsTransactionTest, Lookup) {
+ AddAsyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype,
+ kT0ResponseDatagram, arraysize(kT0ResponseDatagram));
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount);
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+}
+
+// Concurrent lookup tests assume that DnsTransaction::Start immediately
+// consumes a socket from ClientSocketFactory.
+TEST_F(DnsTransactionTest, ConcurrentLookup) {
+ AddAsyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype,
+ kT0ResponseDatagram, arraysize(kT0ResponseDatagram));
+ AddAsyncQueryAndResponse(1 /* id */, kT1HostName, kT1Qtype,
+ kT1ResponseDatagram, arraysize(kT1ResponseDatagram));
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount);
+ helper0.StartTransaction(transaction_factory_.get());
+ TransactionHelper helper1(kT1HostName, kT1Qtype, kT1RecordCount);
+ helper1.StartTransaction(transaction_factory_.get());
+
+ base::MessageLoop::current()->RunUntilIdle();
+
+ EXPECT_TRUE(helper0.has_completed());
+ EXPECT_TRUE(helper1.has_completed());
+}
+
+TEST_F(DnsTransactionTest, CancelLookup) {
+ AddAsyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype,
+ kT0ResponseDatagram, arraysize(kT0ResponseDatagram));
+ AddAsyncQueryAndResponse(1 /* id */, kT1HostName, kT1Qtype,
+ kT1ResponseDatagram, arraysize(kT1ResponseDatagram));
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount);
+ helper0.StartTransaction(transaction_factory_.get());
+ TransactionHelper helper1(kT1HostName, kT1Qtype, kT1RecordCount);
+ helper1.StartTransaction(transaction_factory_.get());
+
+ helper0.Cancel();
+
+ base::MessageLoop::current()->RunUntilIdle();
+
+ EXPECT_FALSE(helper0.has_completed());
+ EXPECT_TRUE(helper1.has_completed());
+}
+
+TEST_F(DnsTransactionTest, DestroyFactory) {
+ AddAsyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype,
+ kT0ResponseDatagram, arraysize(kT0ResponseDatagram));
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount);
+ helper0.StartTransaction(transaction_factory_.get());
+
+ // Destroying the client does not affect running requests.
+ transaction_factory_.reset(NULL);
+
+ base::MessageLoop::current()->RunUntilIdle();
+
+ EXPECT_TRUE(helper0.has_completed());
+}
+
+TEST_F(DnsTransactionTest, CancelFromCallback) {
+ AddAsyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype,
+ kT0ResponseDatagram, arraysize(kT0ResponseDatagram));
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount);
+ helper0.set_cancel_in_callback();
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, MismatchedResponseSync) {
+ config_.attempts = 2;
+ config_.timeout = TestTimeouts::tiny_timeout();
+ ConfigureFactory();
+
+ // Attempt receives mismatched response followed by valid response.
+ scoped_ptr<DnsSocketData> data(
+ new DnsSocketData(0 /* id */, kT0HostName, kT0Qtype, SYNCHRONOUS, false));
+ data->AddResponseData(kT1ResponseDatagram,
+ arraysize(kT1ResponseDatagram), SYNCHRONOUS);
+ data->AddResponseData(kT0ResponseDatagram,
+ arraysize(kT0ResponseDatagram), SYNCHRONOUS);
+ AddSocketData(data.Pass());
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount);
+ EXPECT_TRUE(helper0.RunUntilDone(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, MismatchedResponseAsync) {
+ config_.attempts = 2;
+ config_.timeout = TestTimeouts::tiny_timeout();
+ ConfigureFactory();
+
+ // First attempt receives mismatched response followed by valid response.
+ // Second attempt times out.
+ scoped_ptr<DnsSocketData> data(
+ new DnsSocketData(0 /* id */, kT0HostName, kT0Qtype, ASYNC, false));
+ data->AddResponseData(kT1ResponseDatagram,
+ arraysize(kT1ResponseDatagram), ASYNC);
+ data->AddResponseData(kT0ResponseDatagram,
+ arraysize(kT0ResponseDatagram), ASYNC);
+ AddSocketData(data.Pass());
+ AddQueryAndTimeout(kT0HostName, kT0Qtype);
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount);
+ EXPECT_TRUE(helper0.RunUntilDone(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, MismatchedResponseFail) {
+ config_.timeout = TestTimeouts::tiny_timeout();
+ ConfigureFactory();
+
+ // Attempt receives mismatched response but times out because only one attempt
+ // is allowed.
+ AddAsyncQueryAndResponse(1 /* id */, kT0HostName, kT0Qtype,
+ kT0ResponseDatagram, arraysize(kT0ResponseDatagram));
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_DNS_TIMED_OUT);
+ EXPECT_TRUE(helper0.RunUntilDone(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, ServerFail) {
+ AddAsyncQueryAndRcode(kT0HostName, kT0Qtype, dns_protocol::kRcodeSERVFAIL);
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_DNS_SERVER_FAILED);
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, NoDomain) {
+ AddAsyncQueryAndRcode(kT0HostName, kT0Qtype, dns_protocol::kRcodeNXDOMAIN);
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_NAME_NOT_RESOLVED);
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, Timeout) {
+ config_.attempts = 3;
+ // Use short timeout to speed up the test.
+ config_.timeout = TestTimeouts::tiny_timeout();
+ ConfigureFactory();
+
+ AddQueryAndTimeout(kT0HostName, kT0Qtype);
+ AddQueryAndTimeout(kT0HostName, kT0Qtype);
+ AddQueryAndTimeout(kT0HostName, kT0Qtype);
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_DNS_TIMED_OUT);
+ EXPECT_TRUE(helper0.RunUntilDone(transaction_factory_.get()));
+ EXPECT_TRUE(base::MessageLoop::current()->IsIdleForTesting());
+}
+
+TEST_F(DnsTransactionTest, ServerFallbackAndRotate) {
+ // Test that we fallback on both server failure and timeout.
+ config_.attempts = 2;
+ // The next request should start from the next server.
+ config_.rotate = true;
+ ConfigureNumServers(3);
+ // Use short timeout to speed up the test.
+ config_.timeout = TestTimeouts::tiny_timeout();
+ ConfigureFactory();
+
+ // Responses for first request.
+ AddQueryAndTimeout(kT0HostName, kT0Qtype);
+ AddAsyncQueryAndRcode(kT0HostName, kT0Qtype, dns_protocol::kRcodeSERVFAIL);
+ AddQueryAndTimeout(kT0HostName, kT0Qtype);
+ AddAsyncQueryAndRcode(kT0HostName, kT0Qtype, dns_protocol::kRcodeSERVFAIL);
+ AddAsyncQueryAndRcode(kT0HostName, kT0Qtype, dns_protocol::kRcodeNXDOMAIN);
+ // Responses for second request.
+ AddAsyncQueryAndRcode(kT1HostName, kT1Qtype, dns_protocol::kRcodeSERVFAIL);
+ AddAsyncQueryAndRcode(kT1HostName, kT1Qtype, dns_protocol::kRcodeSERVFAIL);
+ AddAsyncQueryAndRcode(kT1HostName, kT1Qtype, dns_protocol::kRcodeNXDOMAIN);
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_NAME_NOT_RESOLVED);
+ TransactionHelper helper1(kT1HostName, kT1Qtype, ERR_NAME_NOT_RESOLVED);
+
+ EXPECT_TRUE(helper0.RunUntilDone(transaction_factory_.get()));
+ EXPECT_TRUE(helper1.Run(transaction_factory_.get()));
+
+ unsigned kOrder[] = {
+ 0, 1, 2, 0, 1, // The first transaction.
+ 1, 2, 0, // The second transaction starts from the next server.
+ };
+ CheckServerOrder(kOrder, arraysize(kOrder));
+}
+
+TEST_F(DnsTransactionTest, SuffixSearchAboveNdots) {
+ config_.ndots = 2;
+ config_.search.push_back("a");
+ config_.search.push_back("b");
+ config_.search.push_back("c");
+ config_.rotate = true;
+ ConfigureNumServers(2);
+ ConfigureFactory();
+
+ AddAsyncQueryAndRcode("x.y.z", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ AddAsyncQueryAndRcode("x.y.z.a", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ AddAsyncQueryAndRcode("x.y.z.b", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ AddAsyncQueryAndRcode("x.y.z.c", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+
+ TransactionHelper helper0("x.y.z", dns_protocol::kTypeA,
+ ERR_NAME_NOT_RESOLVED);
+
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+
+ // Also check if suffix search causes server rotation.
+ unsigned kOrder0[] = { 0, 1, 0, 1 };
+ CheckServerOrder(kOrder0, arraysize(kOrder0));
+}
+
+TEST_F(DnsTransactionTest, SuffixSearchBelowNdots) {
+ config_.ndots = 2;
+ config_.search.push_back("a");
+ config_.search.push_back("b");
+ config_.search.push_back("c");
+ ConfigureFactory();
+
+ // Responses for first transaction.
+ AddAsyncQueryAndRcode("x.y.a", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ AddAsyncQueryAndRcode("x.y.b", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ AddAsyncQueryAndRcode("x.y.c", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ AddAsyncQueryAndRcode("x.y", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ // Responses for second transaction.
+ AddAsyncQueryAndRcode("x.a", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ AddAsyncQueryAndRcode("x.b", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ AddAsyncQueryAndRcode("x.c", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ // Responses for third transaction.
+ AddAsyncQueryAndRcode("x", dns_protocol::kTypeAAAA,
+ dns_protocol::kRcodeNXDOMAIN);
+
+ TransactionHelper helper0("x.y", dns_protocol::kTypeA, ERR_NAME_NOT_RESOLVED);
+
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+
+ // A single-label name.
+ TransactionHelper helper1("x", dns_protocol::kTypeA, ERR_NAME_NOT_RESOLVED);
+
+ EXPECT_TRUE(helper1.Run(transaction_factory_.get()));
+
+ // A fully-qualified name.
+ TransactionHelper helper2("x.", dns_protocol::kTypeAAAA,
+ ERR_NAME_NOT_RESOLVED);
+
+ EXPECT_TRUE(helper2.Run(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, EmptySuffixSearch) {
+ // Responses for first transaction.
+ AddAsyncQueryAndRcode("x", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+
+ // A fully-qualified name.
+ TransactionHelper helper0("x.", dns_protocol::kTypeA, ERR_NAME_NOT_RESOLVED);
+
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+
+ // A single label name is not even attempted.
+ TransactionHelper helper1("singlelabel", dns_protocol::kTypeA,
+ ERR_DNS_SEARCH_EMPTY);
+
+ helper1.Run(transaction_factory_.get());
+ EXPECT_TRUE(helper1.has_completed());
+}
+
+TEST_F(DnsTransactionTest, DontAppendToMultiLabelName) {
+ config_.search.push_back("a");
+ config_.search.push_back("b");
+ config_.search.push_back("c");
+ config_.append_to_multi_label_name = false;
+ ConfigureFactory();
+
+ // Responses for first transaction.
+ AddAsyncQueryAndRcode("x.y.z", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ // Responses for second transaction.
+ AddAsyncQueryAndRcode("x.y", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ // Responses for third transaction.
+ AddAsyncQueryAndRcode("x.a", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ AddAsyncQueryAndRcode("x.b", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ AddAsyncQueryAndRcode("x.c", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+
+ TransactionHelper helper0("x.y.z", dns_protocol::kTypeA,
+ ERR_NAME_NOT_RESOLVED);
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+
+ TransactionHelper helper1("x.y", dns_protocol::kTypeA, ERR_NAME_NOT_RESOLVED);
+ EXPECT_TRUE(helper1.Run(transaction_factory_.get()));
+
+ TransactionHelper helper2("x", dns_protocol::kTypeA, ERR_NAME_NOT_RESOLVED);
+ EXPECT_TRUE(helper2.Run(transaction_factory_.get()));
+}
+
+const uint8 kResponseNoData[] = {
+ 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
+ // Question
+ 0x01, 'x', 0x01, 'y', 0x01, 'z', 0x01, 'b', 0x00, 0x00, 0x01, 0x00, 0x01,
+ // Authority section, SOA record, TTL 0x3E6
+ 0x01, 'z', 0x00, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x03, 0xE6,
+ // Minimal RDATA, 18 bytes
+ 0x00, 0x12,
+ 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+};
+
+TEST_F(DnsTransactionTest, SuffixSearchStop) {
+ config_.ndots = 2;
+ config_.search.push_back("a");
+ config_.search.push_back("b");
+ config_.search.push_back("c");
+ ConfigureFactory();
+
+ AddAsyncQueryAndRcode("x.y.z", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ AddAsyncQueryAndRcode("x.y.z.a", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ AddAsyncQueryAndResponse(0 /* id */, "x.y.z.b", dns_protocol::kTypeA,
+ kResponseNoData, arraysize(kResponseNoData));
+
+ TransactionHelper helper0("x.y.z", dns_protocol::kTypeA, 0 /* answers */);
+
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, SyncFirstQuery) {
+ config_.search.push_back("lab.ccs.neu.edu");
+ config_.search.push_back("ccs.neu.edu");
+ ConfigureFactory();
+
+ AddSyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype,
+ kT0ResponseDatagram, arraysize(kT0ResponseDatagram));
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount);
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, SyncFirstQueryWithSearch) {
+ config_.search.push_back("lab.ccs.neu.edu");
+ config_.search.push_back("ccs.neu.edu");
+ ConfigureFactory();
+
+ AddSyncQueryAndRcode("www.lab.ccs.neu.edu", kT2Qtype,
+ dns_protocol::kRcodeNXDOMAIN);
+ // "www.ccs.neu.edu"
+ AddAsyncQueryAndResponse(2 /* id */, kT2HostName, kT2Qtype,
+ kT2ResponseDatagram, arraysize(kT2ResponseDatagram));
+
+ TransactionHelper helper0("www", kT2Qtype, kT2RecordCount);
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, SyncSearchQuery) {
+ config_.search.push_back("lab.ccs.neu.edu");
+ config_.search.push_back("ccs.neu.edu");
+ ConfigureFactory();
+
+ AddAsyncQueryAndRcode("www.lab.ccs.neu.edu", dns_protocol::kTypeA,
+ dns_protocol::kRcodeNXDOMAIN);
+ AddSyncQueryAndResponse(2 /* id */, kT2HostName, kT2Qtype,
+ kT2ResponseDatagram, arraysize(kT2ResponseDatagram));
+
+ TransactionHelper helper0("www", kT2Qtype, kT2RecordCount);
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, ConnectFailure) {
+ socket_factory_->fail_next_socket_ = true;
+ transaction_ids_.push_back(0); // Needed to make a DnsUDPAttempt.
+ TransactionHelper helper0("www.chromium.org", dns_protocol::kTypeA,
+ ERR_CONNECTION_REFUSED);
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, ConnectFailureFollowedBySuccess) {
+ // Retry after server failure.
+ config_.attempts = 2;
+ ConfigureFactory();
+ // First server connection attempt fails.
+ transaction_ids_.push_back(0); // Needed to make a DnsUDPAttempt.
+ socket_factory_->fail_next_socket_ = true;
+ // Second DNS query succeeds.
+ AddAsyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype,
+ kT0ResponseDatagram, arraysize(kT0ResponseDatagram));
+ TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount);
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, TCPLookup) {
+ AddAsyncQueryAndRcode(kT0HostName, kT0Qtype,
+ dns_protocol::kRcodeNOERROR | dns_protocol::kFlagTC);
+ AddQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype,
+ kT0ResponseDatagram, arraysize(kT0ResponseDatagram),
+ ASYNC, true /* use_tcp */);
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount);
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, TCPFailure) {
+ AddAsyncQueryAndRcode(kT0HostName, kT0Qtype,
+ dns_protocol::kRcodeNOERROR | dns_protocol::kFlagTC);
+ AddQueryAndRcode(kT0HostName, kT0Qtype, dns_protocol::kRcodeSERVFAIL,
+ ASYNC, true /* use_tcp */);
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_DNS_SERVER_FAILED);
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, TCPMalformed) {
+ AddAsyncQueryAndRcode(kT0HostName, kT0Qtype,
+ dns_protocol::kRcodeNOERROR | dns_protocol::kFlagTC);
+ scoped_ptr<DnsSocketData> data(
+ new DnsSocketData(0 /* id */, kT0HostName, kT0Qtype, ASYNC, true));
+ // Valid response but length too short.
+ data->AddResponseWithLength(
+ make_scoped_ptr(
+ new DnsResponse(reinterpret_cast<const char*>(kT0ResponseDatagram),
+ arraysize(kT0ResponseDatagram), 0)),
+ ASYNC,
+ static_cast<uint16>(kT0QuerySize - 1));
+ AddSocketData(data.Pass());
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_DNS_MALFORMED_RESPONSE);
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, TCPTimeout) {
+ config_.timeout = TestTimeouts::tiny_timeout();
+ ConfigureFactory();
+ AddAsyncQueryAndRcode(kT0HostName, kT0Qtype,
+ dns_protocol::kRcodeNOERROR | dns_protocol::kFlagTC);
+ AddSocketData(make_scoped_ptr(
+ new DnsSocketData(1 /* id */, kT0HostName, kT0Qtype, ASYNC, true)));
+
+ TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_DNS_TIMED_OUT);
+ EXPECT_TRUE(helper0.RunUntilDone(transaction_factory_.get()));
+}
+
+TEST_F(DnsTransactionTest, InvalidQuery) {
+ config_.timeout = TestTimeouts::tiny_timeout();
+ ConfigureFactory();
+
+ TransactionHelper helper0(".", dns_protocol::kTypeA, ERR_INVALID_ARGUMENT);
+ EXPECT_TRUE(helper0.Run(transaction_factory_.get()));
+}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/dns/host_cache.cc b/chromium/net/dns/host_cache.cc
new file mode 100644
index 00000000000..0e6ff15cd5d
--- /dev/null
+++ b/chromium/net/dns/host_cache.cc
@@ -0,0 +1,122 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/host_cache.h"
+
+#include "base/logging.h"
+#include "base/metrics/field_trial.h"
+#include "base/metrics/histogram.h"
+#include "base/strings/string_number_conversions.h"
+#include "net/base/net_errors.h"
+
+namespace net {
+
+//-----------------------------------------------------------------------------
+
+HostCache::Entry::Entry(int error, const AddressList& addrlist,
+ base::TimeDelta ttl)
+ : error(error),
+ addrlist(addrlist),
+ ttl(ttl) {
+ DCHECK(ttl >= base::TimeDelta());
+}
+
+HostCache::Entry::Entry(int error, const AddressList& addrlist)
+ : error(error),
+ addrlist(addrlist),
+ ttl(base::TimeDelta::FromSeconds(-1)) {
+}
+
+HostCache::Entry::~Entry() {
+}
+
+//-----------------------------------------------------------------------------
+
+HostCache::HostCache(size_t max_entries)
+ : entries_(max_entries) {
+}
+
+HostCache::~HostCache() {
+}
+
+const HostCache::Entry* HostCache::Lookup(const Key& key,
+ base::TimeTicks now) {
+ DCHECK(CalledOnValidThread());
+ if (caching_is_disabled())
+ return NULL;
+
+ return entries_.Get(key, now);
+}
+
+void HostCache::Set(const Key& key,
+ const Entry& entry,
+ base::TimeTicks now,
+ base::TimeDelta ttl) {
+ DCHECK(CalledOnValidThread());
+ if (caching_is_disabled())
+ return;
+
+ entries_.Put(key, entry, now, now + ttl);
+}
+
+void HostCache::clear() {
+ DCHECK(CalledOnValidThread());
+ entries_.Clear();
+}
+
+size_t HostCache::size() const {
+ DCHECK(CalledOnValidThread());
+ return entries_.size();
+}
+
+size_t HostCache::max_entries() const {
+ DCHECK(CalledOnValidThread());
+ return entries_.max_entries();
+}
+
+// Note that this map may contain expired entries.
+const HostCache::EntryMap& HostCache::entries() const {
+ DCHECK(CalledOnValidThread());
+ return entries_;
+}
+
+// static
+scoped_ptr<HostCache> HostCache::CreateDefaultCache() {
+ // Cache capacity is determined by the field trial.
+#if defined(ENABLE_BUILT_IN_DNS)
+ const size_t kDefaultMaxEntries = 1000;
+#else
+ const size_t kDefaultMaxEntries = 100;
+#endif
+ const size_t kSaneMaxEntries = 1 << 20;
+ size_t max_entries = 0;
+ base::StringToSizeT(base::FieldTrialList::FindFullName("HostCacheSize"),
+ &max_entries);
+ if ((max_entries == 0) || (max_entries > kSaneMaxEntries))
+ max_entries = kDefaultMaxEntries;
+ return make_scoped_ptr(new HostCache(max_entries));
+}
+
+void HostCache::EvictionHandler::Handle(
+ const Key& key,
+ const Entry& entry,
+ const base::TimeTicks& expiration,
+ const base::TimeTicks& now,
+ bool on_get) const {
+ if (on_get) {
+ DCHECK(now >= expiration);
+ UMA_HISTOGRAM_CUSTOM_TIMES("DNS.CacheExpiredOnGet", now - expiration,
+ base::TimeDelta::FromSeconds(1), base::TimeDelta::FromDays(1), 100);
+ return;
+ }
+ if (expiration > now) {
+ UMA_HISTOGRAM_CUSTOM_TIMES("DNS.CacheEvicted", expiration - now,
+ base::TimeDelta::FromSeconds(1), base::TimeDelta::FromDays(1), 100);
+ } else {
+ UMA_HISTOGRAM_CUSTOM_TIMES("DNS.CacheExpired", now - expiration,
+ base::TimeDelta::FromSeconds(1), base::TimeDelta::FromDays(1), 100);
+ }
+}
+
+} // namespace net
diff --git a/chromium/net/dns/host_cache.h b/chromium/net/dns/host_cache.h
new file mode 100644
index 00000000000..e8628fb04ec
--- /dev/null
+++ b/chromium/net/dns/host_cache.h
@@ -0,0 +1,124 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_HOST_CACHE_H_
+#define NET_DNS_HOST_CACHE_H_
+
+#include <functional>
+#include <string>
+
+#include "base/gtest_prod_util.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/threading/non_thread_safe.h"
+#include "base/time/time.h"
+#include "net/base/address_family.h"
+#include "net/base/address_list.h"
+#include "net/base/expiring_cache.h"
+#include "net/base/net_export.h"
+
+namespace net {
+
+// Cache used by HostResolver to map hostnames to their resolved result.
+class NET_EXPORT HostCache : NON_EXPORTED_BASE(public base::NonThreadSafe) {
+ public:
+ // Stores the latest address list that was looked up for a hostname.
+ struct NET_EXPORT Entry {
+ Entry(int error, const AddressList& addrlist, base::TimeDelta ttl);
+ // Use when |ttl| is unknown.
+ Entry(int error, const AddressList& addrlist);
+ ~Entry();
+
+ bool has_ttl() const { return ttl >= base::TimeDelta(); }
+
+ // The resolve results for this entry.
+ int error;
+ AddressList addrlist;
+ // TTL obtained from the nameserver. Negative if unknown.
+ base::TimeDelta ttl;
+ };
+
+ struct Key {
+ Key(const std::string& hostname, AddressFamily address_family,
+ HostResolverFlags host_resolver_flags)
+ : hostname(hostname),
+ address_family(address_family),
+ host_resolver_flags(host_resolver_flags) {}
+
+ bool operator<(const Key& other) const {
+ // |address_family| and |host_resolver_flags| are compared before
+ // |hostname| under assumption that integer comparisons are faster than
+ // string comparisons.
+ if (address_family != other.address_family)
+ return address_family < other.address_family;
+ if (host_resolver_flags != other.host_resolver_flags)
+ return host_resolver_flags < other.host_resolver_flags;
+ return hostname < other.hostname;
+ }
+
+ std::string hostname;
+ AddressFamily address_family;
+ HostResolverFlags host_resolver_flags;
+ };
+
+ struct EvictionHandler {
+ void Handle(const Key& key,
+ const Entry& entry,
+ const base::TimeTicks& expiration,
+ const base::TimeTicks& now,
+ bool onGet) const;
+ };
+
+ typedef ExpiringCache<Key, Entry, base::TimeTicks,
+ std::less<base::TimeTicks>,
+ EvictionHandler> EntryMap;
+
+ // Constructs a HostCache that stores up to |max_entries|.
+ explicit HostCache(size_t max_entries);
+
+ ~HostCache();
+
+ // Returns a pointer to the entry for |key|, which is valid at time
+ // |now|. If there is no such entry, returns NULL.
+ const Entry* Lookup(const Key& key, base::TimeTicks now);
+
+ // Overwrites or creates an entry for |key|.
+ // |entry| is the value to set, |now| is the current time
+ // |ttl| is the "time to live".
+ void Set(const Key& key,
+ const Entry& entry,
+ base::TimeTicks now,
+ base::TimeDelta ttl);
+
+ // Empties the cache
+ void clear();
+
+ // Returns the number of entries in the cache.
+ size_t size() const;
+
+ // Following are used by net_internals UI.
+ size_t max_entries() const;
+
+ const EntryMap& entries() const;
+
+ // Creates a default cache.
+ static scoped_ptr<HostCache> CreateDefaultCache();
+
+ private:
+ FRIEND_TEST_ALL_PREFIXES(HostCacheTest, NoCache);
+
+ // Returns true if this HostCache can contain no entries.
+ bool caching_is_disabled() const {
+ return entries_.max_entries() == 0;
+ }
+
+ // Map from hostname (presumably in lowercase canonicalized format) to
+ // a resolved result entry.
+ EntryMap entries_;
+
+ DISALLOW_COPY_AND_ASSIGN(HostCache);
+};
+
+} // namespace net
+
+#endif // NET_DNS_HOST_CACHE_H_
diff --git a/chromium/net/dns/host_cache_unittest.cc b/chromium/net/dns/host_cache_unittest.cc
new file mode 100644
index 00000000000..34309c1223c
--- /dev/null
+++ b/chromium/net/dns/host_cache_unittest.cc
@@ -0,0 +1,388 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/host_cache.h"
+
+#include "base/format_macros.h"
+#include "base/stl_util.h"
+#include "base/strings/string_util.h"
+#include "base/strings/stringprintf.h"
+#include "net/base/net_errors.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+const int kMaxCacheEntries = 10;
+
+// Builds a key for |hostname|, defaulting the address family to unspecified.
+HostCache::Key Key(const std::string& hostname) {
+ return HostCache::Key(hostname, ADDRESS_FAMILY_UNSPECIFIED, 0);
+}
+
+} // namespace
+
+TEST(HostCacheTest, Basic) {
+ const base::TimeDelta kTTL = base::TimeDelta::FromSeconds(10);
+
+ HostCache cache(kMaxCacheEntries);
+
+ // Start at t=0.
+ base::TimeTicks now;
+
+ HostCache::Key key1 = Key("foobar.com");
+ HostCache::Key key2 = Key("foobar2.com");
+ HostCache::Entry entry = HostCache::Entry(OK, AddressList());
+
+ EXPECT_EQ(0U, cache.size());
+
+ // Add an entry for "foobar.com" at t=0.
+ EXPECT_FALSE(cache.Lookup(key1, now));
+ cache.Set(key1, entry, now, kTTL);
+ EXPECT_TRUE(cache.Lookup(key1, now));
+ EXPECT_TRUE(cache.Lookup(key1, now)->error == entry.error);
+
+ EXPECT_EQ(1U, cache.size());
+
+ // Advance to t=5.
+ now += base::TimeDelta::FromSeconds(5);
+
+ // Add an entry for "foobar2.com" at t=5.
+ EXPECT_FALSE(cache.Lookup(key2, now));
+ cache.Set(key2, entry, now, kTTL);
+ EXPECT_TRUE(cache.Lookup(key2, now));
+ EXPECT_EQ(2U, cache.size());
+
+ // Advance to t=9
+ now += base::TimeDelta::FromSeconds(4);
+
+ // Verify that the entries we added are still retrievable, and usable.
+ EXPECT_TRUE(cache.Lookup(key1, now));
+ EXPECT_TRUE(cache.Lookup(key2, now));
+ EXPECT_NE(cache.Lookup(key1, now), cache.Lookup(key2, now));
+
+ // Advance to t=10; key is now expired.
+ now += base::TimeDelta::FromSeconds(1);
+
+ EXPECT_FALSE(cache.Lookup(key1, now));
+ EXPECT_TRUE(cache.Lookup(key2, now));
+
+ // Update key1, so it is no longer expired.
+ cache.Set(key1, entry, now, kTTL);
+ EXPECT_TRUE(cache.Lookup(key1, now));
+ EXPECT_EQ(2U, cache.size());
+
+ // Both entries should still be retrievable and usable.
+ EXPECT_TRUE(cache.Lookup(key1, now));
+ EXPECT_TRUE(cache.Lookup(key2, now));
+
+ // Advance to t=20; both entries are now expired.
+ now += base::TimeDelta::FromSeconds(10);
+
+ EXPECT_FALSE(cache.Lookup(key1, now));
+ EXPECT_FALSE(cache.Lookup(key2, now));
+}
+
+// Try caching entries for a failed resolve attempt -- since we set the TTL of
+// such entries to 0 it won't store, but it will kick out the previous result.
+TEST(HostCacheTest, NoCacheZeroTTL) {
+ const base::TimeDelta kSuccessEntryTTL = base::TimeDelta::FromSeconds(10);
+ const base::TimeDelta kFailureEntryTTL = base::TimeDelta::FromSeconds(0);
+
+ HostCache cache(kMaxCacheEntries);
+
+ // Set t=0.
+ base::TimeTicks now;
+
+ HostCache::Key key1 = Key("foobar.com");
+ HostCache::Key key2 = Key("foobar2.com");
+ HostCache::Entry entry = HostCache::Entry(OK, AddressList());
+
+ EXPECT_FALSE(cache.Lookup(key1, now));
+ cache.Set(key1, entry, now, kFailureEntryTTL);
+ EXPECT_EQ(1U, cache.size());
+
+ // We disallow use of negative entries.
+ EXPECT_FALSE(cache.Lookup(key1, now));
+
+ // Now overwrite with a valid entry, and then overwrite with negative entry
+ // again -- the valid entry should be kicked out.
+ cache.Set(key1, entry, now, kSuccessEntryTTL);
+ EXPECT_TRUE(cache.Lookup(key1, now));
+ cache.Set(key1, entry, now, kFailureEntryTTL);
+ EXPECT_FALSE(cache.Lookup(key1, now));
+}
+
+// Try caching entries for a failed resolves for 10 seconds.
+TEST(HostCacheTest, CacheNegativeEntry) {
+ const base::TimeDelta kFailureEntryTTL = base::TimeDelta::FromSeconds(10);
+
+ HostCache cache(kMaxCacheEntries);
+
+ // Start at t=0.
+ base::TimeTicks now;
+
+ HostCache::Key key1 = Key("foobar.com");
+ HostCache::Key key2 = Key("foobar2.com");
+ HostCache::Entry entry = HostCache::Entry(OK, AddressList());
+
+ EXPECT_EQ(0U, cache.size());
+
+ // Add an entry for "foobar.com" at t=0.
+ EXPECT_FALSE(cache.Lookup(key1, now));
+ cache.Set(key1, entry, now, kFailureEntryTTL);
+ EXPECT_TRUE(cache.Lookup(key1, now));
+ EXPECT_EQ(1U, cache.size());
+
+ // Advance to t=5.
+ now += base::TimeDelta::FromSeconds(5);
+
+ // Add an entry for "foobar2.com" at t=5.
+ EXPECT_FALSE(cache.Lookup(key2, now));
+ cache.Set(key2, entry, now, kFailureEntryTTL);
+ EXPECT_TRUE(cache.Lookup(key2, now));
+ EXPECT_EQ(2U, cache.size());
+
+ // Advance to t=9
+ now += base::TimeDelta::FromSeconds(4);
+
+ // Verify that the entries we added are still retrievable, and usable.
+ EXPECT_TRUE(cache.Lookup(key1, now));
+ EXPECT_TRUE(cache.Lookup(key2, now));
+
+ // Advance to t=10; key1 is now expired.
+ now += base::TimeDelta::FromSeconds(1);
+
+ EXPECT_FALSE(cache.Lookup(key1, now));
+ EXPECT_TRUE(cache.Lookup(key2, now));
+
+ // Update key1, so it is no longer expired.
+ cache.Set(key1, entry, now, kFailureEntryTTL);
+ // Re-uses existing entry storage.
+ EXPECT_TRUE(cache.Lookup(key1, now));
+ EXPECT_EQ(2U, cache.size());
+
+ // Both entries should still be retrievable and usable.
+ EXPECT_TRUE(cache.Lookup(key1, now));
+ EXPECT_TRUE(cache.Lookup(key2, now));
+
+ // Advance to t=20; both entries are now expired.
+ now += base::TimeDelta::FromSeconds(10);
+
+ EXPECT_FALSE(cache.Lookup(key1, now));
+ EXPECT_FALSE(cache.Lookup(key2, now));
+}
+
+// Tests that the same hostname can be duplicated in the cache, so long as
+// the address family differs.
+TEST(HostCacheTest, AddressFamilyIsPartOfKey) {
+ const base::TimeDelta kSuccessEntryTTL = base::TimeDelta::FromSeconds(10);
+
+ HostCache cache(kMaxCacheEntries);
+
+ // t=0.
+ base::TimeTicks now;
+
+ HostCache::Key key1("foobar.com", ADDRESS_FAMILY_UNSPECIFIED, 0);
+ HostCache::Key key2("foobar.com", ADDRESS_FAMILY_IPV4, 0);
+ HostCache::Entry entry = HostCache::Entry(OK, AddressList());
+
+ EXPECT_EQ(0U, cache.size());
+
+ // Add an entry for ("foobar.com", UNSPECIFIED) at t=0.
+ EXPECT_FALSE(cache.Lookup(key1, now));
+ cache.Set(key1, entry, now, kSuccessEntryTTL);
+ EXPECT_TRUE(cache.Lookup(key1, now));
+ EXPECT_EQ(1U, cache.size());
+
+ // Add an entry for ("foobar.com", IPV4_ONLY) at t=0.
+ EXPECT_FALSE(cache.Lookup(key2, now));
+ cache.Set(key2, entry, now, kSuccessEntryTTL);
+ EXPECT_TRUE(cache.Lookup(key2, now));
+ EXPECT_EQ(2U, cache.size());
+
+ // Even though the hostnames were the same, we should have two unique
+ // entries (because the address families differ).
+ EXPECT_NE(cache.Lookup(key1, now), cache.Lookup(key2, now));
+}
+
+// Tests that the same hostname can be duplicated in the cache, so long as
+// the HostResolverFlags differ.
+TEST(HostCacheTest, HostResolverFlagsArePartOfKey) {
+ const base::TimeDelta kTTL = base::TimeDelta::FromSeconds(10);
+
+ HostCache cache(kMaxCacheEntries);
+
+ // t=0.
+ base::TimeTicks now;
+
+ HostCache::Key key1("foobar.com", ADDRESS_FAMILY_IPV4, 0);
+ HostCache::Key key2("foobar.com", ADDRESS_FAMILY_IPV4,
+ HOST_RESOLVER_CANONNAME);
+ HostCache::Key key3("foobar.com", ADDRESS_FAMILY_IPV4,
+ HOST_RESOLVER_LOOPBACK_ONLY);
+ HostCache::Entry entry = HostCache::Entry(OK, AddressList());
+
+ EXPECT_EQ(0U, cache.size());
+
+ // Add an entry for ("foobar.com", IPV4, NONE) at t=0.
+ EXPECT_FALSE(cache.Lookup(key1, now));
+ cache.Set(key1, entry, now, kTTL);
+ EXPECT_TRUE(cache.Lookup(key1, now));
+ EXPECT_EQ(1U, cache.size());
+
+ // Add an entry for ("foobar.com", IPV4, CANONNAME) at t=0.
+ EXPECT_FALSE(cache.Lookup(key2, now));
+ cache.Set(key2, entry, now, kTTL);
+ EXPECT_TRUE(cache.Lookup(key2, now));
+ EXPECT_EQ(2U, cache.size());
+
+ // Add an entry for ("foobar.com", IPV4, LOOPBACK_ONLY) at t=0.
+ EXPECT_FALSE(cache.Lookup(key3, now));
+ cache.Set(key3, entry, now, kTTL);
+ EXPECT_TRUE(cache.Lookup(key3, now));
+ EXPECT_EQ(3U, cache.size());
+
+ // Even though the hostnames were the same, we should have two unique
+ // entries (because the HostResolverFlags differ).
+ EXPECT_NE(cache.Lookup(key1, now), cache.Lookup(key2, now));
+ EXPECT_NE(cache.Lookup(key1, now), cache.Lookup(key3, now));
+ EXPECT_NE(cache.Lookup(key2, now), cache.Lookup(key3, now));
+}
+
+TEST(HostCacheTest, NoCache) {
+ // Disable caching.
+ const base::TimeDelta kTTL = base::TimeDelta::FromSeconds(10);
+
+ HostCache cache(0);
+ EXPECT_TRUE(cache.caching_is_disabled());
+
+ // Set t=0.
+ base::TimeTicks now;
+
+ HostCache::Entry entry = HostCache::Entry(OK, AddressList());
+
+ // Lookup and Set should have no effect.
+ EXPECT_FALSE(cache.Lookup(Key("foobar.com"),now));
+ cache.Set(Key("foobar.com"), entry, now, kTTL);
+ EXPECT_FALSE(cache.Lookup(Key("foobar.com"), now));
+
+ EXPECT_EQ(0U, cache.size());
+}
+
+TEST(HostCacheTest, Clear) {
+ const base::TimeDelta kTTL = base::TimeDelta::FromSeconds(10);
+
+ HostCache cache(kMaxCacheEntries);
+
+ // Set t=0.
+ base::TimeTicks now;
+
+ HostCache::Entry entry = HostCache::Entry(OK, AddressList());
+
+ EXPECT_EQ(0u, cache.size());
+
+ // Add three entries.
+ cache.Set(Key("foobar1.com"), entry, now, kTTL);
+ cache.Set(Key("foobar2.com"), entry, now, kTTL);
+ cache.Set(Key("foobar3.com"), entry, now, kTTL);
+
+ EXPECT_EQ(3u, cache.size());
+
+ cache.clear();
+
+ EXPECT_EQ(0u, cache.size());
+}
+
+// Tests the less than and equal operators for HostCache::Key work.
+TEST(HostCacheTest, KeyComparators) {
+ struct {
+ // Inputs.
+ HostCache::Key key1;
+ HostCache::Key key2;
+
+ // Expectation.
+ // -1 means key1 is less than key2
+ // 0 means key1 equals key2
+ // 1 means key1 is greater than key2
+ int expected_comparison;
+ } tests[] = {
+ {
+ HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0),
+ HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0),
+ 0
+ },
+ {
+ HostCache::Key("host1", ADDRESS_FAMILY_IPV4, 0),
+ HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0),
+ 1
+ },
+ {
+ HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0),
+ HostCache::Key("host1", ADDRESS_FAMILY_IPV4, 0),
+ -1
+ },
+ {
+ HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0),
+ HostCache::Key("host2", ADDRESS_FAMILY_UNSPECIFIED, 0),
+ -1
+ },
+ {
+ HostCache::Key("host1", ADDRESS_FAMILY_IPV4, 0),
+ HostCache::Key("host2", ADDRESS_FAMILY_UNSPECIFIED, 0),
+ 1
+ },
+ {
+ HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0),
+ HostCache::Key("host2", ADDRESS_FAMILY_IPV4, 0),
+ -1
+ },
+ {
+ HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0),
+ HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED,
+ HOST_RESOLVER_CANONNAME),
+ -1
+ },
+ {
+ HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED,
+ HOST_RESOLVER_CANONNAME),
+ HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0),
+ 1
+ },
+ {
+ HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED,
+ HOST_RESOLVER_CANONNAME),
+ HostCache::Key("host2", ADDRESS_FAMILY_UNSPECIFIED,
+ HOST_RESOLVER_CANONNAME),
+ -1
+ },
+ };
+
+ for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) {
+ SCOPED_TRACE(base::StringPrintf("Test[%" PRIuS "]", i));
+
+ const HostCache::Key& key1 = tests[i].key1;
+ const HostCache::Key& key2 = tests[i].key2;
+
+ switch (tests[i].expected_comparison) {
+ case -1:
+ EXPECT_TRUE(key1 < key2);
+ EXPECT_FALSE(key2 < key1);
+ break;
+ case 0:
+ EXPECT_FALSE(key1 < key2);
+ EXPECT_FALSE(key2 < key1);
+ break;
+ case 1:
+ EXPECT_FALSE(key1 < key2);
+ EXPECT_TRUE(key2 < key1);
+ break;
+ default:
+ FAIL() << "Invalid expectation. Can be only -1, 0, 1";
+ }
+ }
+}
+
+} // namespace net
diff --git a/chromium/net/dns/host_resolver.cc b/chromium/net/dns/host_resolver.cc
new file mode 100644
index 00000000000..d74be91beb0
--- /dev/null
+++ b/chromium/net/dns/host_resolver.cc
@@ -0,0 +1,145 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/host_resolver.h"
+
+#include "base/logging.h"
+#include "base/metrics/field_trial.h"
+#include "base/strings/string_number_conversions.h"
+#include "base/strings/string_split.h"
+#include "net/dns/dns_client.h"
+#include "net/dns/dns_config_service.h"
+#include "net/dns/host_cache.h"
+#include "net/dns/host_resolver_impl.h"
+
+namespace net {
+
+namespace {
+
+// Maximum of 6 concurrent resolver threads (excluding retries).
+// Some routers (or resolvers) appear to start to provide host-not-found if
+// too many simultaneous resolutions are pending. This number needs to be
+// further optimized, but 8 is what FF currently does. We found some routers
+// that limit this to 6, so we're temporarily holding it at that level.
+const size_t kDefaultMaxProcTasks = 6u;
+
+// When configuring from field trial, do not allow
+const size_t kSaneMaxProcTasks = 20u;
+
+PrioritizedDispatcher::Limits GetDispatcherLimits(
+ const HostResolver::Options& options) {
+ PrioritizedDispatcher::Limits limits(NUM_PRIORITIES,
+ options.max_concurrent_resolves);
+
+ // If not using default, do not use the field trial.
+ if (limits.total_jobs != HostResolver::kDefaultParallelism)
+ return limits;
+
+ // Default, without trial is no reserved slots.
+ limits.total_jobs = kDefaultMaxProcTasks;
+
+ // Parallelism is determined by the field trial.
+ std::string group = base::FieldTrialList::FindFullName(
+ "HostResolverDispatch");
+
+ if (group.empty())
+ return limits;
+
+ // The format of the group name is a list of non-negative integers separated
+ // by ':'. Each of the elements in the list corresponds to an element in
+ // |reserved_slots|, except the last one which is the |total_jobs|.
+
+ std::vector<std::string> group_parts;
+ base::SplitString(group, ':', &group_parts);
+ if (group_parts.size() != NUM_PRIORITIES + 1) {
+ NOTREACHED();
+ return limits;
+ }
+
+ std::vector<size_t> parsed(group_parts.size());
+ size_t total_reserved_slots = 0;
+
+ for (size_t i = 0; i < group_parts.size(); ++i) {
+ if (!base::StringToSizeT(group_parts[i], &parsed[i])) {
+ NOTREACHED();
+ return limits;
+ }
+ }
+
+ size_t total_jobs = parsed.back();
+ parsed.pop_back();
+ for (size_t i = 0; i < parsed.size(); ++i) {
+ total_reserved_slots += parsed[i];
+ }
+
+ // There must be some unreserved slots available for the all priorities.
+ if (total_reserved_slots > total_jobs ||
+ (total_reserved_slots == total_jobs && parsed[MINIMUM_PRIORITY] == 0)) {
+ NOTREACHED();
+ return limits;
+ }
+
+ limits.total_jobs = total_jobs;
+ limits.reserved_slots = parsed;
+ return limits;
+}
+
+} // namespace
+
+HostResolver::Options::Options()
+ : max_concurrent_resolves(kDefaultParallelism),
+ max_retry_attempts(kDefaultRetryAttempts),
+ enable_caching(true) {
+}
+
+HostResolver::RequestInfo::RequestInfo(const HostPortPair& host_port_pair)
+ : host_port_pair_(host_port_pair),
+ address_family_(ADDRESS_FAMILY_UNSPECIFIED),
+ host_resolver_flags_(0),
+ allow_cached_response_(true),
+ is_speculative_(false),
+ priority_(MEDIUM) {
+}
+
+HostResolver::~HostResolver() {
+}
+
+AddressFamily HostResolver::GetDefaultAddressFamily() const {
+ return ADDRESS_FAMILY_UNSPECIFIED;
+}
+
+void HostResolver::SetDnsClientEnabled(bool enabled) {
+}
+
+HostCache* HostResolver::GetHostCache() {
+ return NULL;
+}
+
+base::Value* HostResolver::GetDnsConfigAsValue() const {
+ return NULL;
+}
+
+// static
+scoped_ptr<HostResolver>
+HostResolver::CreateSystemResolver(const Options& options, NetLog* net_log) {
+ scoped_ptr<HostCache> cache;
+ if (options.enable_caching)
+ cache = HostCache::CreateDefaultCache();
+ return scoped_ptr<HostResolver>(new HostResolverImpl(
+ cache.Pass(),
+ GetDispatcherLimits(options),
+ HostResolverImpl::ProcTaskParams(NULL, options.max_retry_attempts),
+ net_log));
+}
+
+// static
+scoped_ptr<HostResolver>
+HostResolver::CreateDefaultResolver(NetLog* net_log) {
+ return CreateSystemResolver(Options(), net_log);
+}
+
+HostResolver::HostResolver() {
+}
+
+} // namespace net
diff --git a/chromium/net/dns/host_resolver.h b/chromium/net/dns/host_resolver.h
new file mode 100644
index 00000000000..558a1ddc4b2
--- /dev/null
+++ b/chromium/net/dns/host_resolver.h
@@ -0,0 +1,204 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_HOST_RESOLVER_H_
+#define NET_DNS_HOST_RESOLVER_H_
+
+#include <string>
+
+#include "base/memory/scoped_ptr.h"
+#include "net/base/address_family.h"
+#include "net/base/completion_callback.h"
+#include "net/base/host_port_pair.h"
+#include "net/base/net_export.h"
+#include "net/base/net_util.h"
+#include "net/base/request_priority.h"
+
+namespace base {
+class Value;
+}
+
+namespace net {
+
+class AddressList;
+class BoundNetLog;
+class HostCache;
+class HostResolverProc;
+class NetLog;
+
+// This class represents the task of resolving hostnames (or IP address
+// literal) to an AddressList object.
+//
+// HostResolver can handle multiple requests at a time, so when cancelling a
+// request the RequestHandle that was returned by Resolve() needs to be
+// given. A simpler alternative for consumers that only have 1 outstanding
+// request at a time is to create a SingleRequestHostResolver wrapper around
+// HostResolver (which will automatically cancel the single request when it
+// goes out of scope).
+class NET_EXPORT HostResolver {
+ public:
+ // |max_concurrent_resolves| is how many resolve requests will be allowed to
+ // run in parallel. Pass HostResolver::kDefaultParallelism to choose a
+ // default value.
+ // |max_retry_attempts| is the maximum number of times we will retry for host
+ // resolution. Pass HostResolver::kDefaultRetryAttempts to choose a default
+ // value.
+ // |enable_caching| controls whether a HostCache is used.
+ struct NET_EXPORT Options {
+ Options();
+
+ size_t max_concurrent_resolves;
+ size_t max_retry_attempts;
+ bool enable_caching;
+ };
+
+ // The parameters for doing a Resolve(). A hostname and port are required,
+ // the rest are optional (and have reasonable defaults).
+ class NET_EXPORT RequestInfo {
+ public:
+ explicit RequestInfo(const HostPortPair& host_port_pair);
+
+ const HostPortPair& host_port_pair() const { return host_port_pair_; }
+ void set_host_port_pair(const HostPortPair& host_port_pair) {
+ host_port_pair_ = host_port_pair;
+ }
+
+ int port() const { return host_port_pair_.port(); }
+ const std::string& hostname() const { return host_port_pair_.host(); }
+
+ AddressFamily address_family() const { return address_family_; }
+ void set_address_family(AddressFamily address_family) {
+ address_family_ = address_family;
+ }
+
+ HostResolverFlags host_resolver_flags() const {
+ return host_resolver_flags_;
+ }
+ void set_host_resolver_flags(HostResolverFlags host_resolver_flags) {
+ host_resolver_flags_ = host_resolver_flags;
+ }
+
+ bool allow_cached_response() const { return allow_cached_response_; }
+ void set_allow_cached_response(bool b) { allow_cached_response_ = b; }
+
+ bool is_speculative() const { return is_speculative_; }
+ void set_is_speculative(bool b) { is_speculative_ = b; }
+
+ RequestPriority priority() const { return priority_; }
+ void set_priority(RequestPriority priority) { priority_ = priority; }
+
+ private:
+ // The hostname to resolve, and the port to use in resulting sockaddrs.
+ HostPortPair host_port_pair_;
+
+ // The address family to restrict results to.
+ AddressFamily address_family_;
+
+ // Flags to use when resolving this request.
+ HostResolverFlags host_resolver_flags_;
+
+ // Whether it is ok to return a result from the host cache.
+ bool allow_cached_response_;
+
+ // Whether this request was started by the DNS prefetcher.
+ bool is_speculative_;
+
+ // The priority for the request.
+ RequestPriority priority_;
+ };
+
+ // Opaque type used to cancel a request.
+ typedef void* RequestHandle;
+
+ // This value can be passed into CreateSystemResolver as the
+ // |max_concurrent_resolves| parameter. It will select a default level of
+ // concurrency.
+ static const size_t kDefaultParallelism = 0;
+
+ // This value can be passed into CreateSystemResolver as the
+ // |max_retry_attempts| parameter.
+ static const size_t kDefaultRetryAttempts = -1;
+
+ // If any completion callbacks are pending when the resolver is destroyed,
+ // the host resolutions are cancelled, and the completion callbacks will not
+ // be called.
+ virtual ~HostResolver();
+
+ // Resolves the given hostname (or IP address literal), filling out the
+ // |addresses| object upon success. The |info.port| parameter will be set as
+ // the sin(6)_port field of the sockaddr_in{6} struct. Returns OK if
+ // successful or an error code upon failure. Returns
+ // ERR_NAME_NOT_RESOLVED if hostname is invalid, or if it is an
+ // incompatible IP literal (e.g. IPv6 is disabled and it is an IPv6
+ // literal).
+ //
+ // If the operation cannot be completed synchronously, ERR_IO_PENDING will
+ // be returned and the real result code will be passed to the completion
+ // callback. Otherwise the result code is returned immediately from this
+ // call.
+ //
+ // If |out_req| is non-NULL, then |*out_req| will be filled with a handle to
+ // the async request. This handle is not valid after the request has
+ // completed.
+ //
+ // Profiling information for the request is saved to |net_log| if non-NULL.
+ virtual int Resolve(const RequestInfo& info,
+ AddressList* addresses,
+ const CompletionCallback& callback,
+ RequestHandle* out_req,
+ const BoundNetLog& net_log) = 0;
+
+ // Resolves the given hostname (or IP address literal) out of cache or HOSTS
+ // file (if enabled) only. This is guaranteed to complete synchronously.
+ // This acts like |Resolve()| if the hostname is IP literal, or cached value
+ // or HOSTS entry exists. Otherwise, ERR_DNS_CACHE_MISS is returned.
+ virtual int ResolveFromCache(const RequestInfo& info,
+ AddressList* addresses,
+ const BoundNetLog& net_log) = 0;
+
+ // Cancels the specified request. |req| is the handle returned by Resolve().
+ // After a request is canceled, its completion callback will not be called.
+ // CancelRequest must NOT be called after the request's completion callback
+ // has already run or the request was canceled.
+ virtual void CancelRequest(RequestHandle req) = 0;
+
+ // Sets the default AddressFamily to use when requests have left it
+ // unspecified. For example, this could be used to restrict resolution
+ // results to AF_INET by passing in ADDRESS_FAMILY_IPV4, or to
+ // AF_INET6 by passing in ADDRESS_FAMILY_IPV6.
+ virtual void SetDefaultAddressFamily(AddressFamily address_family) {}
+ virtual AddressFamily GetDefaultAddressFamily() const;
+
+ // Enable or disable the built-in asynchronous DnsClient.
+ virtual void SetDnsClientEnabled(bool enabled);
+
+ // Returns the HostResolverCache |this| uses, or NULL if there isn't one.
+ // Used primarily to clear the cache and for getting debug information.
+ virtual HostCache* GetHostCache();
+
+ // Returns the current DNS configuration |this| is using, as a Value, or NULL
+ // if it's configured to always use the system host resolver. Caller takes
+ // ownership of the returned Value.
+ virtual base::Value* GetDnsConfigAsValue() const;
+
+ // Creates a HostResolver implementation that queries the underlying system.
+ // (Except if a unit-test has changed the global HostResolverProc using
+ // ScopedHostResolverProc to intercept requests to the system).
+ static scoped_ptr<HostResolver> CreateSystemResolver(
+ const Options& options,
+ NetLog* net_log);
+
+ // As above, but uses default parameters.
+ static scoped_ptr<HostResolver> CreateDefaultResolver(NetLog* net_log);
+
+ protected:
+ HostResolver();
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(HostResolver);
+};
+
+} // namespace net
+
+#endif // NET_DNS_HOST_RESOLVER_H_
diff --git a/chromium/net/dns/host_resolver_impl.cc b/chromium/net/dns/host_resolver_impl.cc
new file mode 100644
index 00000000000..10631773291
--- /dev/null
+++ b/chromium/net/dns/host_resolver_impl.cc
@@ -0,0 +1,2206 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/host_resolver_impl.h"
+
+#if defined(OS_WIN)
+#include <Winsock2.h>
+#elif defined(OS_POSIX)
+#include <netdb.h>
+#endif
+
+#include <cmath>
+#include <utility>
+#include <vector>
+
+#include "base/basictypes.h"
+#include "base/bind.h"
+#include "base/bind_helpers.h"
+#include "base/callback.h"
+#include "base/compiler_specific.h"
+#include "base/debug/debugger.h"
+#include "base/debug/stack_trace.h"
+#include "base/message_loop/message_loop_proxy.h"
+#include "base/metrics/field_trial.h"
+#include "base/metrics/histogram.h"
+#include "base/stl_util.h"
+#include "base/strings/string_util.h"
+#include "base/strings/utf_string_conversions.h"
+#include "base/threading/worker_pool.h"
+#include "base/time/time.h"
+#include "base/values.h"
+#include "net/base/address_family.h"
+#include "net/base/address_list.h"
+#include "net/base/dns_reloader.h"
+#include "net/base/dns_util.h"
+#include "net/base/host_port_pair.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/net_util.h"
+#include "net/dns/address_sorter.h"
+#include "net/dns/dns_client.h"
+#include "net/dns/dns_config_service.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/dns_transaction.h"
+#include "net/dns/host_resolver_proc.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/udp/datagram_client_socket.h"
+
+#if defined(OS_WIN)
+#include "net/base/winsock_init.h"
+#endif
+
+namespace net {
+
+namespace {
+
+// Limit the size of hostnames that will be resolved to combat issues in
+// some platform's resolvers.
+const size_t kMaxHostLength = 4096;
+
+// Default TTL for successful resolutions with ProcTask.
+const unsigned kCacheEntryTTLSeconds = 60;
+
+// Default TTL for unsuccessful resolutions with ProcTask.
+const unsigned kNegativeCacheEntryTTLSeconds = 0;
+
+// Minimum TTL for successful resolutions with DnsTask.
+const unsigned kMinimumTTLSeconds = kCacheEntryTTLSeconds;
+
+// Number of consecutive failures of DnsTask (with successful fallback) before
+// the DnsClient is disabled until the next DNS change.
+const unsigned kMaximumDnsFailures = 16;
+
+// We use a separate histogram name for each platform to facilitate the
+// display of error codes by their symbolic name (since each platform has
+// different mappings).
+const char kOSErrorsForGetAddrinfoHistogramName[] =
+#if defined(OS_WIN)
+ "Net.OSErrorsForGetAddrinfo_Win";
+#elif defined(OS_MACOSX)
+ "Net.OSErrorsForGetAddrinfo_Mac";
+#elif defined(OS_LINUX)
+ "Net.OSErrorsForGetAddrinfo_Linux";
+#else
+ "Net.OSErrorsForGetAddrinfo";
+#endif
+
+// Gets a list of the likely error codes that getaddrinfo() can return
+// (non-exhaustive). These are the error codes that we will track via
+// a histogram.
+std::vector<int> GetAllGetAddrinfoOSErrors() {
+ int os_errors[] = {
+#if defined(OS_POSIX)
+#if !defined(OS_FREEBSD)
+#if !defined(OS_ANDROID)
+ // EAI_ADDRFAMILY has been declared obsolete in Android's and
+ // FreeBSD's netdb.h.
+ EAI_ADDRFAMILY,
+#endif
+ // EAI_NODATA has been declared obsolete in FreeBSD's netdb.h.
+ EAI_NODATA,
+#endif
+ EAI_AGAIN,
+ EAI_BADFLAGS,
+ EAI_FAIL,
+ EAI_FAMILY,
+ EAI_MEMORY,
+ EAI_NONAME,
+ EAI_SERVICE,
+ EAI_SOCKTYPE,
+ EAI_SYSTEM,
+#elif defined(OS_WIN)
+ // See: http://msdn.microsoft.com/en-us/library/ms738520(VS.85).aspx
+ WSA_NOT_ENOUGH_MEMORY,
+ WSAEAFNOSUPPORT,
+ WSAEINVAL,
+ WSAESOCKTNOSUPPORT,
+ WSAHOST_NOT_FOUND,
+ WSANO_DATA,
+ WSANO_RECOVERY,
+ WSANOTINITIALISED,
+ WSATRY_AGAIN,
+ WSATYPE_NOT_FOUND,
+ // The following are not in doc, but might be to appearing in results :-(.
+ WSA_INVALID_HANDLE,
+#endif
+ };
+
+ // Ensure all errors are positive, as histogram only tracks positive values.
+ for (size_t i = 0; i < arraysize(os_errors); ++i) {
+ os_errors[i] = std::abs(os_errors[i]);
+ }
+
+ return base::CustomHistogram::ArrayToCustomRanges(os_errors,
+ arraysize(os_errors));
+}
+
+enum DnsResolveStatus {
+ RESOLVE_STATUS_DNS_SUCCESS = 0,
+ RESOLVE_STATUS_PROC_SUCCESS,
+ RESOLVE_STATUS_FAIL,
+ RESOLVE_STATUS_SUSPECT_NETBIOS,
+ RESOLVE_STATUS_MAX
+};
+
+void UmaAsyncDnsResolveStatus(DnsResolveStatus result) {
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.ResolveStatus",
+ result,
+ RESOLVE_STATUS_MAX);
+}
+
+bool ResemblesNetBIOSName(const std::string& hostname) {
+ return (hostname.size() < 16) && (hostname.find('.') == std::string::npos);
+}
+
+// True if |hostname| ends with either ".local" or ".local.".
+bool ResemblesMulticastDNSName(const std::string& hostname) {
+ DCHECK(!hostname.empty());
+ const char kSuffix[] = ".local.";
+ const size_t kSuffixLen = sizeof(kSuffix) - 1;
+ const size_t kSuffixLenTrimmed = kSuffixLen - 1;
+ if (hostname[hostname.size() - 1] == '.') {
+ return hostname.size() > kSuffixLen &&
+ !hostname.compare(hostname.size() - kSuffixLen, kSuffixLen, kSuffix);
+ }
+ return hostname.size() > kSuffixLenTrimmed &&
+ !hostname.compare(hostname.size() - kSuffixLenTrimmed, kSuffixLenTrimmed,
+ kSuffix, kSuffixLenTrimmed);
+}
+
+// Attempts to connect a UDP socket to |dest|:53.
+bool IsGloballyReachable(const IPAddressNumber& dest,
+ const BoundNetLog& net_log) {
+ scoped_ptr<DatagramClientSocket> socket(
+ ClientSocketFactory::GetDefaultFactory()->CreateDatagramClientSocket(
+ DatagramSocket::DEFAULT_BIND,
+ RandIntCallback(),
+ net_log.net_log(),
+ net_log.source()));
+ int rv = socket->Connect(IPEndPoint(dest, 53));
+ if (rv != OK)
+ return false;
+ IPEndPoint endpoint;
+ rv = socket->GetLocalAddress(&endpoint);
+ if (rv != OK)
+ return false;
+ DCHECK(endpoint.GetFamily() == ADDRESS_FAMILY_IPV6);
+ const IPAddressNumber& address = endpoint.address();
+ bool is_link_local = (address[0] == 0xFE) && ((address[1] & 0xC0) == 0x80);
+ if (is_link_local)
+ return false;
+ const uint8 kTeredoPrefix[] = { 0x20, 0x01, 0, 0 };
+ bool is_teredo = std::equal(kTeredoPrefix,
+ kTeredoPrefix + arraysize(kTeredoPrefix),
+ address.begin());
+ if (is_teredo)
+ return false;
+ return true;
+}
+
+// Provide a common macro to simplify code and readability. We must use a
+// macro as the underlying HISTOGRAM macro creates static variables.
+#define DNS_HISTOGRAM(name, time) UMA_HISTOGRAM_CUSTOM_TIMES(name, time, \
+ base::TimeDelta::FromMilliseconds(1), base::TimeDelta::FromHours(1), 100)
+
+// A macro to simplify code and readability.
+#define DNS_HISTOGRAM_BY_PRIORITY(basename, priority, time) \
+ do { \
+ switch (priority) { \
+ case HIGHEST: DNS_HISTOGRAM(basename "_HIGHEST", time); break; \
+ case MEDIUM: DNS_HISTOGRAM(basename "_MEDIUM", time); break; \
+ case LOW: DNS_HISTOGRAM(basename "_LOW", time); break; \
+ case LOWEST: DNS_HISTOGRAM(basename "_LOWEST", time); break; \
+ case IDLE: DNS_HISTOGRAM(basename "_IDLE", time); break; \
+ default: NOTREACHED(); break; \
+ } \
+ DNS_HISTOGRAM(basename, time); \
+ } while (0)
+
+// Record time from Request creation until a valid DNS response.
+void RecordTotalTime(bool had_dns_config,
+ bool speculative,
+ base::TimeDelta duration) {
+ if (had_dns_config) {
+ if (speculative) {
+ DNS_HISTOGRAM("AsyncDNS.TotalTime_speculative", duration);
+ } else {
+ DNS_HISTOGRAM("AsyncDNS.TotalTime", duration);
+ }
+ } else {
+ if (speculative) {
+ DNS_HISTOGRAM("DNS.TotalTime_speculative", duration);
+ } else {
+ DNS_HISTOGRAM("DNS.TotalTime", duration);
+ }
+ }
+}
+
+void RecordTTL(base::TimeDelta ttl) {
+ UMA_HISTOGRAM_CUSTOM_TIMES("AsyncDNS.TTL", ttl,
+ base::TimeDelta::FromSeconds(1),
+ base::TimeDelta::FromDays(1), 100);
+}
+
+bool ConfigureAsyncDnsNoFallbackFieldTrial() {
+ const bool kDefault = false;
+
+ // Configure the AsyncDns field trial as follows:
+ // groups AsyncDnsNoFallbackA and AsyncDnsNoFallbackB: return true,
+ // groups AsyncDnsA and AsyncDnsB: return false,
+ // groups SystemDnsA and SystemDnsB: return false,
+ // otherwise (trial absent): return default.
+ std::string group_name = base::FieldTrialList::FindFullName("AsyncDns");
+ if (!group_name.empty())
+ return StartsWithASCII(group_name, "AsyncDnsNoFallback", false);
+ return kDefault;
+}
+
+//-----------------------------------------------------------------------------
+
+AddressList EnsurePortOnAddressList(const AddressList& list, uint16 port) {
+ if (list.empty() || list.front().port() == port)
+ return list;
+ return AddressList::CopyWithPort(list, port);
+}
+
+// Returns true if |addresses| contains only IPv4 loopback addresses.
+bool IsAllIPv4Loopback(const AddressList& addresses) {
+ for (unsigned i = 0; i < addresses.size(); ++i) {
+ const IPAddressNumber& address = addresses[i].address();
+ switch (addresses[i].GetFamily()) {
+ case ADDRESS_FAMILY_IPV4:
+ if (address[0] != 127)
+ return false;
+ break;
+ case ADDRESS_FAMILY_IPV6:
+ return false;
+ default:
+ NOTREACHED();
+ return false;
+ }
+ }
+ return true;
+}
+
+// Creates NetLog parameters when the resolve failed.
+base::Value* NetLogProcTaskFailedCallback(uint32 attempt_number,
+ int net_error,
+ int os_error,
+ NetLog::LogLevel /* log_level */) {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ if (attempt_number)
+ dict->SetInteger("attempt_number", attempt_number);
+
+ dict->SetInteger("net_error", net_error);
+
+ if (os_error) {
+ dict->SetInteger("os_error", os_error);
+#if defined(OS_POSIX)
+ dict->SetString("os_error_string", gai_strerror(os_error));
+#elif defined(OS_WIN)
+ // Map the error code to a human-readable string.
+ LPWSTR error_string = NULL;
+ int size = FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER |
+ FORMAT_MESSAGE_FROM_SYSTEM,
+ 0, // Use the internal message table.
+ os_error,
+ 0, // Use default language.
+ (LPWSTR)&error_string,
+ 0, // Buffer size.
+ 0); // Arguments (unused).
+ dict->SetString("os_error_string", WideToUTF8(error_string));
+ LocalFree(error_string);
+#endif
+ }
+
+ return dict;
+}
+
+// Creates NetLog parameters when the DnsTask failed.
+base::Value* NetLogDnsTaskFailedCallback(int net_error,
+ int dns_error,
+ NetLog::LogLevel /* log_level */) {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ dict->SetInteger("net_error", net_error);
+ if (dns_error)
+ dict->SetInteger("dns_error", dns_error);
+ return dict;
+};
+
+// Creates NetLog parameters containing the information in a RequestInfo object,
+// along with the associated NetLog::Source.
+base::Value* NetLogRequestInfoCallback(const NetLog::Source& source,
+ const HostResolver::RequestInfo* info,
+ NetLog::LogLevel /* log_level */) {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ source.AddToEventParameters(dict);
+
+ dict->SetString("host", info->host_port_pair().ToString());
+ dict->SetInteger("address_family",
+ static_cast<int>(info->address_family()));
+ dict->SetBoolean("allow_cached_response", info->allow_cached_response());
+ dict->SetBoolean("is_speculative", info->is_speculative());
+ dict->SetInteger("priority", info->priority());
+ return dict;
+}
+
+// Creates NetLog parameters for the creation of a HostResolverImpl::Job.
+base::Value* NetLogJobCreationCallback(const NetLog::Source& source,
+ const std::string* host,
+ NetLog::LogLevel /* log_level */) {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ source.AddToEventParameters(dict);
+ dict->SetString("host", *host);
+ return dict;
+}
+
+// Creates NetLog parameters for HOST_RESOLVER_IMPL_JOB_ATTACH/DETACH events.
+base::Value* NetLogJobAttachCallback(const NetLog::Source& source,
+ RequestPriority priority,
+ NetLog::LogLevel /* log_level */) {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ source.AddToEventParameters(dict);
+ dict->SetInteger("priority", priority);
+ return dict;
+}
+
+// Creates NetLog parameters for the DNS_CONFIG_CHANGED event.
+base::Value* NetLogDnsConfigCallback(const DnsConfig* config,
+ NetLog::LogLevel /* log_level */) {
+ return config->ToValue();
+}
+
+// The logging routines are defined here because some requests are resolved
+// without a Request object.
+
+// Logs when a request has just been started.
+void LogStartRequest(const BoundNetLog& source_net_log,
+ const BoundNetLog& request_net_log,
+ const HostResolver::RequestInfo& info) {
+ source_net_log.BeginEvent(
+ NetLog::TYPE_HOST_RESOLVER_IMPL,
+ request_net_log.source().ToEventParametersCallback());
+
+ request_net_log.BeginEvent(
+ NetLog::TYPE_HOST_RESOLVER_IMPL_REQUEST,
+ base::Bind(&NetLogRequestInfoCallback, source_net_log.source(), &info));
+}
+
+// Logs when a request has just completed (before its callback is run).
+void LogFinishRequest(const BoundNetLog& source_net_log,
+ const BoundNetLog& request_net_log,
+ const HostResolver::RequestInfo& info,
+ int net_error) {
+ request_net_log.EndEventWithNetErrorCode(
+ NetLog::TYPE_HOST_RESOLVER_IMPL_REQUEST, net_error);
+ source_net_log.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL);
+}
+
+// Logs when a request has been cancelled.
+void LogCancelRequest(const BoundNetLog& source_net_log,
+ const BoundNetLog& request_net_log,
+ const HostResolverImpl::RequestInfo& info) {
+ request_net_log.AddEvent(NetLog::TYPE_CANCELLED);
+ request_net_log.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_REQUEST);
+ source_net_log.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL);
+}
+
+//-----------------------------------------------------------------------------
+
+// Keeps track of the highest priority.
+class PriorityTracker {
+ public:
+ explicit PriorityTracker(RequestPriority initial_priority)
+ : highest_priority_(initial_priority), total_count_(0) {
+ memset(counts_, 0, sizeof(counts_));
+ }
+
+ RequestPriority highest_priority() const {
+ return highest_priority_;
+ }
+
+ size_t total_count() const {
+ return total_count_;
+ }
+
+ void Add(RequestPriority req_priority) {
+ ++total_count_;
+ ++counts_[req_priority];
+ if (highest_priority_ < req_priority)
+ highest_priority_ = req_priority;
+ }
+
+ void Remove(RequestPriority req_priority) {
+ DCHECK_GT(total_count_, 0u);
+ DCHECK_GT(counts_[req_priority], 0u);
+ --total_count_;
+ --counts_[req_priority];
+ size_t i;
+ for (i = highest_priority_; i > MINIMUM_PRIORITY && !counts_[i]; --i);
+ highest_priority_ = static_cast<RequestPriority>(i);
+
+ // In absence of requests, default to MINIMUM_PRIORITY.
+ if (total_count_ == 0)
+ DCHECK_EQ(MINIMUM_PRIORITY, highest_priority_);
+ }
+
+ private:
+ RequestPriority highest_priority_;
+ size_t total_count_;
+ size_t counts_[NUM_PRIORITIES];
+};
+
+} // namespace
+
+//-----------------------------------------------------------------------------
+
+// Holds the data for a request that could not be completed synchronously.
+// It is owned by a Job. Canceled Requests are only marked as canceled rather
+// than removed from the Job's |requests_| list.
+class HostResolverImpl::Request {
+ public:
+ Request(const BoundNetLog& source_net_log,
+ const BoundNetLog& request_net_log,
+ const RequestInfo& info,
+ const CompletionCallback& callback,
+ AddressList* addresses)
+ : source_net_log_(source_net_log),
+ request_net_log_(request_net_log),
+ info_(info),
+ job_(NULL),
+ callback_(callback),
+ addresses_(addresses),
+ request_time_(base::TimeTicks::Now()) {
+ }
+
+ // Mark the request as canceled.
+ void MarkAsCanceled() {
+ job_ = NULL;
+ addresses_ = NULL;
+ callback_.Reset();
+ }
+
+ bool was_canceled() const {
+ return callback_.is_null();
+ }
+
+ void set_job(Job* job) {
+ DCHECK(job);
+ // Identify which job the request is waiting on.
+ job_ = job;
+ }
+
+ // Prepare final AddressList and call completion callback.
+ void OnComplete(int error, const AddressList& addr_list) {
+ DCHECK(!was_canceled());
+ if (error == OK)
+ *addresses_ = EnsurePortOnAddressList(addr_list, info_.port());
+ CompletionCallback callback = callback_;
+ MarkAsCanceled();
+ callback.Run(error);
+ }
+
+ Job* job() const {
+ return job_;
+ }
+
+ // NetLog for the source, passed in HostResolver::Resolve.
+ const BoundNetLog& source_net_log() {
+ return source_net_log_;
+ }
+
+ // NetLog for this request.
+ const BoundNetLog& request_net_log() {
+ return request_net_log_;
+ }
+
+ const RequestInfo& info() const {
+ return info_;
+ }
+
+ base::TimeTicks request_time() const {
+ return request_time_;
+ }
+
+ private:
+ BoundNetLog source_net_log_;
+ BoundNetLog request_net_log_;
+
+ // The request info that started the request.
+ RequestInfo info_;
+
+ // The resolve job that this request is dependent on.
+ Job* job_;
+
+ // The user's callback to invoke when the request completes.
+ CompletionCallback callback_;
+
+ // The address list to save result into.
+ AddressList* addresses_;
+
+ const base::TimeTicks request_time_;
+
+ DISALLOW_COPY_AND_ASSIGN(Request);
+};
+
+//------------------------------------------------------------------------------
+
+// Calls HostResolverProc on the WorkerPool. Performs retries if necessary.
+//
+// Whenever we try to resolve the host, we post a delayed task to check if host
+// resolution (OnLookupComplete) is completed or not. If the original attempt
+// hasn't completed, then we start another attempt for host resolution. We take
+// the results from the first attempt that finishes and ignore the results from
+// all other attempts.
+//
+// TODO(szym): Move to separate source file for testing and mocking.
+//
+class HostResolverImpl::ProcTask
+ : public base::RefCountedThreadSafe<HostResolverImpl::ProcTask> {
+ public:
+ typedef base::Callback<void(int net_error,
+ const AddressList& addr_list)> Callback;
+
+ ProcTask(const Key& key,
+ const ProcTaskParams& params,
+ const Callback& callback,
+ const BoundNetLog& job_net_log)
+ : key_(key),
+ params_(params),
+ callback_(callback),
+ origin_loop_(base::MessageLoopProxy::current()),
+ attempt_number_(0),
+ completed_attempt_number_(0),
+ completed_attempt_error_(ERR_UNEXPECTED),
+ had_non_speculative_request_(false),
+ net_log_(job_net_log) {
+ if (!params_.resolver_proc.get())
+ params_.resolver_proc = HostResolverProc::GetDefault();
+ // If default is unset, use the system proc.
+ if (!params_.resolver_proc.get())
+ params_.resolver_proc = new SystemHostResolverProc();
+ }
+
+ void Start() {
+ DCHECK(origin_loop_->BelongsToCurrentThread());
+ net_log_.BeginEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_PROC_TASK);
+ StartLookupAttempt();
+ }
+
+ // Cancels this ProcTask. It will be orphaned. Any outstanding resolve
+ // attempts running on worker threads will continue running. Only once all the
+ // attempts complete will the final reference to this ProcTask be released.
+ void Cancel() {
+ DCHECK(origin_loop_->BelongsToCurrentThread());
+
+ if (was_canceled() || was_completed())
+ return;
+
+ callback_.Reset();
+ net_log_.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_PROC_TASK);
+ }
+
+ void set_had_non_speculative_request() {
+ DCHECK(origin_loop_->BelongsToCurrentThread());
+ had_non_speculative_request_ = true;
+ }
+
+ bool was_canceled() const {
+ DCHECK(origin_loop_->BelongsToCurrentThread());
+ return callback_.is_null();
+ }
+
+ bool was_completed() const {
+ DCHECK(origin_loop_->BelongsToCurrentThread());
+ return completed_attempt_number_ > 0;
+ }
+
+ private:
+ friend class base::RefCountedThreadSafe<ProcTask>;
+ ~ProcTask() {}
+
+ void StartLookupAttempt() {
+ DCHECK(origin_loop_->BelongsToCurrentThread());
+ base::TimeTicks start_time = base::TimeTicks::Now();
+ ++attempt_number_;
+ // Dispatch the lookup attempt to a worker thread.
+ if (!base::WorkerPool::PostTask(
+ FROM_HERE,
+ base::Bind(&ProcTask::DoLookup, this, start_time, attempt_number_),
+ true)) {
+ NOTREACHED();
+
+ // Since we could be running within Resolve() right now, we can't just
+ // call OnLookupComplete(). Instead we must wait until Resolve() has
+ // returned (IO_PENDING).
+ origin_loop_->PostTask(
+ FROM_HERE,
+ base::Bind(&ProcTask::OnLookupComplete, this, AddressList(),
+ start_time, attempt_number_, ERR_UNEXPECTED, 0));
+ return;
+ }
+
+ net_log_.AddEvent(
+ NetLog::TYPE_HOST_RESOLVER_IMPL_ATTEMPT_STARTED,
+ NetLog::IntegerCallback("attempt_number", attempt_number_));
+
+ // If we don't get the results within a given time, RetryIfNotComplete
+ // will start a new attempt on a different worker thread if none of our
+ // outstanding attempts have completed yet.
+ if (attempt_number_ <= params_.max_retry_attempts) {
+ origin_loop_->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(&ProcTask::RetryIfNotComplete, this),
+ params_.unresponsive_delay);
+ }
+ }
+
+ // WARNING: This code runs inside a worker pool. The shutdown code cannot
+ // wait for it to finish, so we must be very careful here about using other
+ // objects (like MessageLoops, Singletons, etc). During shutdown these objects
+ // may no longer exist. Multiple DoLookups() could be running in parallel, so
+ // any state inside of |this| must not mutate .
+ void DoLookup(const base::TimeTicks& start_time,
+ const uint32 attempt_number) {
+ AddressList results;
+ int os_error = 0;
+ // Running on the worker thread
+ int error = params_.resolver_proc->Resolve(key_.hostname,
+ key_.address_family,
+ key_.host_resolver_flags,
+ &results,
+ &os_error);
+
+ origin_loop_->PostTask(
+ FROM_HERE,
+ base::Bind(&ProcTask::OnLookupComplete, this, results, start_time,
+ attempt_number, error, os_error));
+ }
+
+ // Makes next attempt if DoLookup() has not finished (runs on origin thread).
+ void RetryIfNotComplete() {
+ DCHECK(origin_loop_->BelongsToCurrentThread());
+
+ if (was_completed() || was_canceled())
+ return;
+
+ params_.unresponsive_delay *= params_.retry_factor;
+ StartLookupAttempt();
+ }
+
+ // Callback for when DoLookup() completes (runs on origin thread).
+ void OnLookupComplete(const AddressList& results,
+ const base::TimeTicks& start_time,
+ const uint32 attempt_number,
+ int error,
+ const int os_error) {
+ DCHECK(origin_loop_->BelongsToCurrentThread());
+ // If results are empty, we should return an error.
+ bool empty_list_on_ok = (error == OK && results.empty());
+ UMA_HISTOGRAM_BOOLEAN("DNS.EmptyAddressListAndNoError", empty_list_on_ok);
+ if (empty_list_on_ok)
+ error = ERR_NAME_NOT_RESOLVED;
+
+ bool was_retry_attempt = attempt_number > 1;
+
+ // Ideally the following code would be part of host_resolver_proc.cc,
+ // however it isn't safe to call NetworkChangeNotifier from worker threads.
+ // So we do it here on the IO thread instead.
+ if (error != OK && NetworkChangeNotifier::IsOffline())
+ error = ERR_INTERNET_DISCONNECTED;
+
+ // If this is the first attempt that is finishing later, then record data
+ // for the first attempt. Won't contaminate with retry attempt's data.
+ if (!was_retry_attempt)
+ RecordPerformanceHistograms(start_time, error, os_error);
+
+ RecordAttemptHistograms(start_time, attempt_number, error, os_error);
+
+ if (was_canceled())
+ return;
+
+ NetLog::ParametersCallback net_log_callback;
+ if (error != OK) {
+ net_log_callback = base::Bind(&NetLogProcTaskFailedCallback,
+ attempt_number,
+ error,
+ os_error);
+ } else {
+ net_log_callback = NetLog::IntegerCallback("attempt_number",
+ attempt_number);
+ }
+ net_log_.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_ATTEMPT_FINISHED,
+ net_log_callback);
+
+ if (was_completed())
+ return;
+
+ // Copy the results from the first worker thread that resolves the host.
+ results_ = results;
+ completed_attempt_number_ = attempt_number;
+ completed_attempt_error_ = error;
+
+ if (was_retry_attempt) {
+ // If retry attempt finishes before 1st attempt, then get stats on how
+ // much time is saved by having spawned an extra attempt.
+ retry_attempt_finished_time_ = base::TimeTicks::Now();
+ }
+
+ if (error != OK) {
+ net_log_callback = base::Bind(&NetLogProcTaskFailedCallback,
+ 0, error, os_error);
+ } else {
+ net_log_callback = results_.CreateNetLogCallback();
+ }
+ net_log_.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_PROC_TASK,
+ net_log_callback);
+
+ callback_.Run(error, results_);
+ }
+
+ void RecordPerformanceHistograms(const base::TimeTicks& start_time,
+ const int error,
+ const int os_error) const {
+ DCHECK(origin_loop_->BelongsToCurrentThread());
+ enum Category { // Used in HISTOGRAM_ENUMERATION.
+ RESOLVE_SUCCESS,
+ RESOLVE_FAIL,
+ RESOLVE_SPECULATIVE_SUCCESS,
+ RESOLVE_SPECULATIVE_FAIL,
+ RESOLVE_MAX, // Bounding value.
+ };
+ int category = RESOLVE_MAX; // Illegal value for later DCHECK only.
+
+ base::TimeDelta duration = base::TimeTicks::Now() - start_time;
+ if (error == OK) {
+ if (had_non_speculative_request_) {
+ category = RESOLVE_SUCCESS;
+ DNS_HISTOGRAM("DNS.ResolveSuccess", duration);
+ } else {
+ category = RESOLVE_SPECULATIVE_SUCCESS;
+ DNS_HISTOGRAM("DNS.ResolveSpeculativeSuccess", duration);
+ }
+
+ // Log DNS lookups based on |address_family|. This will help us determine
+ // if IPv4 or IPv4/6 lookups are faster or slower.
+ switch(key_.address_family) {
+ case ADDRESS_FAMILY_IPV4:
+ DNS_HISTOGRAM("DNS.ResolveSuccess_FAMILY_IPV4", duration);
+ break;
+ case ADDRESS_FAMILY_IPV6:
+ DNS_HISTOGRAM("DNS.ResolveSuccess_FAMILY_IPV6", duration);
+ break;
+ case ADDRESS_FAMILY_UNSPECIFIED:
+ DNS_HISTOGRAM("DNS.ResolveSuccess_FAMILY_UNSPEC", duration);
+ break;
+ }
+ } else {
+ if (had_non_speculative_request_) {
+ category = RESOLVE_FAIL;
+ DNS_HISTOGRAM("DNS.ResolveFail", duration);
+ } else {
+ category = RESOLVE_SPECULATIVE_FAIL;
+ DNS_HISTOGRAM("DNS.ResolveSpeculativeFail", duration);
+ }
+ // Log DNS lookups based on |address_family|. This will help us determine
+ // if IPv4 or IPv4/6 lookups are faster or slower.
+ switch(key_.address_family) {
+ case ADDRESS_FAMILY_IPV4:
+ DNS_HISTOGRAM("DNS.ResolveFail_FAMILY_IPV4", duration);
+ break;
+ case ADDRESS_FAMILY_IPV6:
+ DNS_HISTOGRAM("DNS.ResolveFail_FAMILY_IPV6", duration);
+ break;
+ case ADDRESS_FAMILY_UNSPECIFIED:
+ DNS_HISTOGRAM("DNS.ResolveFail_FAMILY_UNSPEC", duration);
+ break;
+ }
+ UMA_HISTOGRAM_CUSTOM_ENUMERATION(kOSErrorsForGetAddrinfoHistogramName,
+ std::abs(os_error),
+ GetAllGetAddrinfoOSErrors());
+ }
+ DCHECK_LT(category, static_cast<int>(RESOLVE_MAX)); // Be sure it was set.
+
+ UMA_HISTOGRAM_ENUMERATION("DNS.ResolveCategory", category, RESOLVE_MAX);
+ }
+
+ void RecordAttemptHistograms(const base::TimeTicks& start_time,
+ const uint32 attempt_number,
+ const int error,
+ const int os_error) const {
+ DCHECK(origin_loop_->BelongsToCurrentThread());
+ bool first_attempt_to_complete =
+ completed_attempt_number_ == attempt_number;
+ bool is_first_attempt = (attempt_number == 1);
+
+ if (first_attempt_to_complete) {
+ // If this was first attempt to complete, then record the resolution
+ // status of the attempt.
+ if (completed_attempt_error_ == OK) {
+ UMA_HISTOGRAM_ENUMERATION(
+ "DNS.AttemptFirstSuccess", attempt_number, 100);
+ } else {
+ UMA_HISTOGRAM_ENUMERATION(
+ "DNS.AttemptFirstFailure", attempt_number, 100);
+ }
+ }
+
+ if (error == OK)
+ UMA_HISTOGRAM_ENUMERATION("DNS.AttemptSuccess", attempt_number, 100);
+ else
+ UMA_HISTOGRAM_ENUMERATION("DNS.AttemptFailure", attempt_number, 100);
+
+ // If first attempt didn't finish before retry attempt, then calculate stats
+ // on how much time is saved by having spawned an extra attempt.
+ if (!first_attempt_to_complete && is_first_attempt && !was_canceled()) {
+ DNS_HISTOGRAM("DNS.AttemptTimeSavedByRetry",
+ base::TimeTicks::Now() - retry_attempt_finished_time_);
+ }
+
+ if (was_canceled() || !first_attempt_to_complete) {
+ // Count those attempts which completed after the job was already canceled
+ // OR after the job was already completed by an earlier attempt (so in
+ // effect).
+ UMA_HISTOGRAM_ENUMERATION("DNS.AttemptDiscarded", attempt_number, 100);
+
+ // Record if job is canceled.
+ if (was_canceled())
+ UMA_HISTOGRAM_ENUMERATION("DNS.AttemptCancelled", attempt_number, 100);
+ }
+
+ base::TimeDelta duration = base::TimeTicks::Now() - start_time;
+ if (error == OK)
+ DNS_HISTOGRAM("DNS.AttemptSuccessDuration", duration);
+ else
+ DNS_HISTOGRAM("DNS.AttemptFailDuration", duration);
+ }
+
+ // Set on the origin thread, read on the worker thread.
+ Key key_;
+
+ // Holds an owning reference to the HostResolverProc that we are going to use.
+ // This may not be the current resolver procedure by the time we call
+ // ResolveAddrInfo, but that's OK... we'll use it anyways, and the owning
+ // reference ensures that it remains valid until we are done.
+ ProcTaskParams params_;
+
+ // The listener to the results of this ProcTask.
+ Callback callback_;
+
+ // Used to post ourselves onto the origin thread.
+ scoped_refptr<base::MessageLoopProxy> origin_loop_;
+
+ // Keeps track of the number of attempts we have made so far to resolve the
+ // host. Whenever we start an attempt to resolve the host, we increase this
+ // number.
+ uint32 attempt_number_;
+
+ // The index of the attempt which finished first (or 0 if the job is still in
+ // progress).
+ uint32 completed_attempt_number_;
+
+ // The result (a net error code) from the first attempt to complete.
+ int completed_attempt_error_;
+
+ // The time when retry attempt was finished.
+ base::TimeTicks retry_attempt_finished_time_;
+
+ // True if a non-speculative request was ever attached to this job
+ // (regardless of whether or not it was later canceled.
+ // This boolean is used for histogramming the duration of jobs used to
+ // service non-speculative requests.
+ bool had_non_speculative_request_;
+
+ AddressList results_;
+
+ BoundNetLog net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(ProcTask);
+};
+
+//-----------------------------------------------------------------------------
+
+// Wraps a call to HaveOnlyLoopbackAddresses to be executed on the WorkerPool as
+// it takes 40-100ms and should not block initialization.
+class HostResolverImpl::LoopbackProbeJob {
+ public:
+ explicit LoopbackProbeJob(const base::WeakPtr<HostResolverImpl>& resolver)
+ : resolver_(resolver),
+ result_(false) {
+ DCHECK(resolver.get());
+ const bool kIsSlow = true;
+ base::WorkerPool::PostTaskAndReply(
+ FROM_HERE,
+ base::Bind(&LoopbackProbeJob::DoProbe, base::Unretained(this)),
+ base::Bind(&LoopbackProbeJob::OnProbeComplete, base::Owned(this)),
+ kIsSlow);
+ }
+
+ virtual ~LoopbackProbeJob() {}
+
+ private:
+ // Runs on worker thread.
+ void DoProbe() {
+ result_ = HaveOnlyLoopbackAddresses();
+ }
+
+ void OnProbeComplete() {
+ if (!resolver_.get())
+ return;
+ resolver_->SetHaveOnlyLoopbackAddresses(result_);
+ }
+
+ // Used/set only on origin thread.
+ base::WeakPtr<HostResolverImpl> resolver_;
+
+ bool result_;
+
+ DISALLOW_COPY_AND_ASSIGN(LoopbackProbeJob);
+};
+
+//-----------------------------------------------------------------------------
+
+// Resolves the hostname using DnsTransaction.
+// TODO(szym): This could be moved to separate source file as well.
+class HostResolverImpl::DnsTask : public base::SupportsWeakPtr<DnsTask> {
+ public:
+ typedef base::Callback<void(int net_error,
+ const AddressList& addr_list,
+ base::TimeDelta ttl)> Callback;
+
+ DnsTask(DnsClient* client,
+ const Key& key,
+ const Callback& callback,
+ const BoundNetLog& job_net_log)
+ : client_(client),
+ family_(key.address_family),
+ callback_(callback),
+ net_log_(job_net_log) {
+ DCHECK(client);
+ DCHECK(!callback.is_null());
+
+ // If unspecified, do IPv4 first, because suffix search will be faster.
+ uint16 qtype = (family_ == ADDRESS_FAMILY_IPV6) ?
+ dns_protocol::kTypeAAAA :
+ dns_protocol::kTypeA;
+ transaction_ = client_->GetTransactionFactory()->CreateTransaction(
+ key.hostname,
+ qtype,
+ base::Bind(&DnsTask::OnTransactionComplete, base::Unretained(this),
+ true /* first_query */, base::TimeTicks::Now()),
+ net_log_);
+ }
+
+ void Start() {
+ net_log_.BeginEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_DNS_TASK);
+ transaction_->Start();
+ }
+
+ private:
+ void OnTransactionComplete(bool first_query,
+ const base::TimeTicks& start_time,
+ DnsTransaction* transaction,
+ int net_error,
+ const DnsResponse* response) {
+ DCHECK(transaction);
+ base::TimeDelta duration = base::TimeTicks::Now() - start_time;
+ // Run |callback_| last since the owning Job will then delete this DnsTask.
+ if (net_error != OK) {
+ DNS_HISTOGRAM("AsyncDNS.TransactionFailure", duration);
+ OnFailure(net_error, DnsResponse::DNS_PARSE_OK);
+ return;
+ }
+
+ CHECK(response);
+ DNS_HISTOGRAM("AsyncDNS.TransactionSuccess", duration);
+ switch (transaction->GetType()) {
+ case dns_protocol::kTypeA:
+ DNS_HISTOGRAM("AsyncDNS.TransactionSuccess_A", duration);
+ break;
+ case dns_protocol::kTypeAAAA:
+ DNS_HISTOGRAM("AsyncDNS.TransactionSuccess_AAAA", duration);
+ break;
+ }
+ AddressList addr_list;
+ base::TimeDelta ttl;
+ DnsResponse::Result result = response->ParseToAddressList(&addr_list, &ttl);
+ UMA_HISTOGRAM_ENUMERATION("AsyncDNS.ParseToAddressList",
+ result,
+ DnsResponse::DNS_PARSE_RESULT_MAX);
+ if (result != DnsResponse::DNS_PARSE_OK) {
+ // Fail even if the other query succeeds.
+ OnFailure(ERR_DNS_MALFORMED_RESPONSE, result);
+ return;
+ }
+
+ bool needs_sort = false;
+ if (first_query) {
+ DCHECK(client_->GetConfig()) <<
+ "Transaction should have been aborted when config changed!";
+ if (family_ == ADDRESS_FAMILY_IPV6) {
+ needs_sort = (addr_list.size() > 1);
+ } else if (family_ == ADDRESS_FAMILY_UNSPECIFIED) {
+ first_addr_list_ = addr_list;
+ first_ttl_ = ttl;
+ // Use fully-qualified domain name to avoid search.
+ transaction_ = client_->GetTransactionFactory()->CreateTransaction(
+ response->GetDottedName() + ".",
+ dns_protocol::kTypeAAAA,
+ base::Bind(&DnsTask::OnTransactionComplete, base::Unretained(this),
+ false /* first_query */, base::TimeTicks::Now()),
+ net_log_);
+ transaction_->Start();
+ return;
+ }
+ } else {
+ DCHECK_EQ(ADDRESS_FAMILY_UNSPECIFIED, family_);
+ bool has_ipv6_addresses = !addr_list.empty();
+ if (!first_addr_list_.empty()) {
+ ttl = std::min(ttl, first_ttl_);
+ // Place IPv4 addresses after IPv6.
+ addr_list.insert(addr_list.end(), first_addr_list_.begin(),
+ first_addr_list_.end());
+ }
+ needs_sort = (has_ipv6_addresses && addr_list.size() > 1);
+ }
+
+ if (addr_list.empty()) {
+ // TODO(szym): Don't fallback to ProcTask in this case.
+ OnFailure(ERR_NAME_NOT_RESOLVED, DnsResponse::DNS_PARSE_OK);
+ return;
+ }
+
+ if (needs_sort) {
+ // Sort could complete synchronously.
+ client_->GetAddressSorter()->Sort(
+ addr_list,
+ base::Bind(&DnsTask::OnSortComplete,
+ AsWeakPtr(),
+ base::TimeTicks::Now(),
+ ttl));
+ } else {
+ OnSuccess(addr_list, ttl);
+ }
+ }
+
+ void OnSortComplete(base::TimeTicks start_time,
+ base::TimeDelta ttl,
+ bool success,
+ const AddressList& addr_list) {
+ if (!success) {
+ DNS_HISTOGRAM("AsyncDNS.SortFailure",
+ base::TimeTicks::Now() - start_time);
+ OnFailure(ERR_DNS_SORT_ERROR, DnsResponse::DNS_PARSE_OK);
+ return;
+ }
+
+ DNS_HISTOGRAM("AsyncDNS.SortSuccess",
+ base::TimeTicks::Now() - start_time);
+
+ // AddressSorter prunes unusable destinations.
+ if (addr_list.empty()) {
+ LOG(WARNING) << "Address list empty after RFC3484 sort";
+ OnFailure(ERR_NAME_NOT_RESOLVED, DnsResponse::DNS_PARSE_OK);
+ return;
+ }
+
+ OnSuccess(addr_list, ttl);
+ }
+
+ void OnFailure(int net_error, DnsResponse::Result result) {
+ DCHECK_NE(OK, net_error);
+ net_log_.EndEvent(
+ NetLog::TYPE_HOST_RESOLVER_IMPL_DNS_TASK,
+ base::Bind(&NetLogDnsTaskFailedCallback, net_error, result));
+ callback_.Run(net_error, AddressList(), base::TimeDelta());
+ }
+
+ void OnSuccess(const AddressList& addr_list, base::TimeDelta ttl) {
+ net_log_.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_DNS_TASK,
+ addr_list.CreateNetLogCallback());
+ callback_.Run(OK, addr_list, ttl);
+ }
+
+ DnsClient* client_;
+ AddressFamily family_;
+ // The listener to the results of this DnsTask.
+ Callback callback_;
+ const BoundNetLog net_log_;
+
+ scoped_ptr<DnsTransaction> transaction_;
+
+ // Results from the first transaction. Used only if |family_| is unspecified.
+ AddressList first_addr_list_;
+ base::TimeDelta first_ttl_;
+
+ DISALLOW_COPY_AND_ASSIGN(DnsTask);
+};
+
+//-----------------------------------------------------------------------------
+
+// Aggregates all Requests for the same Key. Dispatched via PriorityDispatch.
+class HostResolverImpl::Job : public PrioritizedDispatcher::Job {
+ public:
+ // Creates new job for |key| where |request_net_log| is bound to the
+ // request that spawned it.
+ Job(const base::WeakPtr<HostResolverImpl>& resolver,
+ const Key& key,
+ RequestPriority priority,
+ const BoundNetLog& request_net_log)
+ : resolver_(resolver),
+ key_(key),
+ priority_tracker_(priority),
+ had_non_speculative_request_(false),
+ had_dns_config_(false),
+ dns_task_error_(OK),
+ creation_time_(base::TimeTicks::Now()),
+ priority_change_time_(creation_time_),
+ net_log_(BoundNetLog::Make(request_net_log.net_log(),
+ NetLog::SOURCE_HOST_RESOLVER_IMPL_JOB)) {
+ request_net_log.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_CREATE_JOB);
+
+ net_log_.BeginEvent(
+ NetLog::TYPE_HOST_RESOLVER_IMPL_JOB,
+ base::Bind(&NetLogJobCreationCallback,
+ request_net_log.source(),
+ &key_.hostname));
+ }
+
+ virtual ~Job() {
+ if (is_running()) {
+ // |resolver_| was destroyed with this Job still in flight.
+ // Clean-up, record in the log, but don't run any callbacks.
+ if (is_proc_running()) {
+ proc_task_->Cancel();
+ proc_task_ = NULL;
+ }
+ // Clean up now for nice NetLog.
+ dns_task_.reset(NULL);
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB,
+ ERR_ABORTED);
+ } else if (is_queued()) {
+ // |resolver_| was destroyed without running this Job.
+ // TODO(szym): is there any benefit in having this distinction?
+ net_log_.AddEvent(NetLog::TYPE_CANCELLED);
+ net_log_.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB);
+ }
+ // else CompleteRequests logged EndEvent.
+
+ // Log any remaining Requests as cancelled.
+ for (RequestsList::const_iterator it = requests_.begin();
+ it != requests_.end(); ++it) {
+ Request* req = *it;
+ if (req->was_canceled())
+ continue;
+ DCHECK_EQ(this, req->job());
+ LogCancelRequest(req->source_net_log(), req->request_net_log(),
+ req->info());
+ }
+ }
+
+ // Add this job to the dispatcher.
+ void Schedule() {
+ handle_ = resolver_->dispatcher_.Add(this, priority());
+ }
+
+ void AddRequest(scoped_ptr<Request> req) {
+ DCHECK_EQ(key_.hostname, req->info().hostname());
+
+ req->set_job(this);
+ priority_tracker_.Add(req->info().priority());
+
+ req->request_net_log().AddEvent(
+ NetLog::TYPE_HOST_RESOLVER_IMPL_JOB_ATTACH,
+ net_log_.source().ToEventParametersCallback());
+
+ net_log_.AddEvent(
+ NetLog::TYPE_HOST_RESOLVER_IMPL_JOB_REQUEST_ATTACH,
+ base::Bind(&NetLogJobAttachCallback,
+ req->request_net_log().source(),
+ priority()));
+
+ // TODO(szym): Check if this is still needed.
+ if (!req->info().is_speculative()) {
+ had_non_speculative_request_ = true;
+ if (proc_task_.get())
+ proc_task_->set_had_non_speculative_request();
+ }
+
+ requests_.push_back(req.release());
+
+ UpdatePriority();
+ }
+
+ // Marks |req| as cancelled. If it was the last active Request, also finishes
+ // this Job, marking it as cancelled, and deletes it.
+ void CancelRequest(Request* req) {
+ DCHECK_EQ(key_.hostname, req->info().hostname());
+ DCHECK(!req->was_canceled());
+
+ // Don't remove it from |requests_| just mark it canceled.
+ req->MarkAsCanceled();
+ LogCancelRequest(req->source_net_log(), req->request_net_log(),
+ req->info());
+
+ priority_tracker_.Remove(req->info().priority());
+ net_log_.AddEvent(
+ NetLog::TYPE_HOST_RESOLVER_IMPL_JOB_REQUEST_DETACH,
+ base::Bind(&NetLogJobAttachCallback,
+ req->request_net_log().source(),
+ priority()));
+
+ if (num_active_requests() > 0) {
+ UpdatePriority();
+ } else {
+ // If we were called from a Request's callback within CompleteRequests,
+ // that Request could not have been cancelled, so num_active_requests()
+ // could not be 0. Therefore, we are not in CompleteRequests().
+ CompleteRequestsWithError(OK /* cancelled */);
+ }
+ }
+
+ // Called from AbortAllInProgressJobs. Completes all requests and destroys
+ // the job. This currently assumes the abort is due to a network change.
+ void Abort() {
+ DCHECK(is_running());
+ CompleteRequestsWithError(ERR_NETWORK_CHANGED);
+ }
+
+ // If DnsTask present, abort it and fall back to ProcTask.
+ void AbortDnsTask() {
+ if (dns_task_) {
+ dns_task_.reset();
+ dns_task_error_ = OK;
+ StartProcTask();
+ }
+ }
+
+ // Called by HostResolverImpl when this job is evicted due to queue overflow.
+ // Completes all requests and destroys the job.
+ void OnEvicted() {
+ DCHECK(!is_running());
+ DCHECK(is_queued());
+ handle_.Reset();
+
+ net_log_.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB_EVICTED);
+
+ // This signals to CompleteRequests that this job never ran.
+ CompleteRequestsWithError(ERR_HOST_RESOLVER_QUEUE_TOO_LARGE);
+ }
+
+ // Attempts to serve the job from HOSTS. Returns true if succeeded and
+ // this Job was destroyed.
+ bool ServeFromHosts() {
+ DCHECK_GT(num_active_requests(), 0u);
+ AddressList addr_list;
+ if (resolver_->ServeFromHosts(key(),
+ requests_.front()->info(),
+ &addr_list)) {
+ // This will destroy the Job.
+ CompleteRequests(
+ HostCache::Entry(OK, MakeAddressListForRequest(addr_list)),
+ base::TimeDelta());
+ return true;
+ }
+ return false;
+ }
+
+ const Key key() const {
+ return key_;
+ }
+
+ bool is_queued() const {
+ return !handle_.is_null();
+ }
+
+ bool is_running() const {
+ return is_dns_running() || is_proc_running();
+ }
+
+ private:
+ void UpdatePriority() {
+ if (is_queued()) {
+ if (priority() != static_cast<RequestPriority>(handle_.priority()))
+ priority_change_time_ = base::TimeTicks::Now();
+ handle_ = resolver_->dispatcher_.ChangePriority(handle_, priority());
+ }
+ }
+
+ AddressList MakeAddressListForRequest(const AddressList& list) const {
+ if (requests_.empty())
+ return list;
+ return AddressList::CopyWithPort(list, requests_.front()->info().port());
+ }
+
+ // PriorityDispatch::Job:
+ virtual void Start() OVERRIDE {
+ DCHECK(!is_running());
+ handle_.Reset();
+
+ net_log_.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB_STARTED);
+
+ had_dns_config_ = resolver_->HaveDnsConfig();
+
+ base::TimeTicks now = base::TimeTicks::Now();
+ base::TimeDelta queue_time = now - creation_time_;
+ base::TimeDelta queue_time_after_change = now - priority_change_time_;
+
+ if (had_dns_config_) {
+ DNS_HISTOGRAM_BY_PRIORITY("AsyncDNS.JobQueueTime", priority(),
+ queue_time);
+ DNS_HISTOGRAM_BY_PRIORITY("AsyncDNS.JobQueueTimeAfterChange", priority(),
+ queue_time_after_change);
+ } else {
+ DNS_HISTOGRAM_BY_PRIORITY("DNS.JobQueueTime", priority(), queue_time);
+ DNS_HISTOGRAM_BY_PRIORITY("DNS.JobQueueTimeAfterChange", priority(),
+ queue_time_after_change);
+ }
+
+ // Caution: Job::Start must not complete synchronously.
+ if (had_dns_config_ && !ResemblesMulticastDNSName(key_.hostname)) {
+ StartDnsTask();
+ } else {
+ StartProcTask();
+ }
+ }
+
+ // TODO(szym): Since DnsTransaction does not consume threads, we can increase
+ // the limits on |dispatcher_|. But in order to keep the number of WorkerPool
+ // threads low, we will need to use an "inner" PrioritizedDispatcher with
+ // tighter limits.
+ void StartProcTask() {
+ DCHECK(!is_dns_running());
+ proc_task_ = new ProcTask(
+ key_,
+ resolver_->proc_params_,
+ base::Bind(&Job::OnProcTaskComplete, base::Unretained(this),
+ base::TimeTicks::Now()),
+ net_log_);
+
+ if (had_non_speculative_request_)
+ proc_task_->set_had_non_speculative_request();
+ // Start() could be called from within Resolve(), hence it must NOT directly
+ // call OnProcTaskComplete, for example, on synchronous failure.
+ proc_task_->Start();
+ }
+
+ // Called by ProcTask when it completes.
+ void OnProcTaskComplete(base::TimeTicks start_time,
+ int net_error,
+ const AddressList& addr_list) {
+ DCHECK(is_proc_running());
+
+ if (!resolver_->resolved_known_ipv6_hostname_ &&
+ net_error == OK &&
+ key_.address_family == ADDRESS_FAMILY_UNSPECIFIED) {
+ if (key_.hostname == "www.google.com") {
+ resolver_->resolved_known_ipv6_hostname_ = true;
+ bool got_ipv6_address = false;
+ for (size_t i = 0; i < addr_list.size(); ++i) {
+ if (addr_list[i].GetFamily() == ADDRESS_FAMILY_IPV6) {
+ got_ipv6_address = true;
+ break;
+ }
+ }
+ UMA_HISTOGRAM_BOOLEAN("Net.UnspecResolvedIPv6", got_ipv6_address);
+ }
+ }
+
+ if (dns_task_error_ != OK) {
+ base::TimeDelta duration = base::TimeTicks::Now() - start_time;
+ if (net_error == OK) {
+ DNS_HISTOGRAM("AsyncDNS.FallbackSuccess", duration);
+ if ((dns_task_error_ == ERR_NAME_NOT_RESOLVED) &&
+ ResemblesNetBIOSName(key_.hostname)) {
+ UmaAsyncDnsResolveStatus(RESOLVE_STATUS_SUSPECT_NETBIOS);
+ } else {
+ UmaAsyncDnsResolveStatus(RESOLVE_STATUS_PROC_SUCCESS);
+ }
+ UMA_HISTOGRAM_CUSTOM_ENUMERATION("AsyncDNS.ResolveError",
+ std::abs(dns_task_error_),
+ GetAllErrorCodesForUma());
+ resolver_->OnDnsTaskResolve(dns_task_error_);
+ } else {
+ DNS_HISTOGRAM("AsyncDNS.FallbackFail", duration);
+ UmaAsyncDnsResolveStatus(RESOLVE_STATUS_FAIL);
+ }
+ }
+
+ base::TimeDelta ttl =
+ base::TimeDelta::FromSeconds(kNegativeCacheEntryTTLSeconds);
+ if (net_error == OK)
+ ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds);
+
+ // Don't store the |ttl| in cache since it's not obtained from the server.
+ CompleteRequests(
+ HostCache::Entry(net_error, MakeAddressListForRequest(addr_list)),
+ ttl);
+ }
+
+ void StartDnsTask() {
+ DCHECK(resolver_->HaveDnsConfig());
+ base::TimeTicks start_time = base::TimeTicks::Now();
+ dns_task_.reset(new DnsTask(
+ resolver_->dns_client_.get(),
+ key_,
+ base::Bind(&Job::OnDnsTaskComplete, base::Unretained(this), start_time),
+ net_log_));
+
+ dns_task_->Start();
+ }
+
+ // Called if DnsTask fails. It is posted from StartDnsTask, so Job may be
+ // deleted before this callback. In this case dns_task is deleted as well,
+ // so we use it as indicator whether Job is still valid.
+ void OnDnsTaskFailure(const base::WeakPtr<DnsTask>& dns_task,
+ base::TimeDelta duration,
+ int net_error) {
+ DNS_HISTOGRAM("AsyncDNS.ResolveFail", duration);
+
+ if (dns_task == NULL)
+ return;
+
+ dns_task_error_ = net_error;
+
+ // TODO(szym): Run ServeFromHosts now if nsswitch.conf says so.
+ // http://crbug.com/117655
+
+ // TODO(szym): Some net errors indicate lack of connectivity. Starting
+ // ProcTask in that case is a waste of time.
+ if (resolver_->fallback_to_proctask_) {
+ dns_task_.reset();
+ StartProcTask();
+ } else {
+ UmaAsyncDnsResolveStatus(RESOLVE_STATUS_FAIL);
+ CompleteRequestsWithError(net_error);
+ }
+ }
+
+ // Called by DnsTask when it completes.
+ void OnDnsTaskComplete(base::TimeTicks start_time,
+ int net_error,
+ const AddressList& addr_list,
+ base::TimeDelta ttl) {
+ DCHECK(is_dns_running());
+
+ base::TimeDelta duration = base::TimeTicks::Now() - start_time;
+ if (net_error != OK) {
+ OnDnsTaskFailure(dns_task_->AsWeakPtr(), duration, net_error);
+ return;
+ }
+ DNS_HISTOGRAM("AsyncDNS.ResolveSuccess", duration);
+ // Log DNS lookups based on |address_family|.
+ switch(key_.address_family) {
+ case ADDRESS_FAMILY_IPV4:
+ DNS_HISTOGRAM("AsyncDNS.ResolveSuccess_FAMILY_IPV4", duration);
+ break;
+ case ADDRESS_FAMILY_IPV6:
+ DNS_HISTOGRAM("AsyncDNS.ResolveSuccess_FAMILY_IPV6", duration);
+ break;
+ case ADDRESS_FAMILY_UNSPECIFIED:
+ DNS_HISTOGRAM("AsyncDNS.ResolveSuccess_FAMILY_UNSPEC", duration);
+ break;
+ }
+
+ UmaAsyncDnsResolveStatus(RESOLVE_STATUS_DNS_SUCCESS);
+ RecordTTL(ttl);
+
+ resolver_->OnDnsTaskResolve(OK);
+
+ base::TimeDelta bounded_ttl =
+ std::max(ttl, base::TimeDelta::FromSeconds(kMinimumTTLSeconds));
+
+ CompleteRequests(
+ HostCache::Entry(net_error, MakeAddressListForRequest(addr_list), ttl),
+ bounded_ttl);
+ }
+
+ // Performs Job's last rites. Completes all Requests. Deletes this.
+ void CompleteRequests(const HostCache::Entry& entry,
+ base::TimeDelta ttl) {
+ CHECK(resolver_.get());
+
+ // This job must be removed from resolver's |jobs_| now to make room for a
+ // new job with the same key in case one of the OnComplete callbacks decides
+ // to spawn one. Consequently, the job deletes itself when CompleteRequests
+ // is done.
+ scoped_ptr<Job> self_deleter(this);
+
+ resolver_->RemoveJob(this);
+
+ if (is_running()) {
+ DCHECK(!is_queued());
+ if (is_proc_running()) {
+ proc_task_->Cancel();
+ proc_task_ = NULL;
+ }
+ dns_task_.reset();
+
+ // Signal dispatcher that a slot has opened.
+ resolver_->dispatcher_.OnJobFinished();
+ } else if (is_queued()) {
+ resolver_->dispatcher_.Cancel(handle_);
+ handle_.Reset();
+ }
+
+ if (num_active_requests() == 0) {
+ net_log_.AddEvent(NetLog::TYPE_CANCELLED);
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB,
+ OK);
+ return;
+ }
+
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB,
+ entry.error);
+
+ DCHECK(!requests_.empty());
+
+ if (entry.error == OK) {
+ // Record this histogram here, when we know the system has a valid DNS
+ // configuration.
+ UMA_HISTOGRAM_BOOLEAN("AsyncDNS.HaveDnsConfig",
+ resolver_->received_dns_config_);
+ }
+
+ bool did_complete = (entry.error != ERR_NETWORK_CHANGED) &&
+ (entry.error != ERR_HOST_RESOLVER_QUEUE_TOO_LARGE);
+ if (did_complete)
+ resolver_->CacheResult(key_, entry, ttl);
+
+ // Complete all of the requests that were attached to the job.
+ for (RequestsList::const_iterator it = requests_.begin();
+ it != requests_.end(); ++it) {
+ Request* req = *it;
+
+ if (req->was_canceled())
+ continue;
+
+ DCHECK_EQ(this, req->job());
+ // Update the net log and notify registered observers.
+ LogFinishRequest(req->source_net_log(), req->request_net_log(),
+ req->info(), entry.error);
+ if (did_complete) {
+ // Record effective total time from creation to completion.
+ RecordTotalTime(had_dns_config_, req->info().is_speculative(),
+ base::TimeTicks::Now() - req->request_time());
+ }
+ req->OnComplete(entry.error, entry.addrlist);
+
+ // Check if the resolver was destroyed as a result of running the
+ // callback. If it was, we could continue, but we choose to bail.
+ if (!resolver_.get())
+ return;
+ }
+ }
+
+ // Convenience wrapper for CompleteRequests in case of failure.
+ void CompleteRequestsWithError(int net_error) {
+ CompleteRequests(HostCache::Entry(net_error, AddressList()),
+ base::TimeDelta());
+ }
+
+ RequestPriority priority() const {
+ return priority_tracker_.highest_priority();
+ }
+
+ // Number of non-canceled requests in |requests_|.
+ size_t num_active_requests() const {
+ return priority_tracker_.total_count();
+ }
+
+ bool is_dns_running() const {
+ return dns_task_.get() != NULL;
+ }
+
+ bool is_proc_running() const {
+ return proc_task_.get() != NULL;
+ }
+
+ base::WeakPtr<HostResolverImpl> resolver_;
+
+ Key key_;
+
+ // Tracks the highest priority across |requests_|.
+ PriorityTracker priority_tracker_;
+
+ bool had_non_speculative_request_;
+
+ // Distinguishes measurements taken while DnsClient was fully configured.
+ bool had_dns_config_;
+
+ // Result of DnsTask.
+ int dns_task_error_;
+
+ const base::TimeTicks creation_time_;
+ base::TimeTicks priority_change_time_;
+
+ BoundNetLog net_log_;
+
+ // Resolves the host using a HostResolverProc.
+ scoped_refptr<ProcTask> proc_task_;
+
+ // Resolves the host using a DnsTransaction.
+ scoped_ptr<DnsTask> dns_task_;
+
+ // All Requests waiting for the result of this Job. Some can be canceled.
+ RequestsList requests_;
+
+ // A handle used in |HostResolverImpl::dispatcher_|.
+ PrioritizedDispatcher::Handle handle_;
+};
+
+//-----------------------------------------------------------------------------
+
+HostResolverImpl::ProcTaskParams::ProcTaskParams(
+ HostResolverProc* resolver_proc,
+ size_t max_retry_attempts)
+ : resolver_proc(resolver_proc),
+ max_retry_attempts(max_retry_attempts),
+ unresponsive_delay(base::TimeDelta::FromMilliseconds(6000)),
+ retry_factor(2) {
+}
+
+HostResolverImpl::ProcTaskParams::~ProcTaskParams() {}
+
+HostResolverImpl::HostResolverImpl(
+ scoped_ptr<HostCache> cache,
+ const PrioritizedDispatcher::Limits& job_limits,
+ const ProcTaskParams& proc_params,
+ NetLog* net_log)
+ : cache_(cache.Pass()),
+ dispatcher_(job_limits),
+ max_queued_jobs_(job_limits.total_jobs * 100u),
+ proc_params_(proc_params),
+ net_log_(net_log),
+ default_address_family_(ADDRESS_FAMILY_UNSPECIFIED),
+ weak_ptr_factory_(this),
+ probe_weak_ptr_factory_(this),
+ received_dns_config_(false),
+ num_dns_failures_(0),
+ probe_ipv6_support_(true),
+ resolved_known_ipv6_hostname_(false),
+ additional_resolver_flags_(0),
+ fallback_to_proctask_(true) {
+
+ DCHECK_GE(dispatcher_.num_priorities(), static_cast<size_t>(NUM_PRIORITIES));
+
+ // Maximum of 4 retry attempts for host resolution.
+ static const size_t kDefaultMaxRetryAttempts = 4u;
+
+ if (proc_params_.max_retry_attempts == HostResolver::kDefaultRetryAttempts)
+ proc_params_.max_retry_attempts = kDefaultMaxRetryAttempts;
+
+#if defined(OS_WIN)
+ EnsureWinsockInit();
+#endif
+#if defined(OS_POSIX) && !defined(OS_MACOSX) && !defined(OS_ANDROID)
+ new LoopbackProbeJob(weak_ptr_factory_.GetWeakPtr());
+#endif
+ NetworkChangeNotifier::AddIPAddressObserver(this);
+ NetworkChangeNotifier::AddDNSObserver(this);
+#if defined(OS_POSIX) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) && \
+ !defined(OS_ANDROID)
+ EnsureDnsReloaderInit();
+#endif
+
+ // TODO(szym): Remove when received_dns_config_ is removed, once
+ // http://crbug.com/137914 is resolved.
+ {
+ DnsConfig dns_config;
+ NetworkChangeNotifier::GetDnsConfig(&dns_config);
+ received_dns_config_ = dns_config.IsValid();
+ }
+
+ fallback_to_proctask_ = !ConfigureAsyncDnsNoFallbackFieldTrial();
+}
+
+HostResolverImpl::~HostResolverImpl() {
+ // This will also cancel all outstanding requests.
+ STLDeleteValues(&jobs_);
+
+ NetworkChangeNotifier::RemoveIPAddressObserver(this);
+ NetworkChangeNotifier::RemoveDNSObserver(this);
+}
+
+void HostResolverImpl::SetMaxQueuedJobs(size_t value) {
+ DCHECK_EQ(0u, dispatcher_.num_queued_jobs());
+ DCHECK_GT(value, 0u);
+ max_queued_jobs_ = value;
+}
+
+int HostResolverImpl::Resolve(const RequestInfo& info,
+ AddressList* addresses,
+ const CompletionCallback& callback,
+ RequestHandle* out_req,
+ const BoundNetLog& source_net_log) {
+ DCHECK(addresses);
+ DCHECK(CalledOnValidThread());
+ DCHECK_EQ(false, callback.is_null());
+
+ // Check that the caller supplied a valid hostname to resolve.
+ std::string labeled_hostname;
+ if (!DNSDomainFromDot(info.hostname(), &labeled_hostname))
+ return ERR_NAME_NOT_RESOLVED;
+
+ // Make a log item for the request.
+ BoundNetLog request_net_log = BoundNetLog::Make(net_log_,
+ NetLog::SOURCE_HOST_RESOLVER_IMPL_REQUEST);
+
+ LogStartRequest(source_net_log, request_net_log, info);
+
+ // Build a key that identifies the request in the cache and in the
+ // outstanding jobs map.
+ Key key = GetEffectiveKeyForRequest(info, request_net_log);
+
+ int rv = ResolveHelper(key, info, addresses, request_net_log);
+ if (rv != ERR_DNS_CACHE_MISS) {
+ LogFinishRequest(source_net_log, request_net_log, info, rv);
+ RecordTotalTime(HaveDnsConfig(), info.is_speculative(), base::TimeDelta());
+ return rv;
+ }
+
+ // Next we need to attach our request to a "job". This job is responsible for
+ // calling "getaddrinfo(hostname)" on a worker thread.
+
+ JobMap::iterator jobit = jobs_.find(key);
+ Job* job;
+ if (jobit == jobs_.end()) {
+ job = new Job(weak_ptr_factory_.GetWeakPtr(), key, info.priority(),
+ request_net_log);
+ job->Schedule();
+
+ // Check for queue overflow.
+ if (dispatcher_.num_queued_jobs() > max_queued_jobs_) {
+ Job* evicted = static_cast<Job*>(dispatcher_.EvictOldestLowest());
+ DCHECK(evicted);
+ evicted->OnEvicted(); // Deletes |evicted|.
+ if (evicted == job) {
+ rv = ERR_HOST_RESOLVER_QUEUE_TOO_LARGE;
+ LogFinishRequest(source_net_log, request_net_log, info, rv);
+ return rv;
+ }
+ }
+ jobs_.insert(jobit, std::make_pair(key, job));
+ } else {
+ job = jobit->second;
+ }
+
+ // Can't complete synchronously. Create and attach request.
+ scoped_ptr<Request> req(new Request(source_net_log,
+ request_net_log,
+ info,
+ callback,
+ addresses));
+ if (out_req)
+ *out_req = reinterpret_cast<RequestHandle>(req.get());
+
+ job->AddRequest(req.Pass());
+ // Completion happens during Job::CompleteRequests().
+ return ERR_IO_PENDING;
+}
+
+int HostResolverImpl::ResolveHelper(const Key& key,
+ const RequestInfo& info,
+ AddressList* addresses,
+ const BoundNetLog& request_net_log) {
+ // The result of |getaddrinfo| for empty hosts is inconsistent across systems.
+ // On Windows it gives the default interface's address, whereas on Linux it
+ // gives an error. We will make it fail on all platforms for consistency.
+ if (info.hostname().empty() || info.hostname().size() > kMaxHostLength)
+ return ERR_NAME_NOT_RESOLVED;
+
+ int net_error = ERR_UNEXPECTED;
+ if (ResolveAsIP(key, info, &net_error, addresses))
+ return net_error;
+ if (ServeFromCache(key, info, &net_error, addresses)) {
+ request_net_log.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_CACHE_HIT);
+ return net_error;
+ }
+ // TODO(szym): Do not do this if nsswitch.conf instructs not to.
+ // http://crbug.com/117655
+ if (ServeFromHosts(key, info, addresses)) {
+ request_net_log.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_HOSTS_HIT);
+ return OK;
+ }
+ return ERR_DNS_CACHE_MISS;
+}
+
+int HostResolverImpl::ResolveFromCache(const RequestInfo& info,
+ AddressList* addresses,
+ const BoundNetLog& source_net_log) {
+ DCHECK(CalledOnValidThread());
+ DCHECK(addresses);
+
+ // Make a log item for the request.
+ BoundNetLog request_net_log = BoundNetLog::Make(net_log_,
+ NetLog::SOURCE_HOST_RESOLVER_IMPL_REQUEST);
+
+ // Update the net log and notify registered observers.
+ LogStartRequest(source_net_log, request_net_log, info);
+
+ Key key = GetEffectiveKeyForRequest(info, request_net_log);
+
+ int rv = ResolveHelper(key, info, addresses, request_net_log);
+ LogFinishRequest(source_net_log, request_net_log, info, rv);
+ return rv;
+}
+
+void HostResolverImpl::CancelRequest(RequestHandle req_handle) {
+ DCHECK(CalledOnValidThread());
+ Request* req = reinterpret_cast<Request*>(req_handle);
+ DCHECK(req);
+ Job* job = req->job();
+ DCHECK(job);
+ job->CancelRequest(req);
+}
+
+void HostResolverImpl::SetDefaultAddressFamily(AddressFamily address_family) {
+ DCHECK(CalledOnValidThread());
+ default_address_family_ = address_family;
+ probe_ipv6_support_ = false;
+}
+
+AddressFamily HostResolverImpl::GetDefaultAddressFamily() const {
+ return default_address_family_;
+}
+
+void HostResolverImpl::SetDnsClientEnabled(bool enabled) {
+ DCHECK(CalledOnValidThread());
+#if defined(ENABLE_BUILT_IN_DNS)
+ if (enabled && !dns_client_) {
+ SetDnsClient(DnsClient::CreateClient(net_log_));
+ } else if (!enabled && dns_client_) {
+ SetDnsClient(scoped_ptr<DnsClient>());
+ }
+#endif
+}
+
+HostCache* HostResolverImpl::GetHostCache() {
+ return cache_.get();
+}
+
+base::Value* HostResolverImpl::GetDnsConfigAsValue() const {
+ // Check if async DNS is disabled.
+ if (!dns_client_.get())
+ return NULL;
+
+ // Check if async DNS is enabled, but we currently have no configuration
+ // for it.
+ const DnsConfig* dns_config = dns_client_->GetConfig();
+ if (dns_config == NULL)
+ return new base::DictionaryValue();
+
+ return dns_config->ToValue();
+}
+
+bool HostResolverImpl::ResolveAsIP(const Key& key,
+ const RequestInfo& info,
+ int* net_error,
+ AddressList* addresses) {
+ DCHECK(addresses);
+ DCHECK(net_error);
+ IPAddressNumber ip_number;
+ if (!ParseIPLiteralToNumber(key.hostname, &ip_number))
+ return false;
+
+ DCHECK_EQ(key.host_resolver_flags &
+ ~(HOST_RESOLVER_CANONNAME | HOST_RESOLVER_LOOPBACK_ONLY |
+ HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6),
+ 0) << " Unhandled flag";
+ bool ipv6_disabled = (default_address_family_ == ADDRESS_FAMILY_IPV4) &&
+ !probe_ipv6_support_;
+ *net_error = OK;
+ if ((ip_number.size() == kIPv6AddressSize) && ipv6_disabled) {
+ *net_error = ERR_NAME_NOT_RESOLVED;
+ } else {
+ *addresses = AddressList::CreateFromIPAddress(ip_number, info.port());
+ if (key.host_resolver_flags & HOST_RESOLVER_CANONNAME)
+ addresses->SetDefaultCanonicalName();
+ }
+ return true;
+}
+
+bool HostResolverImpl::ServeFromCache(const Key& key,
+ const RequestInfo& info,
+ int* net_error,
+ AddressList* addresses) {
+ DCHECK(addresses);
+ DCHECK(net_error);
+ if (!info.allow_cached_response() || !cache_.get())
+ return false;
+
+ const HostCache::Entry* cache_entry = cache_->Lookup(
+ key, base::TimeTicks::Now());
+ if (!cache_entry)
+ return false;
+
+ *net_error = cache_entry->error;
+ if (*net_error == OK) {
+ if (cache_entry->has_ttl())
+ RecordTTL(cache_entry->ttl);
+ *addresses = EnsurePortOnAddressList(cache_entry->addrlist, info.port());
+ }
+ return true;
+}
+
+bool HostResolverImpl::ServeFromHosts(const Key& key,
+ const RequestInfo& info,
+ AddressList* addresses) {
+ DCHECK(addresses);
+ if (!HaveDnsConfig())
+ return false;
+ addresses->clear();
+
+ // HOSTS lookups are case-insensitive.
+ std::string hostname = StringToLowerASCII(key.hostname);
+
+ const DnsHosts& hosts = dns_client_->GetConfig()->hosts;
+
+ // If |address_family| is ADDRESS_FAMILY_UNSPECIFIED other implementations
+ // (glibc and c-ares) return the first matching line. We have more
+ // flexibility, but lose implicit ordering.
+ // We prefer IPv6 because "happy eyeballs" will fall back to IPv4 if
+ // necessary.
+ if (key.address_family == ADDRESS_FAMILY_IPV6 ||
+ key.address_family == ADDRESS_FAMILY_UNSPECIFIED) {
+ DnsHosts::const_iterator it = hosts.find(
+ DnsHostsKey(hostname, ADDRESS_FAMILY_IPV6));
+ if (it != hosts.end())
+ addresses->push_back(IPEndPoint(it->second, info.port()));
+ }
+
+ if (key.address_family == ADDRESS_FAMILY_IPV4 ||
+ key.address_family == ADDRESS_FAMILY_UNSPECIFIED) {
+ DnsHosts::const_iterator it = hosts.find(
+ DnsHostsKey(hostname, ADDRESS_FAMILY_IPV4));
+ if (it != hosts.end())
+ addresses->push_back(IPEndPoint(it->second, info.port()));
+ }
+
+ // If got only loopback addresses and the family was restricted, resolve
+ // again, without restrictions. See SystemHostResolverCall for rationale.
+ if ((key.host_resolver_flags &
+ HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6) &&
+ IsAllIPv4Loopback(*addresses)) {
+ Key new_key(key);
+ new_key.address_family = ADDRESS_FAMILY_UNSPECIFIED;
+ new_key.host_resolver_flags &=
+ ~HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
+ return ServeFromHosts(new_key, info, addresses);
+ }
+ return !addresses->empty();
+}
+
+void HostResolverImpl::CacheResult(const Key& key,
+ const HostCache::Entry& entry,
+ base::TimeDelta ttl) {
+ if (cache_.get())
+ cache_->Set(key, entry, base::TimeTicks::Now(), ttl);
+}
+
+void HostResolverImpl::RemoveJob(Job* job) {
+ DCHECK(job);
+ JobMap::iterator it = jobs_.find(job->key());
+ if (it != jobs_.end() && it->second == job)
+ jobs_.erase(it);
+}
+
+void HostResolverImpl::SetHaveOnlyLoopbackAddresses(bool result) {
+ if (result) {
+ additional_resolver_flags_ |= HOST_RESOLVER_LOOPBACK_ONLY;
+ } else {
+ additional_resolver_flags_ &= ~HOST_RESOLVER_LOOPBACK_ONLY;
+ }
+}
+
+HostResolverImpl::Key HostResolverImpl::GetEffectiveKeyForRequest(
+ const RequestInfo& info, const BoundNetLog& net_log) const {
+ HostResolverFlags effective_flags =
+ info.host_resolver_flags() | additional_resolver_flags_;
+ AddressFamily effective_address_family = info.address_family();
+
+ if (info.address_family() == ADDRESS_FAMILY_UNSPECIFIED) {
+ if (probe_ipv6_support_) {
+ base::TimeTicks start_time = base::TimeTicks::Now();
+ // Google DNS address.
+ const uint8 kIPv6Address[] =
+ { 0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x88 };
+ IPAddressNumber address(kIPv6Address,
+ kIPv6Address + arraysize(kIPv6Address));
+ bool rv6 = IsGloballyReachable(address, net_log);
+ if (rv6)
+ net_log.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_IPV6_SUPPORTED);
+
+ UMA_HISTOGRAM_TIMES("Net.IPv6ConnectDuration",
+ base::TimeTicks::Now() - start_time);
+ if (rv6) {
+ UMA_HISTOGRAM_BOOLEAN("Net.IPv6ConnectSuccessMatch",
+ default_address_family_ == ADDRESS_FAMILY_UNSPECIFIED);
+ } else {
+ UMA_HISTOGRAM_BOOLEAN("Net.IPv6ConnectFailureMatch",
+ default_address_family_ != ADDRESS_FAMILY_UNSPECIFIED);
+
+ effective_address_family = ADDRESS_FAMILY_IPV4;
+ effective_flags |= HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
+ }
+ } else {
+ effective_address_family = default_address_family_;
+ }
+ }
+
+ return Key(info.hostname(), effective_address_family, effective_flags);
+}
+
+void HostResolverImpl::AbortAllInProgressJobs() {
+ // In Abort, a Request callback could spawn new Jobs with matching keys, so
+ // first collect and remove all running jobs from |jobs_|.
+ ScopedVector<Job> jobs_to_abort;
+ for (JobMap::iterator it = jobs_.begin(); it != jobs_.end(); ) {
+ Job* job = it->second;
+ if (job->is_running()) {
+ jobs_to_abort.push_back(job);
+ jobs_.erase(it++);
+ } else {
+ DCHECK(job->is_queued());
+ ++it;
+ }
+ }
+
+ // Check if no dispatcher slots leaked out.
+ DCHECK_EQ(dispatcher_.num_running_jobs(), jobs_to_abort.size());
+
+ // Life check to bail once |this| is deleted.
+ base::WeakPtr<HostResolverImpl> self = weak_ptr_factory_.GetWeakPtr();
+
+ // Then Abort them.
+ for (size_t i = 0; self.get() && i < jobs_to_abort.size(); ++i) {
+ jobs_to_abort[i]->Abort();
+ jobs_to_abort[i] = NULL;
+ }
+}
+
+void HostResolverImpl::TryServingAllJobsFromHosts() {
+ if (!HaveDnsConfig())
+ return;
+
+ // TODO(szym): Do not do this if nsswitch.conf instructs not to.
+ // http://crbug.com/117655
+
+ // Life check to bail once |this| is deleted.
+ base::WeakPtr<HostResolverImpl> self = weak_ptr_factory_.GetWeakPtr();
+
+ for (JobMap::iterator it = jobs_.begin(); self.get() && it != jobs_.end();) {
+ Job* job = it->second;
+ ++it;
+ // This could remove |job| from |jobs_|, but iterator will remain valid.
+ job->ServeFromHosts();
+ }
+}
+
+void HostResolverImpl::OnIPAddressChanged() {
+ resolved_known_ipv6_hostname_ = false;
+ // Abandon all ProbeJobs.
+ probe_weak_ptr_factory_.InvalidateWeakPtrs();
+ if (cache_.get())
+ cache_->clear();
+#if defined(OS_POSIX) && !defined(OS_MACOSX) && !defined(OS_ANDROID)
+ new LoopbackProbeJob(probe_weak_ptr_factory_.GetWeakPtr());
+#endif
+ AbortAllInProgressJobs();
+ // |this| may be deleted inside AbortAllInProgressJobs().
+}
+
+void HostResolverImpl::OnDNSChanged() {
+ DnsConfig dns_config;
+ NetworkChangeNotifier::GetDnsConfig(&dns_config);
+
+ if (net_log_) {
+ net_log_->AddGlobalEntry(
+ NetLog::TYPE_DNS_CONFIG_CHANGED,
+ base::Bind(&NetLogDnsConfigCallback, &dns_config));
+ }
+
+ // TODO(szym): Remove once http://crbug.com/137914 is resolved.
+ received_dns_config_ = dns_config.IsValid();
+
+ num_dns_failures_ = 0;
+
+ // We want a new DnsSession in place, before we Abort running Jobs, so that
+ // the newly started jobs use the new config.
+ if (dns_client_.get()) {
+ dns_client_->SetConfig(dns_config);
+ if (dns_config.IsValid())
+ UMA_HISTOGRAM_BOOLEAN("AsyncDNS.DnsClientEnabled", true);
+ }
+
+ // If the DNS server has changed, existing cached info could be wrong so we
+ // have to drop our internal cache :( Note that OS level DNS caches, such
+ // as NSCD's cache should be dropped automatically by the OS when
+ // resolv.conf changes so we don't need to do anything to clear that cache.
+ if (cache_.get())
+ cache_->clear();
+
+ // Life check to bail once |this| is deleted.
+ base::WeakPtr<HostResolverImpl> self = weak_ptr_factory_.GetWeakPtr();
+
+ // Existing jobs will have been sent to the original server so they need to
+ // be aborted.
+ AbortAllInProgressJobs();
+
+ // |this| may be deleted inside AbortAllInProgressJobs().
+ if (self.get())
+ TryServingAllJobsFromHosts();
+}
+
+bool HostResolverImpl::HaveDnsConfig() const {
+ // Use DnsClient only if it's fully configured and there is no override by
+ // ScopedDefaultHostResolverProc.
+ // The alternative is to use NetworkChangeNotifier to override DnsConfig,
+ // but that would introduce construction order requirements for NCN and SDHRP.
+ return (dns_client_.get() != NULL) && (dns_client_->GetConfig() != NULL) &&
+ !(proc_params_.resolver_proc.get() == NULL &&
+ HostResolverProc::GetDefault() != NULL);
+}
+
+void HostResolverImpl::OnDnsTaskResolve(int net_error) {
+ DCHECK(dns_client_);
+ if (net_error == OK) {
+ num_dns_failures_ = 0;
+ return;
+ }
+ ++num_dns_failures_;
+ if (num_dns_failures_ < kMaximumDnsFailures)
+ return;
+ // Disable DnsClient until the next DNS change.
+ for (JobMap::iterator it = jobs_.begin(); it != jobs_.end(); ++it)
+ it->second->AbortDnsTask();
+ dns_client_->SetConfig(DnsConfig());
+ UMA_HISTOGRAM_BOOLEAN("AsyncDNS.DnsClientEnabled", false);
+ UMA_HISTOGRAM_CUSTOM_ENUMERATION("AsyncDNS.DnsClientDisabledReason",
+ std::abs(net_error),
+ GetAllErrorCodesForUma());
+}
+
+void HostResolverImpl::SetDnsClient(scoped_ptr<DnsClient> dns_client) {
+ if (HaveDnsConfig()) {
+ for (JobMap::iterator it = jobs_.begin(); it != jobs_.end(); ++it)
+ it->second->AbortDnsTask();
+ }
+ dns_client_ = dns_client.Pass();
+ if (!dns_client_ || dns_client_->GetConfig() ||
+ num_dns_failures_ >= kMaximumDnsFailures) {
+ return;
+ }
+ DnsConfig dns_config;
+ NetworkChangeNotifier::GetDnsConfig(&dns_config);
+ dns_client_->SetConfig(dns_config);
+ num_dns_failures_ = 0;
+ if (dns_config.IsValid())
+ UMA_HISTOGRAM_BOOLEAN("AsyncDNS.DnsClientEnabled", true);
+}
+
+} // namespace net
diff --git a/chromium/net/dns/host_resolver_impl.h b/chromium/net/dns/host_resolver_impl.h
new file mode 100644
index 00000000000..928d07af8b8
--- /dev/null
+++ b/chromium/net/dns/host_resolver_impl.h
@@ -0,0 +1,285 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_HOST_RESOLVER_IMPL_H_
+#define NET_DNS_HOST_RESOLVER_IMPL_H_
+
+#include <map>
+
+#include "base/basictypes.h"
+#include "base/gtest_prod_util.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/memory/scoped_vector.h"
+#include "base/memory/weak_ptr.h"
+#include "base/threading/non_thread_safe.h"
+#include "base/time/time.h"
+#include "net/base/capturing_net_log.h"
+#include "net/base/net_export.h"
+#include "net/base/network_change_notifier.h"
+#include "net/base/prioritized_dispatcher.h"
+#include "net/dns/host_cache.h"
+#include "net/dns/host_resolver.h"
+#include "net/dns/host_resolver_proc.h"
+
+namespace net {
+
+class BoundNetLog;
+class DnsClient;
+class NetLog;
+
+// For each hostname that is requested, HostResolver creates a
+// HostResolverImpl::Job. When this job gets dispatched it creates a ProcTask
+// which runs the given HostResolverProc on a WorkerPool thread. If requests for
+// that same host are made during the job's lifetime, they are attached to the
+// existing job rather than creating a new one. This avoids doing parallel
+// resolves for the same host.
+//
+// The way these classes fit together is illustrated by:
+//
+//
+// +----------- HostResolverImpl -------------+
+// | | |
+// Job Job Job
+// (for host1, fam1) (for host2, fam2) (for hostx, famx)
+// / | | / | | / | |
+// Request ... Request Request ... Request Request ... Request
+// (port1) (port2) (port3) (port4) (port5) (portX)
+//
+// When a HostResolverImpl::Job finishes, the callbacks of each waiting request
+// are run on the origin thread.
+//
+// Thread safety: This class is not threadsafe, and must only be called
+// from one thread!
+//
+// The HostResolverImpl enforces limits on the maximum number of concurrent
+// threads using PrioritizedDispatcher::Limits.
+//
+// Jobs are ordered in the queue based on their priority and order of arrival.
+class NET_EXPORT HostResolverImpl
+ : public HostResolver,
+ NON_EXPORTED_BASE(public base::NonThreadSafe),
+ public NetworkChangeNotifier::IPAddressObserver,
+ public NetworkChangeNotifier::DNSObserver {
+ public:
+ // Parameters for ProcTask which resolves hostnames using HostResolveProc.
+ //
+ // |resolver_proc| is used to perform the actual resolves; it must be
+ // thread-safe since it is run from multiple worker threads. If
+ // |resolver_proc| is NULL then the default host resolver procedure is
+ // used (which is SystemHostResolverProc except if overridden).
+ //
+ // For each attempt, we could start another attempt if host is not resolved
+ // within |unresponsive_delay| time. We keep attempting to resolve the host
+ // for |max_retry_attempts|. For every retry attempt, we grow the
+ // |unresponsive_delay| by the |retry_factor| amount (that is retry interval
+ // is multiplied by the retry factor each time). Once we have retried
+ // |max_retry_attempts|, we give up on additional attempts.
+ //
+ struct NET_EXPORT_PRIVATE ProcTaskParams {
+ // Sets up defaults.
+ ProcTaskParams(HostResolverProc* resolver_proc, size_t max_retry_attempts);
+
+ ~ProcTaskParams();
+
+ // The procedure to use for resolving host names. This will be NULL, except
+ // in the case of unit-tests which inject custom host resolving behaviors.
+ scoped_refptr<HostResolverProc> resolver_proc;
+
+ // Maximum number retry attempts to resolve the hostname.
+ // Pass HostResolver::kDefaultRetryAttempts to choose a default value.
+ size_t max_retry_attempts;
+
+ // This is the limit after which we make another attempt to resolve the host
+ // if the worker thread has not responded yet.
+ base::TimeDelta unresponsive_delay;
+
+ // Factor to grow |unresponsive_delay| when we re-re-try.
+ uint32 retry_factor;
+ };
+
+ // Creates a HostResolver that first uses the local cache |cache|, and then
+ // falls back to |proc_params.resolver_proc|.
+ //
+ // If |cache| is NULL, then no caching is used. Otherwise we take
+ // ownership of the |cache| pointer, and will free it during destruction.
+ //
+ // |job_limits| specifies the maximum number of jobs that the resolver will
+ // run at once. This upper-bounds the total number of outstanding
+ // DNS transactions (not counting retransmissions and retries).
+ //
+ // |net_log| must remain valid for the life of the HostResolverImpl.
+ HostResolverImpl(scoped_ptr<HostCache> cache,
+ const PrioritizedDispatcher::Limits& job_limits,
+ const ProcTaskParams& proc_params,
+ NetLog* net_log);
+
+ // If any completion callbacks are pending when the resolver is destroyed,
+ // the host resolutions are cancelled, and the completion callbacks will not
+ // be called.
+ virtual ~HostResolverImpl();
+
+ // Configures maximum number of Jobs in the queue. Exposed for testing.
+ // Only allowed when the queue is empty.
+ void SetMaxQueuedJobs(size_t value);
+
+ // Set the DnsClient to be used for resolution. In case of failure, the
+ // HostResolverProc from ProcTaskParams will be queried. If the DnsClient is
+ // not pre-configured with a valid DnsConfig, a new config is fetched from
+ // NetworkChangeNotifier.
+ void SetDnsClient(scoped_ptr<DnsClient> dns_client);
+
+ // HostResolver methods:
+ virtual int Resolve(const RequestInfo& info,
+ AddressList* addresses,
+ const CompletionCallback& callback,
+ RequestHandle* out_req,
+ const BoundNetLog& source_net_log) OVERRIDE;
+ virtual int ResolveFromCache(const RequestInfo& info,
+ AddressList* addresses,
+ const BoundNetLog& source_net_log) OVERRIDE;
+ virtual void CancelRequest(RequestHandle req) OVERRIDE;
+ virtual void SetDefaultAddressFamily(AddressFamily address_family) OVERRIDE;
+ virtual AddressFamily GetDefaultAddressFamily() const OVERRIDE;
+ virtual void SetDnsClientEnabled(bool enabled) OVERRIDE;
+ virtual HostCache* GetHostCache() OVERRIDE;
+ virtual base::Value* GetDnsConfigAsValue() const OVERRIDE;
+
+ private:
+ friend class HostResolverImplTest;
+ class Job;
+ class ProcTask;
+ class LoopbackProbeJob;
+ class DnsTask;
+ class Request;
+ typedef HostCache::Key Key;
+ typedef std::map<Key, Job*> JobMap;
+ typedef ScopedVector<Request> RequestsList;
+
+ // Helper used by |Resolve()| and |ResolveFromCache()|. Performs IP
+ // literal, cache and HOSTS lookup (if enabled), returns OK if successful,
+ // ERR_NAME_NOT_RESOLVED if either hostname is invalid or IP literal is
+ // incompatible, ERR_DNS_CACHE_MISS if entry was not found in cache and HOSTS.
+ int ResolveHelper(const Key& key,
+ const RequestInfo& info,
+ AddressList* addresses,
+ const BoundNetLog& request_net_log);
+
+ // Tries to resolve |key| as an IP, returns true and sets |net_error| if
+ // succeeds, returns false otherwise.
+ bool ResolveAsIP(const Key& key,
+ const RequestInfo& info,
+ int* net_error,
+ AddressList* addresses);
+
+ // If |key| is not found in cache returns false, otherwise returns
+ // true, sets |net_error| to the cached error code and fills |addresses|
+ // if it is a positive entry.
+ bool ServeFromCache(const Key& key,
+ const RequestInfo& info,
+ int* net_error,
+ AddressList* addresses);
+
+ // If we have a DnsClient with a valid DnsConfig, and |key| is found in the
+ // HOSTS file, returns true and fills |addresses|. Otherwise returns false.
+ bool ServeFromHosts(const Key& key,
+ const RequestInfo& info,
+ AddressList* addresses);
+
+ // Callback from HaveOnlyLoopbackAddresses probe.
+ void SetHaveOnlyLoopbackAddresses(bool result);
+
+ // Returns the (hostname, address_family) key to use for |info|, choosing an
+ // "effective" address family by inheriting the resolver's default address
+ // family when the request leaves it unspecified.
+ Key GetEffectiveKeyForRequest(const RequestInfo& info,
+ const BoundNetLog& net_log) const;
+
+ // Records the result in cache if cache is present.
+ void CacheResult(const Key& key,
+ const HostCache::Entry& entry,
+ base::TimeDelta ttl);
+
+ // Removes |job| from |jobs_|, only if it exists.
+ void RemoveJob(Job* job);
+
+ // Aborts all in progress jobs with ERR_NETWORK_CHANGED and notifies their
+ // requests. Might start new jobs.
+ void AbortAllInProgressJobs();
+
+ // Attempts to serve each Job in |jobs_| from the HOSTS file if we have
+ // a DnsClient with a valid DnsConfig.
+ void TryServingAllJobsFromHosts();
+
+ // NetworkChangeNotifier::IPAddressObserver:
+ virtual void OnIPAddressChanged() OVERRIDE;
+
+ // NetworkChangeNotifier::DNSObserver:
+ virtual void OnDNSChanged() OVERRIDE;
+
+ // True if have a DnsClient with a valid DnsConfig.
+ bool HaveDnsConfig() const;
+
+ // Called when a host name is successfully resolved and DnsTask was run on it
+ // and resulted in |net_error|.
+ void OnDnsTaskResolve(int net_error);
+
+ // Allows the tests to catch slots leaking out of the dispatcher.
+ size_t num_running_jobs_for_tests() const {
+ return dispatcher_.num_running_jobs();
+ }
+
+ // Cache of host resolution results.
+ scoped_ptr<HostCache> cache_;
+
+ // Map from HostCache::Key to a Job.
+ JobMap jobs_;
+
+ // Starts Jobs according to their priority and the configured limits.
+ PrioritizedDispatcher dispatcher_;
+
+ // Limit on the maximum number of jobs queued in |dispatcher_|.
+ size_t max_queued_jobs_;
+
+ // Parameters for ProcTask.
+ ProcTaskParams proc_params_;
+
+ NetLog* net_log_;
+
+ // Address family to use when the request doesn't specify one.
+ AddressFamily default_address_family_;
+
+ base::WeakPtrFactory<HostResolverImpl> weak_ptr_factory_;
+
+ base::WeakPtrFactory<HostResolverImpl> probe_weak_ptr_factory_;
+
+ // If present, used by DnsTask and ServeFromHosts to resolve requests.
+ scoped_ptr<DnsClient> dns_client_;
+
+ // True if received valid config from |dns_config_service_|. Temporary, used
+ // to measure performance of DnsConfigService: http://crbug.com/125599
+ bool received_dns_config_;
+
+ // Number of consecutive failures of DnsTask, counted when fallback succeeds.
+ unsigned num_dns_failures_;
+
+ // True if probing is done for each Request to set address family. When false,
+ // explicit setting in |default_address_family_| is used.
+ bool probe_ipv6_support_;
+
+ // True iff ProcTask has successfully resolved a hostname known to have IPv6
+ // addresses using ADDRESS_FAMILY_UNSPECIFIED. Reset on IP address change.
+ bool resolved_known_ipv6_hostname_;
+
+ // Any resolver flags that should be added to a request by default.
+ HostResolverFlags additional_resolver_flags_;
+
+ // Allow fallback to ProcTask if DnsTask fails.
+ bool fallback_to_proctask_;
+
+ DISALLOW_COPY_AND_ASSIGN(HostResolverImpl);
+};
+
+} // namespace net
+
+#endif // NET_DNS_HOST_RESOLVER_IMPL_H_
diff --git a/chromium/net/dns/host_resolver_impl_unittest.cc b/chromium/net/dns/host_resolver_impl_unittest.cc
new file mode 100644
index 00000000000..f6b7f690675
--- /dev/null
+++ b/chromium/net/dns/host_resolver_impl_unittest.cc
@@ -0,0 +1,1641 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/host_resolver_impl.h"
+
+#include <algorithm>
+#include <string>
+
+#include "base/bind.h"
+#include "base/bind_helpers.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_vector.h"
+#include "base/message_loop/message_loop.h"
+#include "base/strings/string_util.h"
+#include "base/strings/stringprintf.h"
+#include "base/synchronization/condition_variable.h"
+#include "base/synchronization/lock.h"
+#include "base/test/test_timeouts.h"
+#include "base/time/time.h"
+#include "net/base/address_list.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+#include "net/dns/dns_client.h"
+#include "net/dns/dns_test_util.h"
+#include "net/dns/host_cache.h"
+#include "net/dns/mock_host_resolver.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+const size_t kMaxJobs = 10u;
+const size_t kMaxRetryAttempts = 4u;
+
+PrioritizedDispatcher::Limits DefaultLimits() {
+ PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, kMaxJobs);
+ return limits;
+}
+
+HostResolverImpl::ProcTaskParams DefaultParams(
+ HostResolverProc* resolver_proc) {
+ return HostResolverImpl::ProcTaskParams(resolver_proc, kMaxRetryAttempts);
+}
+
+// A HostResolverProc that pushes each host mapped into a list and allows
+// waiting for a specific number of requests. Unlike RuleBasedHostResolverProc
+// it never calls SystemHostResolverCall. By default resolves all hostnames to
+// "127.0.0.1". After AddRule(), it resolves only names explicitly specified.
+class MockHostResolverProc : public HostResolverProc {
+ public:
+ struct ResolveKey {
+ ResolveKey(const std::string& hostname, AddressFamily address_family)
+ : hostname(hostname), address_family(address_family) {}
+ bool operator<(const ResolveKey& other) const {
+ return address_family < other.address_family ||
+ (address_family == other.address_family && hostname < other.hostname);
+ }
+ std::string hostname;
+ AddressFamily address_family;
+ };
+
+ typedef std::vector<ResolveKey> CaptureList;
+
+ MockHostResolverProc()
+ : HostResolverProc(NULL),
+ num_requests_waiting_(0),
+ num_slots_available_(0),
+ requests_waiting_(&lock_),
+ slots_available_(&lock_) {
+ }
+
+ // Waits until |count| calls to |Resolve| are blocked. Returns false when
+ // timed out.
+ bool WaitFor(unsigned count) {
+ base::AutoLock lock(lock_);
+ base::Time start_time = base::Time::Now();
+ while (num_requests_waiting_ < count) {
+ requests_waiting_.TimedWait(TestTimeouts::action_timeout());
+ if (base::Time::Now() > start_time + TestTimeouts::action_timeout())
+ return false;
+ }
+ return true;
+ }
+
+ // Signals |count| waiting calls to |Resolve|. First come first served.
+ void SignalMultiple(unsigned count) {
+ base::AutoLock lock(lock_);
+ num_slots_available_ += count;
+ slots_available_.Broadcast();
+ }
+
+ // Signals all waiting calls to |Resolve|. Beware of races.
+ void SignalAll() {
+ base::AutoLock lock(lock_);
+ num_slots_available_ = num_requests_waiting_;
+ slots_available_.Broadcast();
+ }
+
+ void AddRule(const std::string& hostname, AddressFamily family,
+ const AddressList& result) {
+ base::AutoLock lock(lock_);
+ rules_[ResolveKey(hostname, family)] = result;
+ }
+
+ void AddRule(const std::string& hostname, AddressFamily family,
+ const std::string& ip_list) {
+ AddressList result;
+ int rv = ParseAddressList(ip_list, std::string(), &result);
+ DCHECK_EQ(OK, rv);
+ AddRule(hostname, family, result);
+ }
+
+ void AddRuleForAllFamilies(const std::string& hostname,
+ const std::string& ip_list) {
+ AddressList result;
+ int rv = ParseAddressList(ip_list, std::string(), &result);
+ DCHECK_EQ(OK, rv);
+ AddRule(hostname, ADDRESS_FAMILY_UNSPECIFIED, result);
+ AddRule(hostname, ADDRESS_FAMILY_IPV4, result);
+ AddRule(hostname, ADDRESS_FAMILY_IPV6, result);
+ }
+
+ virtual int Resolve(const std::string& hostname,
+ AddressFamily address_family,
+ HostResolverFlags host_resolver_flags,
+ AddressList* addrlist,
+ int* os_error) OVERRIDE {
+ base::AutoLock lock(lock_);
+ capture_list_.push_back(ResolveKey(hostname, address_family));
+ ++num_requests_waiting_;
+ requests_waiting_.Broadcast();
+ while (!num_slots_available_)
+ slots_available_.Wait();
+ DCHECK_GT(num_requests_waiting_, 0u);
+ --num_slots_available_;
+ --num_requests_waiting_;
+ if (rules_.empty()) {
+ int rv = ParseAddressList("127.0.0.1", std::string(), addrlist);
+ DCHECK_EQ(OK, rv);
+ return OK;
+ }
+ ResolveKey key(hostname, address_family);
+ if (rules_.count(key) == 0)
+ return ERR_NAME_NOT_RESOLVED;
+ *addrlist = rules_[key];
+ return OK;
+ }
+
+ CaptureList GetCaptureList() const {
+ CaptureList copy;
+ {
+ base::AutoLock lock(lock_);
+ copy = capture_list_;
+ }
+ return copy;
+ }
+
+ bool HasBlockedRequests() const {
+ base::AutoLock lock(lock_);
+ return num_requests_waiting_ > num_slots_available_;
+ }
+
+ protected:
+ virtual ~MockHostResolverProc() {}
+
+ private:
+ mutable base::Lock lock_;
+ std::map<ResolveKey, AddressList> rules_;
+ CaptureList capture_list_;
+ unsigned num_requests_waiting_;
+ unsigned num_slots_available_;
+ base::ConditionVariable requests_waiting_;
+ base::ConditionVariable slots_available_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockHostResolverProc);
+};
+
+bool AddressListContains(const AddressList& list, const std::string& address,
+ int port) {
+ IPAddressNumber ip;
+ bool rv = ParseIPLiteralToNumber(address, &ip);
+ DCHECK(rv);
+ return std::find(list.begin(),
+ list.end(),
+ IPEndPoint(ip, port)) != list.end();
+}
+
+// A wrapper for requests to a HostResolver.
+class Request {
+ public:
+ // Base class of handlers to be executed on completion of requests.
+ struct Handler {
+ virtual ~Handler() {}
+ virtual void Handle(Request* request) = 0;
+ };
+
+ Request(const HostResolver::RequestInfo& info,
+ size_t index,
+ HostResolver* resolver,
+ Handler* handler)
+ : info_(info),
+ index_(index),
+ resolver_(resolver),
+ handler_(handler),
+ quit_on_complete_(false),
+ result_(ERR_UNEXPECTED),
+ handle_(NULL) {}
+
+ int Resolve() {
+ DCHECK(resolver_);
+ DCHECK(!handle_);
+ list_ = AddressList();
+ result_ = resolver_->Resolve(
+ info_, &list_, base::Bind(&Request::OnComplete, base::Unretained(this)),
+ &handle_, BoundNetLog());
+ if (!list_.empty())
+ EXPECT_EQ(OK, result_);
+ return result_;
+ }
+
+ int ResolveFromCache() {
+ DCHECK(resolver_);
+ DCHECK(!handle_);
+ return resolver_->ResolveFromCache(info_, &list_, BoundNetLog());
+ }
+
+ void Cancel() {
+ DCHECK(resolver_);
+ DCHECK(handle_);
+ resolver_->CancelRequest(handle_);
+ handle_ = NULL;
+ }
+
+ const HostResolver::RequestInfo& info() const { return info_; }
+ size_t index() const { return index_; }
+ const AddressList& list() const { return list_; }
+ int result() const { return result_; }
+ bool completed() const { return result_ != ERR_IO_PENDING; }
+ bool pending() const { return handle_ != NULL; }
+
+ bool HasAddress(const std::string& address, int port) const {
+ return AddressListContains(list_, address, port);
+ }
+
+ // Returns the number of addresses in |list_|.
+ unsigned NumberOfAddresses() const {
+ return list_.size();
+ }
+
+ bool HasOneAddress(const std::string& address, int port) const {
+ return HasAddress(address, port) && (NumberOfAddresses() == 1u);
+ }
+
+ // Returns ERR_UNEXPECTED if timed out.
+ int WaitForResult() {
+ if (completed())
+ return result_;
+ base::CancelableClosure closure(base::MessageLoop::QuitClosure());
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE, closure.callback(), TestTimeouts::action_max_timeout());
+ quit_on_complete_ = true;
+ base::MessageLoop::current()->Run();
+ bool did_quit = !quit_on_complete_;
+ quit_on_complete_ = false;
+ closure.Cancel();
+ if (did_quit)
+ return result_;
+ else
+ return ERR_UNEXPECTED;
+ }
+
+ private:
+ void OnComplete(int rv) {
+ EXPECT_TRUE(pending());
+ EXPECT_EQ(ERR_IO_PENDING, result_);
+ EXPECT_NE(ERR_IO_PENDING, rv);
+ result_ = rv;
+ handle_ = NULL;
+ if (!list_.empty()) {
+ EXPECT_EQ(OK, result_);
+ EXPECT_EQ(info_.port(), list_.front().port());
+ }
+ if (handler_)
+ handler_->Handle(this);
+ if (quit_on_complete_) {
+ base::MessageLoop::current()->Quit();
+ quit_on_complete_ = false;
+ }
+ }
+
+ HostResolver::RequestInfo info_;
+ size_t index_;
+ HostResolver* resolver_;
+ Handler* handler_;
+ bool quit_on_complete_;
+
+ AddressList list_;
+ int result_;
+ HostResolver::RequestHandle handle_;
+
+ DISALLOW_COPY_AND_ASSIGN(Request);
+};
+
+// Using LookupAttemptHostResolverProc simulate very long lookups, and control
+// which attempt resolves the host.
+class LookupAttemptHostResolverProc : public HostResolverProc {
+ public:
+ LookupAttemptHostResolverProc(HostResolverProc* previous,
+ int attempt_number_to_resolve,
+ int total_attempts)
+ : HostResolverProc(previous),
+ attempt_number_to_resolve_(attempt_number_to_resolve),
+ current_attempt_number_(0),
+ total_attempts_(total_attempts),
+ total_attempts_resolved_(0),
+ resolved_attempt_number_(0),
+ all_done_(&lock_) {
+ }
+
+ // Test harness will wait for all attempts to finish before checking the
+ // results.
+ void WaitForAllAttemptsToFinish(const base::TimeDelta& wait_time) {
+ base::TimeTicks end_time = base::TimeTicks::Now() + wait_time;
+ {
+ base::AutoLock auto_lock(lock_);
+ while (total_attempts_resolved_ != total_attempts_ &&
+ base::TimeTicks::Now() < end_time) {
+ all_done_.TimedWait(end_time - base::TimeTicks::Now());
+ }
+ }
+ }
+
+ // All attempts will wait for an attempt to resolve the host.
+ void WaitForAnAttemptToComplete() {
+ base::TimeDelta wait_time = base::TimeDelta::FromSeconds(60);
+ base::TimeTicks end_time = base::TimeTicks::Now() + wait_time;
+ {
+ base::AutoLock auto_lock(lock_);
+ while (resolved_attempt_number_ == 0 && base::TimeTicks::Now() < end_time)
+ all_done_.TimedWait(end_time - base::TimeTicks::Now());
+ }
+ all_done_.Broadcast(); // Tell all waiting attempts to proceed.
+ }
+
+ // Returns the number of attempts that have finished the Resolve() method.
+ int total_attempts_resolved() { return total_attempts_resolved_; }
+
+ // Returns the first attempt that that has resolved the host.
+ int resolved_attempt_number() { return resolved_attempt_number_; }
+
+ // HostResolverProc methods.
+ virtual int Resolve(const std::string& host,
+ AddressFamily address_family,
+ HostResolverFlags host_resolver_flags,
+ AddressList* addrlist,
+ int* os_error) OVERRIDE {
+ bool wait_for_right_attempt_to_complete = true;
+ {
+ base::AutoLock auto_lock(lock_);
+ ++current_attempt_number_;
+ if (current_attempt_number_ == attempt_number_to_resolve_) {
+ resolved_attempt_number_ = current_attempt_number_;
+ wait_for_right_attempt_to_complete = false;
+ }
+ }
+
+ if (wait_for_right_attempt_to_complete)
+ // Wait for the attempt_number_to_resolve_ attempt to resolve.
+ WaitForAnAttemptToComplete();
+
+ int result = ResolveUsingPrevious(host, address_family, host_resolver_flags,
+ addrlist, os_error);
+
+ {
+ base::AutoLock auto_lock(lock_);
+ ++total_attempts_resolved_;
+ }
+
+ all_done_.Broadcast(); // Tell all attempts to proceed.
+
+ // Since any negative number is considered a network error, with -1 having
+ // special meaning (ERR_IO_PENDING). We could return the attempt that has
+ // resolved the host as a negative number. For example, if attempt number 3
+ // resolves the host, then this method returns -4.
+ if (result == OK)
+ return -1 - resolved_attempt_number_;
+ else
+ return result;
+ }
+
+ protected:
+ virtual ~LookupAttemptHostResolverProc() {}
+
+ private:
+ int attempt_number_to_resolve_;
+ int current_attempt_number_; // Incremented whenever Resolve is called.
+ int total_attempts_;
+ int total_attempts_resolved_;
+ int resolved_attempt_number_;
+
+ // All attempts wait for right attempt to be resolve.
+ base::Lock lock_;
+ base::ConditionVariable all_done_;
+};
+
+} // namespace
+
+class HostResolverImplTest : public testing::Test {
+ public:
+ static const int kDefaultPort = 80;
+
+ HostResolverImplTest() : proc_(new MockHostResolverProc()) {}
+
+ protected:
+ // A Request::Handler which is a proxy to the HostResolverImplTest fixture.
+ struct Handler : public Request::Handler {
+ virtual ~Handler() {}
+
+ // Proxy functions so that classes derived from Handler can access them.
+ Request* CreateRequest(const HostResolver::RequestInfo& info) {
+ return test->CreateRequest(info);
+ }
+ Request* CreateRequest(const std::string& hostname, int port) {
+ return test->CreateRequest(hostname, port);
+ }
+ Request* CreateRequest(const std::string& hostname) {
+ return test->CreateRequest(hostname);
+ }
+ ScopedVector<Request>& requests() { return test->requests_; }
+
+ void DeleteResolver() { test->resolver_.reset(); }
+
+ HostResolverImplTest* test;
+ };
+
+ void CreateResolver() {
+ resolver_.reset(new HostResolverImpl(HostCache::CreateDefaultCache(),
+ DefaultLimits(),
+ DefaultParams(proc_.get()),
+ NULL));
+ }
+
+ // This HostResolverImpl will only allow 1 outstanding resolve at a time and
+ // perform no retries.
+ void CreateSerialResolver() {
+ HostResolverImpl::ProcTaskParams params = DefaultParams(proc_.get());
+ params.max_retry_attempts = 0u;
+ PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, 1);
+ resolver_.reset(new HostResolverImpl(
+ HostCache::CreateDefaultCache(),
+ limits,
+ params,
+ NULL));
+ }
+
+ // The Request will not be made until a call to |Resolve()|, and the Job will
+ // not start until released by |proc_->SignalXXX|.
+ Request* CreateRequest(const HostResolver::RequestInfo& info) {
+ Request* req = new Request(info, requests_.size(), resolver_.get(),
+ handler_.get());
+ requests_.push_back(req);
+ return req;
+ }
+
+ Request* CreateRequest(const std::string& hostname,
+ int port,
+ RequestPriority priority,
+ AddressFamily family) {
+ HostResolver::RequestInfo info(HostPortPair(hostname, port));
+ info.set_priority(priority);
+ info.set_address_family(family);
+ return CreateRequest(info);
+ }
+
+ Request* CreateRequest(const std::string& hostname,
+ int port,
+ RequestPriority priority) {
+ return CreateRequest(hostname, port, priority, ADDRESS_FAMILY_UNSPECIFIED);
+ }
+
+ Request* CreateRequest(const std::string& hostname, int port) {
+ return CreateRequest(hostname, port, MEDIUM);
+ }
+
+ Request* CreateRequest(const std::string& hostname) {
+ return CreateRequest(hostname, kDefaultPort);
+ }
+
+ virtual void SetUp() OVERRIDE {
+ CreateResolver();
+ }
+
+ virtual void TearDown() OVERRIDE {
+ if (resolver_.get())
+ EXPECT_EQ(0u, resolver_->num_running_jobs_for_tests());
+ EXPECT_FALSE(proc_->HasBlockedRequests());
+ }
+
+ void set_handler(Handler* handler) {
+ handler_.reset(handler);
+ handler_->test = this;
+ }
+
+ // Friendship is not inherited, so use proxies to access those.
+ size_t num_running_jobs() const {
+ DCHECK(resolver_.get());
+ return resolver_->num_running_jobs_for_tests();
+ }
+
+ void set_fallback_to_proctask(bool fallback_to_proctask) {
+ DCHECK(resolver_.get());
+ resolver_->fallback_to_proctask_ = fallback_to_proctask;
+ }
+
+ scoped_refptr<MockHostResolverProc> proc_;
+ scoped_ptr<HostResolverImpl> resolver_;
+ ScopedVector<Request> requests_;
+
+ scoped_ptr<Handler> handler_;
+};
+
+TEST_F(HostResolverImplTest, AsynchronousLookup) {
+ proc_->AddRuleForAllFamilies("just.testing", "192.168.1.42");
+ proc_->SignalMultiple(1u);
+
+ Request* req = CreateRequest("just.testing", 80);
+ EXPECT_EQ(ERR_IO_PENDING, req->Resolve());
+ EXPECT_EQ(OK, req->WaitForResult());
+
+ EXPECT_TRUE(req->HasOneAddress("192.168.1.42", 80));
+
+ EXPECT_EQ("just.testing", proc_->GetCaptureList()[0].hostname);
+}
+
+TEST_F(HostResolverImplTest, EmptyListMeansNameNotResolved) {
+ proc_->AddRuleForAllFamilies("just.testing", "");
+ proc_->SignalMultiple(1u);
+
+ Request* req = CreateRequest("just.testing", 80);
+ EXPECT_EQ(ERR_IO_PENDING, req->Resolve());
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req->WaitForResult());
+ EXPECT_EQ(0u, req->NumberOfAddresses());
+ EXPECT_EQ("just.testing", proc_->GetCaptureList()[0].hostname);
+}
+
+TEST_F(HostResolverImplTest, FailedAsynchronousLookup) {
+ proc_->AddRuleForAllFamilies(std::string(),
+ "0.0.0.0"); // Default to failures.
+ proc_->SignalMultiple(1u);
+
+ Request* req = CreateRequest("just.testing", 80);
+ EXPECT_EQ(ERR_IO_PENDING, req->Resolve());
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req->WaitForResult());
+
+ EXPECT_EQ("just.testing", proc_->GetCaptureList()[0].hostname);
+
+ // Also test that the error is not cached.
+ EXPECT_EQ(ERR_DNS_CACHE_MISS, req->ResolveFromCache());
+}
+
+TEST_F(HostResolverImplTest, AbortedAsynchronousLookup) {
+ Request* req0 = CreateRequest("just.testing", 80);
+ EXPECT_EQ(ERR_IO_PENDING, req0->Resolve());
+
+ EXPECT_TRUE(proc_->WaitFor(1u));
+
+ // Resolver is destroyed while job is running on WorkerPool.
+ resolver_.reset();
+
+ proc_->SignalAll();
+
+ // To ensure there was no spurious callback, complete with a new resolver.
+ CreateResolver();
+ Request* req1 = CreateRequest("just.testing", 80);
+ EXPECT_EQ(ERR_IO_PENDING, req1->Resolve());
+
+ proc_->SignalMultiple(2u);
+
+ EXPECT_EQ(OK, req1->WaitForResult());
+
+ // This request was canceled.
+ EXPECT_FALSE(req0->completed());
+}
+
+TEST_F(HostResolverImplTest, NumericIPv4Address) {
+ // Stevens says dotted quads with AI_UNSPEC resolve to a single sockaddr_in.
+ Request* req = CreateRequest("127.1.2.3", 5555);
+ EXPECT_EQ(OK, req->Resolve());
+
+ EXPECT_TRUE(req->HasOneAddress("127.1.2.3", 5555));
+}
+
+TEST_F(HostResolverImplTest, NumericIPv6Address) {
+ // Resolve a plain IPv6 address. Don't worry about [brackets], because
+ // the caller should have removed them.
+ Request* req = CreateRequest("2001:db8::1", 5555);
+ EXPECT_EQ(OK, req->Resolve());
+
+ EXPECT_TRUE(req->HasOneAddress("2001:db8::1", 5555));
+}
+
+TEST_F(HostResolverImplTest, EmptyHost) {
+ Request* req = CreateRequest(std::string(), 5555);
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req->Resolve());
+}
+
+TEST_F(HostResolverImplTest, EmptyDotsHost) {
+ for (int i = 0; i < 16; ++i) {
+ Request* req = CreateRequest(std::string(i, '.'), 5555);
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req->Resolve());
+ }
+}
+
+TEST_F(HostResolverImplTest, LongHost) {
+ Request* req = CreateRequest(std::string(4097, 'a'), 5555);
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req->Resolve());
+}
+
+TEST_F(HostResolverImplTest, DeDupeRequests) {
+ // Start 5 requests, duplicating hosts "a" and "b". Since the resolver_proc is
+ // blocked, these should all pile up until we signal it.
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 81)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 82)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 83)->Resolve());
+
+ proc_->SignalMultiple(2u); // One for "a", one for "b".
+
+ for (size_t i = 0; i < requests_.size(); ++i) {
+ EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i;
+ }
+}
+
+TEST_F(HostResolverImplTest, CancelMultipleRequests) {
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 81)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 82)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 83)->Resolve());
+
+ // Cancel everything except request for ("a", 82).
+ requests_[0]->Cancel();
+ requests_[1]->Cancel();
+ requests_[2]->Cancel();
+ requests_[4]->Cancel();
+
+ proc_->SignalMultiple(2u); // One for "a", one for "b".
+
+ EXPECT_EQ(OK, requests_[3]->WaitForResult());
+}
+
+TEST_F(HostResolverImplTest, CanceledRequestsReleaseJobSlots) {
+ // Fill up the dispatcher and queue.
+ for (unsigned i = 0; i < kMaxJobs + 1; ++i) {
+ std::string hostname = "a_";
+ hostname[1] = 'a' + i;
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest(hostname, 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest(hostname, 81)->Resolve());
+ }
+
+ EXPECT_TRUE(proc_->WaitFor(kMaxJobs));
+
+ // Cancel all but last two.
+ for (unsigned i = 0; i < requests_.size() - 2; ++i) {
+ requests_[i]->Cancel();
+ }
+
+ EXPECT_TRUE(proc_->WaitFor(kMaxJobs + 1));
+
+ proc_->SignalAll();
+
+ size_t num_requests = requests_.size();
+ EXPECT_EQ(OK, requests_[num_requests - 1]->WaitForResult());
+ EXPECT_EQ(OK, requests_[num_requests - 2]->result());
+}
+
+TEST_F(HostResolverImplTest, CancelWithinCallback) {
+ struct MyHandler : public Handler {
+ virtual void Handle(Request* req) OVERRIDE {
+ // Port 80 is the first request that the callback will be invoked for.
+ // While we are executing within that callback, cancel the other requests
+ // in the job and start another request.
+ if (req->index() == 0) {
+ // Once "a:80" completes, it will cancel "a:81" and "a:82".
+ requests()[1]->Cancel();
+ requests()[2]->Cancel();
+ }
+ }
+ };
+ set_handler(new MyHandler());
+
+ for (size_t i = 0; i < 4; ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 80 + i)->Resolve()) << i;
+ }
+
+ proc_->SignalMultiple(2u); // One for "a". One for "finalrequest".
+
+ EXPECT_EQ(OK, requests_[0]->WaitForResult());
+
+ Request* final_request = CreateRequest("finalrequest", 70);
+ EXPECT_EQ(ERR_IO_PENDING, final_request->Resolve());
+ EXPECT_EQ(OK, final_request->WaitForResult());
+ EXPECT_TRUE(requests_[3]->completed());
+}
+
+TEST_F(HostResolverImplTest, DeleteWithinCallback) {
+ struct MyHandler : public Handler {
+ virtual void Handle(Request* req) OVERRIDE {
+ EXPECT_EQ("a", req->info().hostname());
+ EXPECT_EQ(80, req->info().port());
+
+ DeleteResolver();
+
+ // Quit after returning from OnCompleted (to give it a chance at
+ // incorrectly running the cancelled tasks).
+ base::MessageLoop::current()->PostTask(FROM_HERE,
+ base::MessageLoop::QuitClosure());
+ }
+ };
+ set_handler(new MyHandler());
+
+ for (size_t i = 0; i < 4; ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 80 + i)->Resolve()) << i;
+ }
+
+ proc_->SignalMultiple(1u); // One for "a".
+
+ // |MyHandler| will send quit message once all the requests have finished.
+ base::MessageLoop::current()->Run();
+}
+
+TEST_F(HostResolverImplTest, DeleteWithinAbortedCallback) {
+ struct MyHandler : public Handler {
+ virtual void Handle(Request* req) OVERRIDE {
+ EXPECT_EQ("a", req->info().hostname());
+ EXPECT_EQ(80, req->info().port());
+
+ DeleteResolver();
+
+ // Quit after returning from OnCompleted (to give it a chance at
+ // incorrectly running the cancelled tasks).
+ base::MessageLoop::current()->PostTask(FROM_HERE,
+ base::MessageLoop::QuitClosure());
+ }
+ };
+ set_handler(new MyHandler());
+
+ // This test assumes that the Jobs will be Aborted in order ["a", "b"]
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 80)->Resolve());
+ // HostResolverImpl will be deleted before later Requests can complete.
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 81)->Resolve());
+ // Job for 'b' will be aborted before it can complete.
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 82)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 83)->Resolve());
+
+ EXPECT_TRUE(proc_->WaitFor(1u));
+
+ // Triggering an IP address change.
+ NetworkChangeNotifier::NotifyObserversOfIPAddressChangeForTests();
+
+ // |MyHandler| will send quit message once all the requests have finished.
+ base::MessageLoop::current()->Run();
+
+ EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[0]->result());
+ EXPECT_EQ(ERR_IO_PENDING, requests_[1]->result());
+ EXPECT_EQ(ERR_IO_PENDING, requests_[2]->result());
+ EXPECT_EQ(ERR_IO_PENDING, requests_[3]->result());
+ // Clean up.
+ proc_->SignalMultiple(requests_.size());
+}
+
+TEST_F(HostResolverImplTest, StartWithinCallback) {
+ struct MyHandler : public Handler {
+ virtual void Handle(Request* req) OVERRIDE {
+ if (req->index() == 0) {
+ // On completing the first request, start another request for "a".
+ // Since caching is disabled, this will result in another async request.
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 70)->Resolve());
+ }
+ }
+ };
+ set_handler(new MyHandler());
+
+ // Turn off caching for this host resolver.
+ resolver_.reset(new HostResolverImpl(scoped_ptr<HostCache>(),
+ DefaultLimits(),
+ DefaultParams(proc_.get()),
+ NULL));
+
+ for (size_t i = 0; i < 4; ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 80 + i)->Resolve()) << i;
+ }
+
+ proc_->SignalMultiple(2u); // One for "a". One for the second "a".
+
+ EXPECT_EQ(OK, requests_[0]->WaitForResult());
+ ASSERT_EQ(5u, requests_.size());
+ EXPECT_EQ(OK, requests_.back()->WaitForResult());
+
+ EXPECT_EQ(2u, proc_->GetCaptureList().size());
+}
+
+TEST_F(HostResolverImplTest, BypassCache) {
+ struct MyHandler : public Handler {
+ virtual void Handle(Request* req) OVERRIDE {
+ if (req->index() == 0) {
+ // On completing the first request, start another request for "a".
+ // Since caching is enabled, this should complete synchronously.
+ std::string hostname = req->info().hostname();
+ EXPECT_EQ(OK, CreateRequest(hostname, 70)->Resolve());
+ EXPECT_EQ(OK, CreateRequest(hostname, 75)->ResolveFromCache());
+
+ // Ok good. Now make sure that if we ask to bypass the cache, it can no
+ // longer service the request synchronously.
+ HostResolver::RequestInfo info(HostPortPair(hostname, 71));
+ info.set_allow_cached_response(false);
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest(info)->Resolve());
+ } else if (71 == req->info().port()) {
+ // Test is done.
+ base::MessageLoop::current()->Quit();
+ } else {
+ FAIL() << "Unexpected request";
+ }
+ }
+ };
+ set_handler(new MyHandler());
+
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 80)->Resolve());
+ proc_->SignalMultiple(3u); // Only need two, but be generous.
+
+ // |verifier| will send quit message once all the requests have finished.
+ base::MessageLoop::current()->Run();
+ EXPECT_EQ(2u, proc_->GetCaptureList().size());
+}
+
+// Test that IP address changes flush the cache.
+TEST_F(HostResolverImplTest, FlushCacheOnIPAddressChange) {
+ proc_->SignalMultiple(2u); // One before the flush, one after.
+
+ Request* req = CreateRequest("host1", 70);
+ EXPECT_EQ(ERR_IO_PENDING, req->Resolve());
+ EXPECT_EQ(OK, req->WaitForResult());
+
+ req = CreateRequest("host1", 75);
+ EXPECT_EQ(OK, req->Resolve()); // Should complete synchronously.
+
+ // Flush cache by triggering an IP address change.
+ NetworkChangeNotifier::NotifyObserversOfIPAddressChangeForTests();
+ base::MessageLoop::current()->RunUntilIdle(); // Notification happens async.
+
+ // Resolve "host1" again -- this time it won't be served from cache, so it
+ // will complete asynchronously.
+ req = CreateRequest("host1", 80);
+ EXPECT_EQ(ERR_IO_PENDING, req->Resolve());
+ EXPECT_EQ(OK, req->WaitForResult());
+}
+
+// Test that IP address changes send ERR_NETWORK_CHANGED to pending requests.
+TEST_F(HostResolverImplTest, AbortOnIPAddressChanged) {
+ Request* req = CreateRequest("host1", 70);
+ EXPECT_EQ(ERR_IO_PENDING, req->Resolve());
+
+ EXPECT_TRUE(proc_->WaitFor(1u));
+ // Triggering an IP address change.
+ NetworkChangeNotifier::NotifyObserversOfIPAddressChangeForTests();
+ base::MessageLoop::current()->RunUntilIdle(); // Notification happens async.
+ proc_->SignalAll();
+
+ EXPECT_EQ(ERR_NETWORK_CHANGED, req->WaitForResult());
+ EXPECT_EQ(0u, resolver_->GetHostCache()->size());
+}
+
+// Obey pool constraints after IP address has changed.
+TEST_F(HostResolverImplTest, ObeyPoolConstraintsAfterIPAddressChange) {
+ // Runs at most one job at a time.
+ CreateSerialResolver();
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a")->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b")->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("c")->Resolve());
+
+ EXPECT_TRUE(proc_->WaitFor(1u));
+ // Triggering an IP address change.
+ NetworkChangeNotifier::NotifyObserversOfIPAddressChangeForTests();
+ base::MessageLoop::current()->RunUntilIdle(); // Notification happens async.
+ proc_->SignalMultiple(3u); // Let the false-start go so that we can catch it.
+
+ EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[0]->WaitForResult());
+
+ EXPECT_EQ(1u, num_running_jobs());
+
+ EXPECT_FALSE(requests_[1]->completed());
+ EXPECT_FALSE(requests_[2]->completed());
+
+ EXPECT_EQ(OK, requests_[2]->WaitForResult());
+ EXPECT_EQ(OK, requests_[1]->result());
+}
+
+// Tests that a new Request made from the callback of a previously aborted one
+// will not be aborted.
+TEST_F(HostResolverImplTest, AbortOnlyExistingRequestsOnIPAddressChange) {
+ struct MyHandler : public Handler {
+ virtual void Handle(Request* req) OVERRIDE {
+ // Start new request for a different hostname to ensure that the order
+ // of jobs in HostResolverImpl is not stable.
+ std::string hostname;
+ if (req->index() == 0)
+ hostname = "zzz";
+ else if (req->index() == 1)
+ hostname = "aaa";
+ else if (req->index() == 2)
+ hostname = "eee";
+ else
+ return; // A request started from within MyHandler.
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest(hostname)->Resolve()) << hostname;
+ }
+ };
+ set_handler(new MyHandler());
+
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("bbb")->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("eee")->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ccc")->Resolve());
+
+ // Wait until all are blocked;
+ EXPECT_TRUE(proc_->WaitFor(3u));
+ // Trigger an IP address change.
+ NetworkChangeNotifier::NotifyObserversOfIPAddressChangeForTests();
+ // This should abort all running jobs.
+ base::MessageLoop::current()->RunUntilIdle();
+ EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[0]->result());
+ EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[1]->result());
+ EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[2]->result());
+ ASSERT_EQ(6u, requests_.size());
+ // Unblock all calls to proc.
+ proc_->SignalMultiple(requests_.size());
+ // Run until the re-started requests finish.
+ EXPECT_EQ(OK, requests_[3]->WaitForResult());
+ EXPECT_EQ(OK, requests_[4]->WaitForResult());
+ EXPECT_EQ(OK, requests_[5]->WaitForResult());
+ // Verify that results of aborted Jobs were not cached.
+ EXPECT_EQ(6u, proc_->GetCaptureList().size());
+ EXPECT_EQ(3u, resolver_->GetHostCache()->size());
+}
+
+// Tests that when the maximum threads is set to 1, requests are dequeued
+// in order of priority.
+TEST_F(HostResolverImplTest, HigherPriorityRequestsStartedFirst) {
+ CreateSerialResolver();
+
+ // Note that at this point the MockHostResolverProc is blocked, so any
+ // requests we make will not complete.
+ CreateRequest("req0", 80, LOW);
+ CreateRequest("req1", 80, MEDIUM);
+ CreateRequest("req2", 80, MEDIUM);
+ CreateRequest("req3", 80, LOW);
+ CreateRequest("req4", 80, HIGHEST);
+ CreateRequest("req5", 80, LOW);
+ CreateRequest("req6", 80, LOW);
+ CreateRequest("req5", 80, HIGHEST);
+
+ for (size_t i = 0; i < requests_.size(); ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, requests_[i]->Resolve()) << i;
+ }
+
+ // Unblock the resolver thread so the requests can run.
+ proc_->SignalMultiple(requests_.size()); // More than needed.
+
+ // Wait for all the requests to complete succesfully.
+ for (size_t i = 0; i < requests_.size(); ++i) {
+ EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i;
+ }
+
+ // Since we have restricted to a single concurrent thread in the jobpool,
+ // the requests should complete in order of priority (with the exception
+ // of the first request, which gets started right away, since there is
+ // nothing outstanding).
+ MockHostResolverProc::CaptureList capture_list = proc_->GetCaptureList();
+ ASSERT_EQ(7u, capture_list.size());
+
+ EXPECT_EQ("req0", capture_list[0].hostname);
+ EXPECT_EQ("req4", capture_list[1].hostname);
+ EXPECT_EQ("req5", capture_list[2].hostname);
+ EXPECT_EQ("req1", capture_list[3].hostname);
+ EXPECT_EQ("req2", capture_list[4].hostname);
+ EXPECT_EQ("req3", capture_list[5].hostname);
+ EXPECT_EQ("req6", capture_list[6].hostname);
+}
+
+// Try cancelling a job which has not started yet.
+TEST_F(HostResolverImplTest, CancelPendingRequest) {
+ CreateSerialResolver();
+
+ CreateRequest("req0", 80, LOWEST);
+ CreateRequest("req1", 80, HIGHEST); // Will cancel.
+ CreateRequest("req2", 80, MEDIUM);
+ CreateRequest("req3", 80, LOW);
+ CreateRequest("req4", 80, HIGHEST); // Will cancel.
+ CreateRequest("req5", 80, LOWEST); // Will cancel.
+ CreateRequest("req6", 80, MEDIUM);
+
+ // Start all of the requests.
+ for (size_t i = 0; i < requests_.size(); ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, requests_[i]->Resolve()) << i;
+ }
+
+ // Cancel some requests
+ requests_[1]->Cancel();
+ requests_[4]->Cancel();
+ requests_[5]->Cancel();
+
+ // Unblock the resolver thread so the requests can run.
+ proc_->SignalMultiple(requests_.size()); // More than needed.
+
+ // Wait for all the requests to complete succesfully.
+ for (size_t i = 0; i < requests_.size(); ++i) {
+ if (!requests_[i]->pending())
+ continue; // Don't wait for the requests we cancelled.
+ EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i;
+ }
+
+ // Verify that they called out the the resolver proc (which runs on the
+ // resolver thread) in the expected order.
+ MockHostResolverProc::CaptureList capture_list = proc_->GetCaptureList();
+ ASSERT_EQ(4u, capture_list.size());
+
+ EXPECT_EQ("req0", capture_list[0].hostname);
+ EXPECT_EQ("req2", capture_list[1].hostname);
+ EXPECT_EQ("req6", capture_list[2].hostname);
+ EXPECT_EQ("req3", capture_list[3].hostname);
+}
+
+// Test that when too many requests are enqueued, old ones start to be aborted.
+TEST_F(HostResolverImplTest, QueueOverflow) {
+ CreateSerialResolver();
+
+ // Allow only 3 queued jobs.
+ const size_t kMaxPendingJobs = 3u;
+ resolver_->SetMaxQueuedJobs(kMaxPendingJobs);
+
+ // Note that at this point the MockHostResolverProc is blocked, so any
+ // requests we make will not complete.
+
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("req0", 80, LOWEST)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("req1", 80, HIGHEST)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("req2", 80, MEDIUM)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("req3", 80, MEDIUM)->Resolve());
+
+ // At this point, there are 3 enqueued jobs.
+ // Insertion of subsequent requests will cause evictions
+ // based on priority.
+
+ EXPECT_EQ(ERR_HOST_RESOLVER_QUEUE_TOO_LARGE,
+ CreateRequest("req4", 80, LOW)->Resolve()); // Evicts itself!
+
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("req5", 80, MEDIUM)->Resolve());
+ EXPECT_EQ(ERR_HOST_RESOLVER_QUEUE_TOO_LARGE, requests_[2]->result());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("req6", 80, HIGHEST)->Resolve());
+ EXPECT_EQ(ERR_HOST_RESOLVER_QUEUE_TOO_LARGE, requests_[3]->result());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("req7", 80, MEDIUM)->Resolve());
+ EXPECT_EQ(ERR_HOST_RESOLVER_QUEUE_TOO_LARGE, requests_[5]->result());
+
+ // Unblock the resolver thread so the requests can run.
+ proc_->SignalMultiple(4u);
+
+ // The rest should succeed.
+ EXPECT_EQ(OK, requests_[7]->WaitForResult());
+ EXPECT_EQ(OK, requests_[0]->result());
+ EXPECT_EQ(OK, requests_[1]->result());
+ EXPECT_EQ(OK, requests_[6]->result());
+
+ // Verify that they called out the the resolver proc (which runs on the
+ // resolver thread) in the expected order.
+ MockHostResolverProc::CaptureList capture_list = proc_->GetCaptureList();
+ ASSERT_EQ(4u, capture_list.size());
+
+ EXPECT_EQ("req0", capture_list[0].hostname);
+ EXPECT_EQ("req1", capture_list[1].hostname);
+ EXPECT_EQ("req6", capture_list[2].hostname);
+ EXPECT_EQ("req7", capture_list[3].hostname);
+
+ // Verify that the evicted (incomplete) requests were not cached.
+ EXPECT_EQ(4u, resolver_->GetHostCache()->size());
+
+ for (size_t i = 0; i < requests_.size(); ++i) {
+ EXPECT_TRUE(requests_[i]->completed()) << i;
+ }
+}
+
+// Tests that after changing the default AddressFamily to IPV4, requests
+// with UNSPECIFIED address family map to IPV4.
+TEST_F(HostResolverImplTest, SetDefaultAddressFamily_IPv4) {
+ CreateSerialResolver(); // To guarantee order of resolutions.
+
+ proc_->AddRule("h1", ADDRESS_FAMILY_IPV4, "1.0.0.1");
+ proc_->AddRule("h1", ADDRESS_FAMILY_IPV6, "::2");
+
+ resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV4);
+
+ CreateRequest("h1", 80, MEDIUM, ADDRESS_FAMILY_UNSPECIFIED);
+ CreateRequest("h1", 80, MEDIUM, ADDRESS_FAMILY_IPV4);
+ CreateRequest("h1", 80, MEDIUM, ADDRESS_FAMILY_IPV6);
+
+ // Start all of the requests.
+ for (size_t i = 0; i < requests_.size(); ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, requests_[i]->Resolve()) << i;
+ }
+
+ proc_->SignalMultiple(requests_.size());
+
+ // Wait for all the requests to complete.
+ for (size_t i = 0u; i < requests_.size(); ++i) {
+ EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i;
+ }
+
+ // Since the requests all had the same priority and we limited the thread
+ // count to 1, they should have completed in the same order as they were
+ // requested. Moreover, request0 and request1 will have been serviced by
+ // the same job.
+
+ MockHostResolverProc::CaptureList capture_list = proc_->GetCaptureList();
+ ASSERT_EQ(2u, capture_list.size());
+
+ EXPECT_EQ("h1", capture_list[0].hostname);
+ EXPECT_EQ(ADDRESS_FAMILY_IPV4, capture_list[0].address_family);
+
+ EXPECT_EQ("h1", capture_list[1].hostname);
+ EXPECT_EQ(ADDRESS_FAMILY_IPV6, capture_list[1].address_family);
+
+ // Now check that the correct resolved IP addresses were returned.
+ EXPECT_TRUE(requests_[0]->HasOneAddress("1.0.0.1", 80));
+ EXPECT_TRUE(requests_[1]->HasOneAddress("1.0.0.1", 80));
+ EXPECT_TRUE(requests_[2]->HasOneAddress("::2", 80));
+}
+
+// This is the exact same test as SetDefaultAddressFamily_IPv4, except the
+// default family is set to IPv6 and the family of requests is flipped where
+// specified.
+TEST_F(HostResolverImplTest, SetDefaultAddressFamily_IPv6) {
+ CreateSerialResolver(); // To guarantee order of resolutions.
+
+ // Don't use IPv6 replacements here since some systems don't support it.
+ proc_->AddRule("h1", ADDRESS_FAMILY_IPV4, "1.0.0.1");
+ proc_->AddRule("h1", ADDRESS_FAMILY_IPV6, "::2");
+
+ resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV6);
+
+ CreateRequest("h1", 80, MEDIUM, ADDRESS_FAMILY_UNSPECIFIED);
+ CreateRequest("h1", 80, MEDIUM, ADDRESS_FAMILY_IPV6);
+ CreateRequest("h1", 80, MEDIUM, ADDRESS_FAMILY_IPV4);
+
+ // Start all of the requests.
+ for (size_t i = 0; i < requests_.size(); ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, requests_[i]->Resolve()) << i;
+ }
+
+ proc_->SignalMultiple(requests_.size());
+
+ // Wait for all the requests to complete.
+ for (size_t i = 0u; i < requests_.size(); ++i) {
+ EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i;
+ }
+
+ // Since the requests all had the same priority and we limited the thread
+ // count to 1, they should have completed in the same order as they were
+ // requested. Moreover, request0 and request1 will have been serviced by
+ // the same job.
+
+ MockHostResolverProc::CaptureList capture_list = proc_->GetCaptureList();
+ ASSERT_EQ(2u, capture_list.size());
+
+ EXPECT_EQ("h1", capture_list[0].hostname);
+ EXPECT_EQ(ADDRESS_FAMILY_IPV6, capture_list[0].address_family);
+
+ EXPECT_EQ("h1", capture_list[1].hostname);
+ EXPECT_EQ(ADDRESS_FAMILY_IPV4, capture_list[1].address_family);
+
+ // Now check that the correct resolved IP addresses were returned.
+ EXPECT_TRUE(requests_[0]->HasOneAddress("::2", 80));
+ EXPECT_TRUE(requests_[1]->HasOneAddress("::2", 80));
+ EXPECT_TRUE(requests_[2]->HasOneAddress("1.0.0.1", 80));
+}
+
+TEST_F(HostResolverImplTest, ResolveFromCache) {
+ proc_->AddRuleForAllFamilies("just.testing", "192.168.1.42");
+ proc_->SignalMultiple(1u); // Need only one.
+
+ HostResolver::RequestInfo info(HostPortPair("just.testing", 80));
+
+ // First hit will miss the cache.
+ EXPECT_EQ(ERR_DNS_CACHE_MISS, CreateRequest(info)->ResolveFromCache());
+
+ // This time, we fetch normally.
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest(info)->Resolve());
+ EXPECT_EQ(OK, requests_[1]->WaitForResult());
+
+ // Now we should be able to fetch from the cache.
+ EXPECT_EQ(OK, CreateRequest(info)->ResolveFromCache());
+ EXPECT_TRUE(requests_[2]->HasOneAddress("192.168.1.42", 80));
+}
+
+// Test the retry attempts simulating host resolver proc that takes too long.
+TEST_F(HostResolverImplTest, MultipleAttempts) {
+ // Total number of attempts would be 3 and we want the 3rd attempt to resolve
+ // the host. First and second attempt will be forced to sleep until they get
+ // word that a resolution has completed. The 3rd resolution attempt will try
+ // to get done ASAP, and won't sleep..
+ int kAttemptNumberToResolve = 3;
+ int kTotalAttempts = 3;
+
+ scoped_refptr<LookupAttemptHostResolverProc> resolver_proc(
+ new LookupAttemptHostResolverProc(
+ NULL, kAttemptNumberToResolve, kTotalAttempts));
+
+ HostResolverImpl::ProcTaskParams params = DefaultParams(resolver_proc.get());
+
+ // Specify smaller interval for unresponsive_delay_ for HostResolverImpl so
+ // that unit test runs faster. For example, this test finishes in 1.5 secs
+ // (500ms * 3).
+ params.unresponsive_delay = base::TimeDelta::FromMilliseconds(500);
+
+ resolver_.reset(
+ new HostResolverImpl(HostCache::CreateDefaultCache(),
+ DefaultLimits(),
+ params,
+ NULL));
+
+ // Resolve "host1".
+ HostResolver::RequestInfo info(HostPortPair("host1", 70));
+ Request* req = CreateRequest(info);
+ EXPECT_EQ(ERR_IO_PENDING, req->Resolve());
+
+ // Resolve returns -4 to indicate that 3rd attempt has resolved the host.
+ EXPECT_EQ(-4, req->WaitForResult());
+
+ resolver_proc->WaitForAllAttemptsToFinish(
+ base::TimeDelta::FromMilliseconds(60000));
+ base::MessageLoop::current()->RunUntilIdle();
+
+ EXPECT_EQ(resolver_proc->total_attempts_resolved(), kTotalAttempts);
+ EXPECT_EQ(resolver_proc->resolved_attempt_number(), kAttemptNumberToResolve);
+}
+
+DnsConfig CreateValidDnsConfig() {
+ IPAddressNumber dns_ip;
+ bool rv = ParseIPLiteralToNumber("192.168.1.0", &dns_ip);
+ EXPECT_TRUE(rv);
+
+ DnsConfig config;
+ config.nameservers.push_back(IPEndPoint(dns_ip, dns_protocol::kDefaultPort));
+ EXPECT_TRUE(config.IsValid());
+ return config;
+}
+
+// Specialized fixture for tests of DnsTask.
+class HostResolverImplDnsTest : public HostResolverImplTest {
+ protected:
+ virtual void SetUp() OVERRIDE {
+ AddDnsRule("nx", dns_protocol::kTypeA, MockDnsClientRule::FAIL);
+ AddDnsRule("nx", dns_protocol::kTypeAAAA, MockDnsClientRule::FAIL);
+ AddDnsRule("ok", dns_protocol::kTypeA, MockDnsClientRule::OK);
+ AddDnsRule("ok", dns_protocol::kTypeAAAA, MockDnsClientRule::OK);
+ AddDnsRule("4ok", dns_protocol::kTypeA, MockDnsClientRule::OK);
+ AddDnsRule("4ok", dns_protocol::kTypeAAAA, MockDnsClientRule::EMPTY);
+ AddDnsRule("6ok", dns_protocol::kTypeA, MockDnsClientRule::EMPTY);
+ AddDnsRule("6ok", dns_protocol::kTypeAAAA, MockDnsClientRule::OK);
+ AddDnsRule("4nx", dns_protocol::kTypeA, MockDnsClientRule::OK);
+ AddDnsRule("4nx", dns_protocol::kTypeAAAA, MockDnsClientRule::FAIL);
+ CreateResolver();
+ }
+
+ void CreateResolver() {
+ resolver_.reset(new HostResolverImpl(HostCache::CreateDefaultCache(),
+ DefaultLimits(),
+ DefaultParams(proc_.get()),
+ NULL));
+ // Disable IPv6 support probing.
+ resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_UNSPECIFIED);
+ resolver_->SetDnsClient(CreateMockDnsClient(DnsConfig(), dns_rules_));
+ }
+
+ // Adds a rule to |dns_rules_|. Must be followed by |CreateResolver| to apply.
+ void AddDnsRule(const std::string& prefix,
+ uint16 qtype,
+ MockDnsClientRule::Result result) {
+ dns_rules_.push_back(MockDnsClientRule(prefix, qtype, result));
+ }
+
+ void ChangeDnsConfig(const DnsConfig& config) {
+ NetworkChangeNotifier::SetDnsConfig(config);
+ // Notification is delivered asynchronously.
+ base::MessageLoop::current()->RunUntilIdle();
+ }
+
+ MockDnsClientRuleList dns_rules_;
+};
+
+// TODO(szym): Test AbortAllInProgressJobs due to DnsConfig change.
+
+// TODO(cbentzel): Test a mix of requests with different HostResolverFlags.
+
+// Test successful and fallback resolutions in HostResolverImpl::DnsTask.
+TEST_F(HostResolverImplDnsTest, DnsTask) {
+ resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV4);
+
+ proc_->AddRuleForAllFamilies("nx_succeed", "192.168.1.102");
+ // All other hostnames will fail in proc_.
+
+ // Initially there is no config, so client should not be invoked.
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok_fail", 80)->Resolve());
+ proc_->SignalMultiple(requests_.size());
+
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, requests_[0]->WaitForResult());
+
+ ChangeDnsConfig(CreateValidDnsConfig());
+
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok_fail", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("nx_fail", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("nx_succeed", 80)->Resolve());
+
+ proc_->SignalMultiple(requests_.size());
+
+ for (size_t i = 1; i < requests_.size(); ++i)
+ EXPECT_NE(ERR_UNEXPECTED, requests_[i]->WaitForResult()) << i;
+
+ EXPECT_EQ(OK, requests_[1]->result());
+ // Resolved by MockDnsClient.
+ EXPECT_TRUE(requests_[1]->HasOneAddress("127.0.0.1", 80));
+ // Fallback to ProcTask.
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, requests_[2]->result());
+ EXPECT_EQ(OK, requests_[3]->result());
+ EXPECT_TRUE(requests_[3]->HasOneAddress("192.168.1.102", 80));
+}
+
+// Test successful and failing resolutions in HostResolverImpl::DnsTask when
+// fallback to ProcTask is disabled.
+TEST_F(HostResolverImplDnsTest, NoFallbackToProcTask) {
+ resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV4);
+ set_fallback_to_proctask(false);
+
+ proc_->AddRuleForAllFamilies("nx_succeed", "192.168.1.102");
+ // All other hostnames will fail in proc_.
+
+ // Set empty DnsConfig.
+ ChangeDnsConfig(DnsConfig());
+ // Initially there is no config, so client should not be invoked.
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok_fail", 80)->Resolve());
+ // There is no config, so fallback to ProcTask must work.
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("nx_succeed", 80)->Resolve());
+ proc_->SignalMultiple(requests_.size());
+
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, requests_[0]->WaitForResult());
+ EXPECT_EQ(OK, requests_[1]->WaitForResult());
+ EXPECT_TRUE(requests_[1]->HasOneAddress("192.168.1.102", 80));
+
+ ChangeDnsConfig(CreateValidDnsConfig());
+
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok_abort", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("nx_abort", 80)->Resolve());
+
+ // Simulate the case when the preference or policy has disabled the DNS client
+ // causing AbortDnsTasks.
+ resolver_->SetDnsClient(CreateMockDnsClient(DnsConfig(), dns_rules_));
+ ChangeDnsConfig(CreateValidDnsConfig());
+
+ // First request is resolved by MockDnsClient, others should fail due to
+ // disabled fallback to ProcTask.
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok_fail", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("nx_fail", 80)->Resolve());
+ proc_->SignalMultiple(requests_.size());
+
+ // Aborted due to Network Change.
+ EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[2]->WaitForResult());
+ EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[3]->WaitForResult());
+ // Resolved by MockDnsClient.
+ EXPECT_EQ(OK, requests_[4]->WaitForResult());
+ EXPECT_TRUE(requests_[4]->HasOneAddress("127.0.0.1", 80));
+ // Fallback to ProcTask is disabled.
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, requests_[5]->WaitForResult());
+}
+
+// Test behavior of OnDnsTaskFailure when Job is aborted.
+TEST_F(HostResolverImplDnsTest, OnDnsTaskFailureAbortedJob) {
+ resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV4);
+ ChangeDnsConfig(CreateValidDnsConfig());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("nx_abort", 80)->Resolve());
+ // Abort all jobs here.
+ CreateResolver();
+ proc_->SignalMultiple(requests_.size());
+ // Run to completion.
+ base::MessageLoop::current()->RunUntilIdle(); // Notification happens async.
+ // It shouldn't crash during OnDnsTaskFailure callbacks.
+ EXPECT_EQ(ERR_IO_PENDING, requests_[0]->result());
+
+ // Repeat test with Fallback to ProcTask disabled
+ resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV4);
+ set_fallback_to_proctask(false);
+ ChangeDnsConfig(CreateValidDnsConfig());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("nx_abort", 80)->Resolve());
+ // Abort all jobs here.
+ CreateResolver();
+ // Run to completion.
+ base::MessageLoop::current()->RunUntilIdle(); // Notification happens async.
+ // It shouldn't crash during OnDnsTaskFailure callbacks.
+ EXPECT_EQ(ERR_IO_PENDING, requests_[1]->result());
+}
+
+TEST_F(HostResolverImplDnsTest, DnsTaskUnspec) {
+ ChangeDnsConfig(CreateValidDnsConfig());
+
+ proc_->AddRuleForAllFamilies("4nx", "192.168.1.101");
+ // All other hostnames will fail in proc_.
+
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("4ok", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("6ok", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("4nx", 80)->Resolve());
+
+ proc_->SignalMultiple(requests_.size());
+
+ for (size_t i = 0; i < requests_.size(); ++i)
+ EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i;
+
+ EXPECT_EQ(2u, requests_[0]->NumberOfAddresses());
+ EXPECT_TRUE(requests_[0]->HasAddress("127.0.0.1", 80));
+ EXPECT_TRUE(requests_[0]->HasAddress("::1", 80));
+ EXPECT_EQ(1u, requests_[1]->NumberOfAddresses());
+ EXPECT_TRUE(requests_[1]->HasAddress("127.0.0.1", 80));
+ EXPECT_EQ(1u, requests_[2]->NumberOfAddresses());
+ EXPECT_TRUE(requests_[2]->HasAddress("::1", 80));
+ EXPECT_EQ(1u, requests_[3]->NumberOfAddresses());
+ EXPECT_TRUE(requests_[3]->HasAddress("192.168.1.101", 80));
+}
+
+TEST_F(HostResolverImplDnsTest, ServeFromHosts) {
+ // Initially, use empty HOSTS file.
+ DnsConfig config = CreateValidDnsConfig();
+ ChangeDnsConfig(config);
+
+ proc_->AddRuleForAllFamilies(std::string(),
+ std::string()); // Default to failures.
+ proc_->SignalMultiple(1u); // For the first request which misses.
+
+ Request* req0 = CreateRequest("nx_ipv4", 80);
+ EXPECT_EQ(ERR_IO_PENDING, req0->Resolve());
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req0->WaitForResult());
+
+ IPAddressNumber local_ipv4, local_ipv6;
+ ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &local_ipv4));
+ ASSERT_TRUE(ParseIPLiteralToNumber("::1", &local_ipv6));
+
+ DnsHosts hosts;
+ hosts[DnsHostsKey("nx_ipv4", ADDRESS_FAMILY_IPV4)] = local_ipv4;
+ hosts[DnsHostsKey("nx_ipv6", ADDRESS_FAMILY_IPV6)] = local_ipv6;
+ hosts[DnsHostsKey("nx_both", ADDRESS_FAMILY_IPV4)] = local_ipv4;
+ hosts[DnsHostsKey("nx_both", ADDRESS_FAMILY_IPV6)] = local_ipv6;
+
+ // Update HOSTS file.
+ config.hosts = hosts;
+ ChangeDnsConfig(config);
+
+ Request* req1 = CreateRequest("nx_ipv4", 80);
+ EXPECT_EQ(OK, req1->Resolve());
+ EXPECT_TRUE(req1->HasOneAddress("127.0.0.1", 80));
+
+ Request* req2 = CreateRequest("nx_ipv6", 80);
+ EXPECT_EQ(OK, req2->Resolve());
+ EXPECT_TRUE(req2->HasOneAddress("::1", 80));
+
+ Request* req3 = CreateRequest("nx_both", 80);
+ EXPECT_EQ(OK, req3->Resolve());
+ EXPECT_TRUE(req3->HasAddress("127.0.0.1", 80) &&
+ req3->HasAddress("::1", 80));
+
+ // Requests with specified AddressFamily.
+ Request* req4 = CreateRequest("nx_ipv4", 80, MEDIUM, ADDRESS_FAMILY_IPV4);
+ EXPECT_EQ(OK, req4->Resolve());
+ EXPECT_TRUE(req4->HasOneAddress("127.0.0.1", 80));
+
+ Request* req5 = CreateRequest("nx_ipv6", 80, MEDIUM, ADDRESS_FAMILY_IPV6);
+ EXPECT_EQ(OK, req5->Resolve());
+ EXPECT_TRUE(req5->HasOneAddress("::1", 80));
+
+ // Request with upper case.
+ Request* req6 = CreateRequest("nx_IPV4", 80);
+ EXPECT_EQ(OK, req6->Resolve());
+ EXPECT_TRUE(req6->HasOneAddress("127.0.0.1", 80));
+}
+
+TEST_F(HostResolverImplDnsTest, BypassDnsTask) {
+ ChangeDnsConfig(CreateValidDnsConfig());
+
+ proc_->AddRuleForAllFamilies(std::string(),
+ std::string()); // Default to failures.
+
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok.local", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok.local.", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("oklocal", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("oklocal.", 80)->Resolve());
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok", 80)->Resolve());
+
+ proc_->SignalMultiple(requests_.size());
+
+ for (size_t i = 0; i < 2; ++i)
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, requests_[i]->WaitForResult()) << i;
+
+ for (size_t i = 2; i < requests_.size(); ++i)
+ EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i;
+}
+
+TEST_F(HostResolverImplDnsTest, DisableDnsClientOnPersistentFailure) {
+ ChangeDnsConfig(CreateValidDnsConfig());
+
+ proc_->AddRuleForAllFamilies(std::string(),
+ std::string()); // Default to failures.
+
+ // Check that DnsTask works.
+ Request* req = CreateRequest("ok_1", 80);
+ EXPECT_EQ(ERR_IO_PENDING, req->Resolve());
+ EXPECT_EQ(OK, req->WaitForResult());
+
+ for (unsigned i = 0; i < 20; ++i) {
+ // Use custom names to require separate Jobs.
+ std::string hostname = base::StringPrintf("nx_%u", i);
+ // Ensure fallback to ProcTask succeeds.
+ proc_->AddRuleForAllFamilies(hostname, "192.168.1.101");
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest(hostname, 80)->Resolve()) << i;
+ }
+
+ proc_->SignalMultiple(requests_.size());
+
+ for (size_t i = 0; i < requests_.size(); ++i)
+ EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i;
+
+ ASSERT_FALSE(proc_->HasBlockedRequests());
+
+ // DnsTask should be disabled by now.
+ req = CreateRequest("ok_2", 80);
+ EXPECT_EQ(ERR_IO_PENDING, req->Resolve());
+ proc_->SignalMultiple(1u);
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req->WaitForResult());
+
+ // Check that it is re-enabled after DNS change.
+ ChangeDnsConfig(CreateValidDnsConfig());
+ req = CreateRequest("ok_3", 80);
+ EXPECT_EQ(ERR_IO_PENDING, req->Resolve());
+ EXPECT_EQ(OK, req->WaitForResult());
+}
+
+TEST_F(HostResolverImplDnsTest, DontDisableDnsClientOnSporadicFailure) {
+ ChangeDnsConfig(CreateValidDnsConfig());
+
+ // |proc_| defaults to successes.
+
+ // 20 failures interleaved with 20 successes.
+ for (unsigned i = 0; i < 40; ++i) {
+ // Use custom names to require separate Jobs.
+ std::string hostname = (i % 2) == 0 ? base::StringPrintf("nx_%u", i)
+ : base::StringPrintf("ok_%u", i);
+ EXPECT_EQ(ERR_IO_PENDING, CreateRequest(hostname, 80)->Resolve()) << i;
+ }
+
+ proc_->SignalMultiple(requests_.size());
+
+ for (size_t i = 0; i < requests_.size(); ++i)
+ EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i;
+
+ // Make |proc_| default to failures.
+ proc_->AddRuleForAllFamilies(std::string(), std::string());
+
+ // DnsTask should still be enabled.
+ Request* req = CreateRequest("ok_last", 80);
+ EXPECT_EQ(ERR_IO_PENDING, req->Resolve());
+ EXPECT_EQ(OK, req->WaitForResult());
+}
+
+// Confirm that resolving "localhost" is unrestricted even if there are no
+// global IPv6 address. See SystemHostResolverCall for rationale.
+// Test both the DnsClient and system host resolver paths.
+TEST_F(HostResolverImplDnsTest, DualFamilyLocalhost) {
+ // Use regular SystemHostResolverCall!
+ scoped_refptr<HostResolverProc> proc(new SystemHostResolverProc());
+ resolver_.reset(new HostResolverImpl(HostCache::CreateDefaultCache(),
+ DefaultLimits(),
+ DefaultParams(proc.get()),
+ NULL));
+ resolver_->SetDnsClient(CreateMockDnsClient(DnsConfig(), dns_rules_));
+ resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV4);
+
+ // Get the expected output.
+ AddressList addrlist;
+ int rv = proc->Resolve("localhost", ADDRESS_FAMILY_UNSPECIFIED, 0, &addrlist,
+ NULL);
+ if (rv != OK)
+ return;
+
+ for (unsigned i = 0; i < addrlist.size(); ++i)
+ LOG(WARNING) << addrlist[i].ToString();
+
+ bool saw_ipv4 = AddressListContains(addrlist, "127.0.0.1", 0);
+ bool saw_ipv6 = AddressListContains(addrlist, "::1", 0);
+ if (!saw_ipv4 && !saw_ipv6)
+ return;
+
+ HostResolver::RequestInfo info(HostPortPair("localhost", 80));
+ info.set_address_family(ADDRESS_FAMILY_UNSPECIFIED);
+ info.set_host_resolver_flags(HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6);
+
+ // Try without DnsClient.
+ ChangeDnsConfig(DnsConfig());
+ Request* req = CreateRequest(info);
+ // It is resolved via getaddrinfo, so expect asynchronous result.
+ EXPECT_EQ(ERR_IO_PENDING, req->Resolve());
+ EXPECT_EQ(OK, req->WaitForResult());
+
+ EXPECT_EQ(saw_ipv4, req->HasAddress("127.0.0.1", 80));
+ EXPECT_EQ(saw_ipv6, req->HasAddress("::1", 80));
+
+ // Configure DnsClient with dual-host HOSTS file.
+ DnsConfig config = CreateValidDnsConfig();
+ DnsHosts hosts;
+ IPAddressNumber local_ipv4, local_ipv6;
+ ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &local_ipv4));
+ ASSERT_TRUE(ParseIPLiteralToNumber("::1", &local_ipv6));
+ if (saw_ipv4)
+ hosts[DnsHostsKey("localhost", ADDRESS_FAMILY_IPV4)] = local_ipv4;
+ if (saw_ipv6)
+ hosts[DnsHostsKey("localhost", ADDRESS_FAMILY_IPV6)] = local_ipv6;
+ config.hosts = hosts;
+
+ ChangeDnsConfig(config);
+ req = CreateRequest(info);
+ // Expect synchronous resolution from DnsHosts.
+ EXPECT_EQ(OK, req->Resolve());
+
+ EXPECT_EQ(saw_ipv4, req->HasAddress("127.0.0.1", 80));
+ EXPECT_EQ(saw_ipv6, req->HasAddress("::1", 80));
+}
+
+} // namespace net
diff --git a/chromium/net/dns/host_resolver_proc.cc b/chromium/net/dns/host_resolver_proc.cc
new file mode 100644
index 00000000000..f2b10c649ba
--- /dev/null
+++ b/chromium/net/dns/host_resolver_proc.cc
@@ -0,0 +1,267 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/host_resolver_proc.h"
+
+#include "build/build_config.h"
+
+#include "base/logging.h"
+#include "base/sys_byteorder.h"
+#include "net/base/address_list.h"
+#include "net/base/dns_reloader.h"
+#include "net/base/net_errors.h"
+#include "net/base/sys_addrinfo.h"
+
+#if defined(OS_OPENBSD)
+#define AI_ADDRCONFIG 0
+#endif
+
+namespace net {
+
+namespace {
+
+bool IsAllLocalhostOfOneFamily(const struct addrinfo* ai) {
+ bool saw_v4_localhost = false;
+ bool saw_v6_localhost = false;
+ for (; ai != NULL; ai = ai->ai_next) {
+ switch (ai->ai_family) {
+ case AF_INET: {
+ const struct sockaddr_in* addr_in =
+ reinterpret_cast<struct sockaddr_in*>(ai->ai_addr);
+ if ((base::NetToHost32(addr_in->sin_addr.s_addr) & 0xff000000) ==
+ 0x7f000000)
+ saw_v4_localhost = true;
+ else
+ return false;
+ break;
+ }
+ case AF_INET6: {
+ const struct sockaddr_in6* addr_in6 =
+ reinterpret_cast<struct sockaddr_in6*>(ai->ai_addr);
+ if (IN6_IS_ADDR_LOOPBACK(&addr_in6->sin6_addr))
+ saw_v6_localhost = true;
+ else
+ return false;
+ break;
+ }
+ default:
+ NOTREACHED();
+ return false;
+ }
+ }
+
+ return saw_v4_localhost != saw_v6_localhost;
+}
+
+} // namespace
+
+HostResolverProc* HostResolverProc::default_proc_ = NULL;
+
+HostResolverProc::HostResolverProc(HostResolverProc* previous) {
+ SetPreviousProc(previous);
+
+ // Implicitly fall-back to the global default procedure.
+ if (!previous)
+ SetPreviousProc(default_proc_);
+}
+
+HostResolverProc::~HostResolverProc() {
+}
+
+int HostResolverProc::ResolveUsingPrevious(
+ const std::string& host,
+ AddressFamily address_family,
+ HostResolverFlags host_resolver_flags,
+ AddressList* addrlist,
+ int* os_error) {
+ if (previous_proc_.get()) {
+ return previous_proc_->Resolve(
+ host, address_family, host_resolver_flags, addrlist, os_error);
+ }
+
+ // Final fallback is the system resolver.
+ return SystemHostResolverCall(host, address_family, host_resolver_flags,
+ addrlist, os_error);
+}
+
+void HostResolverProc::SetPreviousProc(HostResolverProc* proc) {
+ HostResolverProc* current_previous = previous_proc_.get();
+ previous_proc_ = NULL;
+ // Now that we've guaranteed |this| is the last proc in a chain, we can
+ // detect potential cycles using GetLastProc().
+ previous_proc_ = (GetLastProc(proc) == this) ? current_previous : proc;
+}
+
+void HostResolverProc::SetLastProc(HostResolverProc* proc) {
+ GetLastProc(this)->SetPreviousProc(proc);
+}
+
+// static
+HostResolverProc* HostResolverProc::GetLastProc(HostResolverProc* proc) {
+ if (proc == NULL)
+ return NULL;
+ HostResolverProc* last_proc = proc;
+ while (last_proc->previous_proc_.get() != NULL)
+ last_proc = last_proc->previous_proc_.get();
+ return last_proc;
+}
+
+// static
+HostResolverProc* HostResolverProc::SetDefault(HostResolverProc* proc) {
+ HostResolverProc* old = default_proc_;
+ default_proc_ = proc;
+ return old;
+}
+
+// static
+HostResolverProc* HostResolverProc::GetDefault() {
+ return default_proc_;
+}
+
+int SystemHostResolverCall(const std::string& host,
+ AddressFamily address_family,
+ HostResolverFlags host_resolver_flags,
+ AddressList* addrlist,
+ int* os_error) {
+ if (os_error)
+ *os_error = 0;
+
+ struct addrinfo* ai = NULL;
+ struct addrinfo hints = {0};
+
+ switch (address_family) {
+ case ADDRESS_FAMILY_IPV4:
+ hints.ai_family = AF_INET;
+ break;
+ case ADDRESS_FAMILY_IPV6:
+ hints.ai_family = AF_INET6;
+ break;
+ case ADDRESS_FAMILY_UNSPECIFIED:
+ hints.ai_family = AF_UNSPEC;
+ break;
+ default:
+ NOTREACHED();
+ hints.ai_family = AF_UNSPEC;
+ }
+
+#if defined(OS_WIN)
+ // DO NOT USE AI_ADDRCONFIG ON WINDOWS.
+ //
+ // The following comment in <winsock2.h> is the best documentation I found
+ // on AI_ADDRCONFIG for Windows:
+ // Flags used in "hints" argument to getaddrinfo()
+ // - AI_ADDRCONFIG is supported starting with Vista
+ // - default is AI_ADDRCONFIG ON whether the flag is set or not
+ // because the performance penalty in not having ADDRCONFIG in
+ // the multi-protocol stack environment is severe;
+ // this defaulting may be disabled by specifying the AI_ALL flag,
+ // in that case AI_ADDRCONFIG must be EXPLICITLY specified to
+ // enable ADDRCONFIG behavior
+ //
+ // Not only is AI_ADDRCONFIG unnecessary, but it can be harmful. If the
+ // computer is not connected to a network, AI_ADDRCONFIG causes getaddrinfo
+ // to fail with WSANO_DATA (11004) for "localhost", probably because of the
+ // following note on AI_ADDRCONFIG in the MSDN getaddrinfo page:
+ // The IPv4 or IPv6 loopback address is not considered a valid global
+ // address.
+ // See http://crbug.com/5234.
+ //
+ // OpenBSD does not support it, either.
+ hints.ai_flags = 0;
+#else
+ hints.ai_flags = AI_ADDRCONFIG;
+#endif
+
+ // On Linux AI_ADDRCONFIG doesn't consider loopback addreses, even if only
+ // loopback addresses are configured. So don't use it when there are only
+ // loopback addresses.
+ if (host_resolver_flags & HOST_RESOLVER_LOOPBACK_ONLY)
+ hints.ai_flags &= ~AI_ADDRCONFIG;
+
+ if (host_resolver_flags & HOST_RESOLVER_CANONNAME)
+ hints.ai_flags |= AI_CANONNAME;
+
+ // Restrict result set to only this socket type to avoid duplicates.
+ hints.ai_socktype = SOCK_STREAM;
+
+#if defined(OS_POSIX) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) && \
+ !defined(OS_ANDROID)
+ DnsReloaderMaybeReload();
+#endif
+ int err = getaddrinfo(host.c_str(), NULL, &hints, &ai);
+ bool should_retry = false;
+ // If the lookup was restricted (either by address family, or address
+ // detection), and the results where all localhost of a single family,
+ // maybe we should retry. There were several bugs related to these
+ // issues, for example http://crbug.com/42058 and http://crbug.com/49024
+ if ((hints.ai_family != AF_UNSPEC || hints.ai_flags & AI_ADDRCONFIG) &&
+ err == 0 && IsAllLocalhostOfOneFamily(ai)) {
+ if (host_resolver_flags & HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6) {
+ hints.ai_family = AF_UNSPEC;
+ should_retry = true;
+ }
+ if (hints.ai_flags & AI_ADDRCONFIG) {
+ hints.ai_flags &= ~AI_ADDRCONFIG;
+ should_retry = true;
+ }
+ }
+ if (should_retry) {
+ if (ai != NULL) {
+ freeaddrinfo(ai);
+ ai = NULL;
+ }
+ err = getaddrinfo(host.c_str(), NULL, &hints, &ai);
+ }
+
+ if (err) {
+#if defined(OS_WIN)
+ err = WSAGetLastError();
+#endif
+
+ // Return the OS error to the caller.
+ if (os_error)
+ *os_error = err;
+
+ // If the call to getaddrinfo() failed because of a system error, report
+ // it separately from ERR_NAME_NOT_RESOLVED.
+#if defined(OS_WIN)
+ if (err != WSAHOST_NOT_FOUND && err != WSANO_DATA)
+ return ERR_NAME_RESOLUTION_FAILED;
+#elif defined(OS_POSIX) && !defined(OS_FREEBSD)
+ if (err != EAI_NONAME && err != EAI_NODATA)
+ return ERR_NAME_RESOLUTION_FAILED;
+#endif
+
+ return ERR_NAME_NOT_RESOLVED;
+ }
+
+#if defined(OS_ANDROID)
+ // Workaround for Android's getaddrinfo leaving ai==NULL without an error.
+ // http://crbug.com/134142
+ if (ai == NULL)
+ return ERR_NAME_NOT_RESOLVED;
+#endif
+
+ *addrlist = AddressList::CreateFromAddrinfo(ai);
+ freeaddrinfo(ai);
+ return OK;
+}
+
+SystemHostResolverProc::SystemHostResolverProc() : HostResolverProc(NULL) {}
+
+int SystemHostResolverProc::Resolve(const std::string& hostname,
+ AddressFamily address_family,
+ HostResolverFlags host_resolver_flags,
+ AddressList* addr_list,
+ int* os_error) {
+ return SystemHostResolverCall(hostname,
+ address_family,
+ host_resolver_flags,
+ addr_list,
+ os_error);
+}
+
+SystemHostResolverProc::~SystemHostResolverProc() {}
+
+} // namespace net
diff --git a/chromium/net/dns/host_resolver_proc.h b/chromium/net/dns/host_resolver_proc.h
new file mode 100644
index 00000000000..014a720e375
--- /dev/null
+++ b/chromium/net/dns/host_resolver_proc.h
@@ -0,0 +1,111 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_HOST_RESOLVER_PROC_H_
+#define NET_DNS_HOST_RESOLVER_PROC_H_
+
+#include <string>
+
+#include "base/memory/ref_counted.h"
+#include "net/base/address_family.h"
+#include "net/base/net_export.h"
+
+namespace net {
+
+class AddressList;
+
+// Interface for a getaddrinfo()-like procedure. This is used by unit-tests
+// to control the underlying resolutions in HostResolverImpl. HostResolverProcs
+// can be chained together; they fallback to the next procedure in the chain
+// by calling ResolveUsingPrevious().
+//
+// Note that implementations of HostResolverProc *MUST BE THREADSAFE*, since
+// the HostResolver implementation using them can be multi-threaded.
+class NET_EXPORT HostResolverProc
+ : public base::RefCountedThreadSafe<HostResolverProc> {
+ public:
+ explicit HostResolverProc(HostResolverProc* previous);
+
+ // Resolves |host| to an address list, restricting the results to addresses
+ // in |address_family|. If successful returns OK and fills |addrlist| with
+ // a list of socket addresses. Otherwise returns a network error code, and
+ // fills |os_error| with a more specific error if it was non-NULL.
+ virtual int Resolve(const std::string& host,
+ AddressFamily address_family,
+ HostResolverFlags host_resolver_flags,
+ AddressList* addrlist,
+ int* os_error) = 0;
+
+ protected:
+ friend class base::RefCountedThreadSafe<HostResolverProc>;
+
+ virtual ~HostResolverProc();
+
+ // Asks the fallback procedure (if set) to do the resolve.
+ int ResolveUsingPrevious(const std::string& host,
+ AddressFamily address_family,
+ HostResolverFlags host_resolver_flags,
+ AddressList* addrlist,
+ int* os_error);
+
+ private:
+ friend class HostResolverImpl;
+ friend class MockHostResolverBase;
+ friend class ScopedDefaultHostResolverProc;
+
+ // Sets the previous procedure in the chain. Aborts if this would result in a
+ // cycle.
+ void SetPreviousProc(HostResolverProc* proc);
+
+ // Sets the last procedure in the chain, i.e. appends |proc| to the end of the
+ // current chain. Aborts if this would result in a cycle.
+ void SetLastProc(HostResolverProc* proc);
+
+ // Returns the last procedure in the chain starting at |proc|. Will return
+ // NULL iff |proc| is NULL.
+ static HostResolverProc* GetLastProc(HostResolverProc* proc);
+
+ // Sets the default host resolver procedure that is used by HostResolverImpl.
+ // This can be used through ScopedDefaultHostResolverProc to set a catch-all
+ // DNS block in unit-tests (individual tests should use MockHostResolver to
+ // prevent hitting the network).
+ static HostResolverProc* SetDefault(HostResolverProc* proc);
+ static HostResolverProc* GetDefault();
+
+ scoped_refptr<HostResolverProc> previous_proc_;
+ static HostResolverProc* default_proc_;
+
+ DISALLOW_COPY_AND_ASSIGN(HostResolverProc);
+};
+
+// Resolves |host| to an address list, using the system's default host resolver.
+// (i.e. this calls out to getaddrinfo()). If successful returns OK and fills
+// |addrlist| with a list of socket addresses. Otherwise returns a
+// network error code, and fills |os_error| with a more specific error if it
+// was non-NULL.
+NET_EXPORT_PRIVATE int SystemHostResolverCall(
+ const std::string& host,
+ AddressFamily address_family,
+ HostResolverFlags host_resolver_flags,
+ AddressList* addrlist,
+ int* os_error);
+
+// Wraps call to SystemHostResolverCall as an instance of HostResolverProc.
+class NET_EXPORT_PRIVATE SystemHostResolverProc : public HostResolverProc {
+ public:
+ SystemHostResolverProc();
+ virtual int Resolve(const std::string& hostname,
+ AddressFamily address_family,
+ HostResolverFlags host_resolver_flags,
+ AddressList* addr_list,
+ int* os_error) OVERRIDE;
+ protected:
+ virtual ~SystemHostResolverProc();
+
+ DISALLOW_COPY_AND_ASSIGN(SystemHostResolverProc);
+};
+
+} // namespace net
+
+#endif // NET_DNS_HOST_RESOLVER_PROC_H_
diff --git a/chromium/net/dns/mapped_host_resolver.cc b/chromium/net/dns/mapped_host_resolver.cc
new file mode 100644
index 00000000000..4db7bc97928
--- /dev/null
+++ b/chromium/net/dns/mapped_host_resolver.cc
@@ -0,0 +1,63 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/mapped_host_resolver.h"
+
+#include "base/strings/string_util.h"
+#include "net/base/host_port_pair.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+
+namespace net {
+
+MappedHostResolver::MappedHostResolver(scoped_ptr<HostResolver> impl)
+ : impl_(impl.Pass()) {
+}
+
+MappedHostResolver::~MappedHostResolver() {
+}
+
+int MappedHostResolver::Resolve(const RequestInfo& original_info,
+ AddressList* addresses,
+ const CompletionCallback& callback,
+ RequestHandle* out_req,
+ const BoundNetLog& net_log) {
+ RequestInfo info = original_info;
+ int rv = ApplyRules(&info);
+ if (rv != OK)
+ return rv;
+
+ return impl_->Resolve(info, addresses, callback, out_req, net_log);
+}
+
+int MappedHostResolver::ResolveFromCache(const RequestInfo& original_info,
+ AddressList* addresses,
+ const BoundNetLog& net_log) {
+ RequestInfo info = original_info;
+ int rv = ApplyRules(&info);
+ if (rv != OK)
+ return rv;
+
+ return impl_->ResolveFromCache(info, addresses, net_log);
+}
+
+void MappedHostResolver::CancelRequest(RequestHandle req) {
+ impl_->CancelRequest(req);
+}
+
+HostCache* MappedHostResolver::GetHostCache() {
+ return impl_->GetHostCache();
+}
+
+int MappedHostResolver::ApplyRules(RequestInfo* info) const {
+ HostPortPair host_port(info->host_port_pair());
+ if (rules_.RewriteHost(&host_port)) {
+ if (host_port.host() == "~NOTFOUND")
+ return ERR_NAME_NOT_RESOLVED;
+ info->set_host_port_pair(host_port);
+ }
+ return OK;
+}
+
+} // namespace net
diff --git a/chromium/net/dns/mapped_host_resolver.h b/chromium/net/dns/mapped_host_resolver.h
new file mode 100644
index 00000000000..50062a9848e
--- /dev/null
+++ b/chromium/net/dns/mapped_host_resolver.h
@@ -0,0 +1,71 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_MAPPED_HOST_RESOLVER_H_
+#define NET_DNS_MAPPED_HOST_RESOLVER_H_
+
+#include <string>
+
+#include "base/memory/scoped_ptr.h"
+#include "net/base/host_mapping_rules.h"
+#include "net/base/net_export.h"
+#include "net/dns/host_resolver.h"
+
+namespace net {
+
+// This class wraps an existing HostResolver instance, but modifies the
+// request before passing it off to |impl|. This is different from
+// MockHostResolver which does the remapping at the HostResolverProc
+// layer, so it is able to preserve the effectiveness of the cache.
+class NET_EXPORT MappedHostResolver : public HostResolver {
+ public:
+ // Creates a MappedHostResolver that forwards all of its requests through
+ // |impl|.
+ explicit MappedHostResolver(scoped_ptr<HostResolver> impl);
+ virtual ~MappedHostResolver();
+
+ // Adds a rule to this mapper. The format of the rule can be one of:
+ //
+ // "MAP" <hostname_pattern> <replacement_host> [":" <replacement_port>]
+ // "EXCLUDE" <hostname_pattern>
+ //
+ // The <replacement_host> can be either a hostname, or an IP address literal,
+ // or "~NOTFOUND". If it is "~NOTFOUND" then all matched hostnames will fail
+ // to be resolved with ERR_NAME_NOT_RESOLVED.
+ //
+ // Returns true if the rule was successfully parsed and added.
+ bool AddRuleFromString(const std::string& rule_string) {
+ return rules_.AddRuleFromString(rule_string);
+ }
+
+ // Takes a comma separated list of rules, and assigns them to this resolver.
+ void SetRulesFromString(const std::string& rules_string) {
+ rules_.SetRulesFromString(rules_string);
+ }
+
+ // HostResolver methods:
+ virtual int Resolve(const RequestInfo& info,
+ AddressList* addresses,
+ const CompletionCallback& callback,
+ RequestHandle* out_req,
+ const BoundNetLog& net_log) OVERRIDE;
+ virtual int ResolveFromCache(const RequestInfo& info,
+ AddressList* addresses,
+ const BoundNetLog& net_log) OVERRIDE;
+ virtual void CancelRequest(RequestHandle req) OVERRIDE;
+ virtual HostCache* GetHostCache() OVERRIDE;
+
+ private:
+ // Modify the request |info| according to |rules_|. Returns either OK or
+ // the network error code that the hostname's resolution mapped to.
+ int ApplyRules(RequestInfo* info) const;
+
+ scoped_ptr<HostResolver> impl_;
+
+ HostMappingRules rules_;
+};
+
+} // namespace net
+
+#endif // NET_DNS_MAPPED_HOST_RESOLVER_H_
diff --git a/chromium/net/dns/mapped_host_resolver_unittest.cc b/chromium/net/dns/mapped_host_resolver_unittest.cc
new file mode 100644
index 00000000000..d8594663c9d
--- /dev/null
+++ b/chromium/net/dns/mapped_host_resolver_unittest.cc
@@ -0,0 +1,219 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/mapped_host_resolver.h"
+
+#include "net/base/address_list.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/net_util.h"
+#include "net/base/test_completion_callback.h"
+#include "net/dns/mock_host_resolver.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+std::string FirstAddress(const AddressList& address_list) {
+ if (address_list.empty())
+ return std::string();
+ return address_list.front().ToString();
+}
+
+TEST(MappedHostResolverTest, Inclusion) {
+ // Create a mock host resolver, with specific hostname to IP mappings.
+ scoped_ptr<MockHostResolver> resolver_impl(new MockHostResolver());
+ resolver_impl->rules()->AddSimulatedFailure("*google.com");
+ resolver_impl->rules()->AddRule("baz.com", "192.168.1.5");
+ resolver_impl->rules()->AddRule("foo.com", "192.168.1.8");
+ resolver_impl->rules()->AddRule("proxy", "192.168.1.11");
+
+ // Create a remapped resolver that uses |resolver_impl|.
+ scoped_ptr<MappedHostResolver> resolver(
+ new MappedHostResolver(resolver_impl.PassAs<HostResolver>()));
+
+ int rv;
+ AddressList address_list;
+
+ // Try resolving "www.google.com:80". There are no mappings yet, so this
+ // hits |resolver_impl| and fails.
+ TestCompletionCallback callback;
+ rv = resolver->Resolve(HostResolver::RequestInfo(
+ HostPortPair("www.google.com", 80)),
+ &address_list, callback.callback(), NULL,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback.WaitForResult();
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv);
+
+ // Remap *.google.com to baz.com.
+ EXPECT_TRUE(resolver->AddRuleFromString("map *.google.com baz.com"));
+
+ // Try resolving "www.google.com:80". Should be remapped to "baz.com:80".
+ rv = resolver->Resolve(HostResolver::RequestInfo(
+ HostPortPair("www.google.com", 80)),
+ &address_list, callback.callback(), NULL,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_EQ("192.168.1.5:80", FirstAddress(address_list));
+
+ // Try resolving "foo.com:77". This will NOT be remapped, so result
+ // is "foo.com:77".
+ rv = resolver->Resolve(HostResolver::RequestInfo(HostPortPair("foo.com", 77)),
+ &address_list, callback.callback(), NULL,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_EQ("192.168.1.8:77", FirstAddress(address_list));
+
+ // Remap "*.org" to "proxy:99".
+ EXPECT_TRUE(resolver->AddRuleFromString("Map *.org proxy:99"));
+
+ // Try resolving "chromium.org:61". Should be remapped to "proxy:99".
+ rv = resolver->Resolve(HostResolver::RequestInfo
+ (HostPortPair("chromium.org", 61)),
+ &address_list, callback.callback(), NULL,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_EQ("192.168.1.11:99", FirstAddress(address_list));
+}
+
+// Tests that exclusions are respected.
+TEST(MappedHostResolverTest, Exclusion) {
+ // Create a mock host resolver, with specific hostname to IP mappings.
+ scoped_ptr<MockHostResolver> resolver_impl(new MockHostResolver());
+ resolver_impl->rules()->AddRule("baz", "192.168.1.5");
+ resolver_impl->rules()->AddRule("www.google.com", "192.168.1.3");
+
+ // Create a remapped resolver that uses |resolver_impl|.
+ scoped_ptr<MappedHostResolver> resolver(
+ new MappedHostResolver(resolver_impl.PassAs<HostResolver>()));
+
+ int rv;
+ AddressList address_list;
+ TestCompletionCallback callback;
+
+ // Remap "*.com" to "baz".
+ EXPECT_TRUE(resolver->AddRuleFromString("map *.com baz"));
+
+ // Add an exclusion for "*.google.com".
+ EXPECT_TRUE(resolver->AddRuleFromString("EXCLUDE *.google.com"));
+
+ // Try resolving "www.google.com". Should not be remapped due to exclusion).
+ rv = resolver->Resolve(HostResolver::RequestInfo(
+ HostPortPair("www.google.com", 80)),
+ &address_list, callback.callback(), NULL,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_EQ("192.168.1.3:80", FirstAddress(address_list));
+
+ // Try resolving "chrome.com:80". Should be remapped to "baz:80".
+ rv = resolver->Resolve(HostResolver::RequestInfo(
+ HostPortPair("chrome.com", 80)),
+ &address_list, callback.callback(), NULL,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_EQ("192.168.1.5:80", FirstAddress(address_list));
+}
+
+TEST(MappedHostResolverTest, SetRulesFromString) {
+ // Create a mock host resolver, with specific hostname to IP mappings.
+ scoped_ptr<MockHostResolver> resolver_impl(new MockHostResolver());
+ resolver_impl->rules()->AddRule("baz", "192.168.1.7");
+ resolver_impl->rules()->AddRule("bar", "192.168.1.9");
+
+ // Create a remapped resolver that uses |resolver_impl|.
+ scoped_ptr<MappedHostResolver> resolver(
+ new MappedHostResolver(resolver_impl.PassAs<HostResolver>()));
+
+ int rv;
+ AddressList address_list;
+ TestCompletionCallback callback;
+
+ // Remap "*.com" to "baz", and *.net to "bar:60".
+ resolver->SetRulesFromString("map *.com baz , map *.net bar:60");
+
+ // Try resolving "www.google.com". Should be remapped to "baz".
+ rv = resolver->Resolve(HostResolver::RequestInfo(
+ HostPortPair("www.google.com", 80)),
+ &address_list, callback.callback(), NULL,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_EQ("192.168.1.7:80", FirstAddress(address_list));
+
+ // Try resolving "chrome.net:80". Should be remapped to "bar:60".
+ rv = resolver->Resolve(HostResolver::RequestInfo(
+ HostPortPair("chrome.net", 80)),
+ &address_list, callback.callback(), NULL,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_EQ("192.168.1.9:60", FirstAddress(address_list));
+}
+
+// Parsing bad rules should silently discard the rule (and never crash).
+TEST(MappedHostResolverTest, ParseInvalidRules) {
+ scoped_ptr<MappedHostResolver> resolver(
+ new MappedHostResolver(scoped_ptr<HostResolver>()));
+
+ EXPECT_FALSE(resolver->AddRuleFromString("xyz"));
+ EXPECT_FALSE(resolver->AddRuleFromString(std::string()));
+ EXPECT_FALSE(resolver->AddRuleFromString(" "));
+ EXPECT_FALSE(resolver->AddRuleFromString("EXCLUDE"));
+ EXPECT_FALSE(resolver->AddRuleFromString("EXCLUDE foo bar"));
+ EXPECT_FALSE(resolver->AddRuleFromString("INCLUDE"));
+ EXPECT_FALSE(resolver->AddRuleFromString("INCLUDE x"));
+ EXPECT_FALSE(resolver->AddRuleFromString("INCLUDE x :10"));
+}
+
+// Test mapping hostnames to resolving failures.
+TEST(MappedHostResolverTest, MapToError) {
+ scoped_ptr<MockHostResolver> resolver_impl(new MockHostResolver());
+ resolver_impl->rules()->AddRule("*", "192.168.1.5");
+
+ scoped_ptr<MappedHostResolver> resolver(
+ new MappedHostResolver(resolver_impl.PassAs<HostResolver>()));
+
+ int rv;
+ AddressList address_list;
+
+ // Remap *.google.com to resolving failures.
+ EXPECT_TRUE(resolver->AddRuleFromString("MAP *.google.com ~NOTFOUND"));
+
+ // Try resolving www.google.com --> Should give an error.
+ TestCompletionCallback callback1;
+ rv = resolver->Resolve(HostResolver::RequestInfo(
+ HostPortPair("www.google.com", 80)),
+ &address_list, callback1.callback(), NULL,
+ BoundNetLog());
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv);
+
+ // Try resolving www.foo.com --> Should succeed.
+ TestCompletionCallback callback2;
+ rv = resolver->Resolve(HostResolver::RequestInfo(
+ HostPortPair("www.foo.com", 80)),
+ &address_list, callback2.callback(), NULL,
+ BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ rv = callback2.WaitForResult();
+ EXPECT_EQ(OK, rv);
+ EXPECT_EQ("192.168.1.5:80", FirstAddress(address_list));
+}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/dns/mdns_cache.cc b/chromium/net/dns/mdns_cache.cc
new file mode 100644
index 00000000000..010a34f45d4
--- /dev/null
+++ b/chromium/net/dns/mdns_cache.cc
@@ -0,0 +1,212 @@
+// Copyright (c) 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/mdns_cache.h"
+
+#include <algorithm>
+#include <utility>
+
+#include "base/stl_util.h"
+#include "base/strings/string_number_conversions.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/record_parsed.h"
+#include "net/dns/record_rdata.h"
+
+// TODO(noamsml): Recursive CNAME closure (backwards and forwards).
+
+namespace net {
+
+// The effective TTL given to records with a nominal zero TTL.
+// Allows time for hosts to send updated records, as detailed in RFC 6762
+// Section 10.1.
+static const unsigned kZeroTTLSeconds = 1;
+
+MDnsCache::Key::Key(unsigned type, const std::string& name,
+ const std::string& optional)
+ : type_(type), name_(name), optional_(optional) {
+}
+
+MDnsCache::Key::Key(
+ const MDnsCache::Key& other)
+ : type_(other.type_), name_(other.name_), optional_(other.optional_) {
+}
+
+
+MDnsCache::Key& MDnsCache::Key::operator=(
+ const MDnsCache::Key& other) {
+ type_ = other.type_;
+ name_ = other.name_;
+ optional_ = other.optional_;
+ return *this;
+}
+
+MDnsCache::Key::~Key() {
+}
+
+bool MDnsCache::Key::operator<(const MDnsCache::Key& key) const {
+ if (name_ != key.name_)
+ return name_ < key.name_;
+
+ if (type_ != key.type_)
+ return type_ < key.type_;
+
+ if (optional_ != key.optional_)
+ return optional_ < key.optional_;
+ return false; // keys are equal
+}
+
+bool MDnsCache::Key::operator==(const MDnsCache::Key& key) const {
+ return type_ == key.type_ && name_ == key.name_ && optional_ == key.optional_;
+}
+
+// static
+MDnsCache::Key MDnsCache::Key::CreateFor(const RecordParsed* record) {
+ return Key(record->type(),
+ record->name(),
+ GetOptionalFieldForRecord(record));
+}
+
+
+MDnsCache::MDnsCache() {
+}
+
+MDnsCache::~MDnsCache() {
+ Clear();
+}
+
+void MDnsCache::Clear() {
+ next_expiration_ = base::Time();
+ STLDeleteValues(&mdns_cache_);
+}
+
+const RecordParsed* MDnsCache::LookupKey(const Key& key) {
+ RecordMap::iterator found = mdns_cache_.find(key);
+ if (found != mdns_cache_.end()) {
+ return found->second;
+ }
+ return NULL;
+}
+
+MDnsCache::UpdateType MDnsCache::UpdateDnsRecord(
+ scoped_ptr<const RecordParsed> record) {
+ Key cache_key = Key::CreateFor(record.get());
+
+ // Ignore "goodbye" packets for records not in cache.
+ if (record->ttl() == 0 && mdns_cache_.find(cache_key) == mdns_cache_.end())
+ return NoChange;
+
+ base::Time new_expiration = GetEffectiveExpiration(record.get());
+ if (next_expiration_ != base::Time())
+ new_expiration = std::min(new_expiration, next_expiration_);
+
+ std::pair<RecordMap::iterator, bool> insert_result =
+ mdns_cache_.insert(std::make_pair(cache_key, (const RecordParsed*)NULL));
+ UpdateType type = NoChange;
+ if (insert_result.second) {
+ type = RecordAdded;
+ } else {
+ const RecordParsed* other_record = insert_result.first->second;
+
+ if (record->ttl() != 0 && !record->IsEqual(other_record, true)) {
+ type = RecordChanged;
+ }
+ delete other_record;
+ }
+
+ insert_result.first->second = record.release();
+ next_expiration_ = new_expiration;
+ return type;
+}
+
+void MDnsCache::CleanupRecords(
+ base::Time now,
+ const RecordRemovedCallback& record_removed_callback) {
+ base::Time next_expiration;
+
+ // We are guaranteed that |next_expiration_| will be at or before the next
+ // expiration. This allows clients to eagrely call CleanupRecords with
+ // impunity.
+ if (now < next_expiration_) return;
+
+ for (RecordMap::iterator i = mdns_cache_.begin();
+ i != mdns_cache_.end(); ) {
+ base::Time expiration = GetEffectiveExpiration(i->second);
+ if (now >= expiration) {
+ record_removed_callback.Run(i->second);
+ delete i->second;
+ mdns_cache_.erase(i++);
+ } else {
+ if (next_expiration == base::Time() || expiration < next_expiration) {
+ next_expiration = expiration;
+ }
+ ++i;
+ }
+ }
+
+ next_expiration_ = next_expiration;
+}
+
+void MDnsCache::FindDnsRecords(unsigned type,
+ const std::string& name,
+ std::vector<const RecordParsed*>* results,
+ base::Time now) const {
+ DCHECK(results);
+ results->clear();
+
+ RecordMap::const_iterator i = mdns_cache_.lower_bound(Key(type, name, ""));
+ for (; i != mdns_cache_.end(); ++i) {
+ if (i->first.name() != name ||
+ (type != 0 && i->first.type() != type)) {
+ break;
+ }
+
+ const RecordParsed* record = i->second;
+
+ // Records are deleted only upon request.
+ if (now >= GetEffectiveExpiration(record)) continue;
+
+ results->push_back(record);
+ }
+}
+
+scoped_ptr<const RecordParsed> MDnsCache::RemoveRecord(
+ const RecordParsed* record) {
+ Key key = Key::CreateFor(record);
+ RecordMap::iterator found = mdns_cache_.find(key);
+
+ if (found != mdns_cache_.end() && found->second == record) {
+ mdns_cache_.erase(key);
+ return scoped_ptr<const RecordParsed>(record);
+ }
+
+ return scoped_ptr<const RecordParsed>();
+}
+
+// static
+std::string MDnsCache::GetOptionalFieldForRecord(
+ const RecordParsed* record) {
+ switch (record->type()) {
+ case PtrRecordRdata::kType: {
+ const PtrRecordRdata* rdata = record->rdata<PtrRecordRdata>();
+ return rdata->ptrdomain();
+ }
+ default: // Most records are considered unique for our purposes
+ return "";
+ }
+}
+
+// static
+base::Time MDnsCache::GetEffectiveExpiration(const RecordParsed* record) {
+ base::TimeDelta ttl;
+
+ if (record->ttl()) {
+ ttl = base::TimeDelta::FromSeconds(record->ttl());
+ } else {
+ ttl = base::TimeDelta::FromSeconds(kZeroTTLSeconds);
+ }
+
+ return record->time_created() + ttl;
+}
+
+} // namespace net
diff --git a/chromium/net/dns/mdns_cache.h b/chromium/net/dns/mdns_cache.h
new file mode 100644
index 00000000000..27a14d8fa44
--- /dev/null
+++ b/chromium/net/dns/mdns_cache.h
@@ -0,0 +1,119 @@
+// Copyright (c) 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_MDNS_CACHE_H_
+#define NET_DNS_MDNS_CACHE_H_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "base/callback.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/time/time.h"
+#include "net/base/net_export.h"
+
+namespace net {
+
+class ParsedDnsRecord;
+class RecordParsed;
+
+// mDNS Cache
+// This is a cache of mDNS records. It keeps track of expiration times and is
+// guaranteed not to return expired records. It also has facilities for timely
+// record expiration.
+class NET_EXPORT_PRIVATE MDnsCache {
+ public:
+ // Key type for the record map. It is a 3-tuple of type, name and optional
+ // value ordered by type, then name, then optional value. This allows us to
+ // query for all records of a certain type and name, while also allowing us
+ // to set records of a certain type, name and optionally value as unique.
+ class Key {
+ public:
+ Key(unsigned type, const std::string& name, const std::string& optional);
+ Key(const Key&);
+ Key& operator=(const Key&);
+ ~Key();
+ bool operator<(const Key& key) const;
+ bool operator==(const Key& key) const;
+
+ unsigned type() const { return type_; }
+ const std::string& name() const { return name_; }
+ const std::string& optional() const { return optional_; }
+
+ // Create the cache key corresponding to |record|.
+ static Key CreateFor(const RecordParsed* record);
+ private:
+ unsigned type_;
+ std::string name_;
+ std::string optional_;
+ };
+
+ typedef base::Callback<void(const RecordParsed*)> RecordRemovedCallback;
+
+ enum UpdateType {
+ RecordAdded,
+ RecordChanged,
+ NoChange
+ };
+
+ MDnsCache();
+ ~MDnsCache();
+
+ // Return value indicates whether the record was added, changed
+ // (existed previously with different value) or not changed (existed
+ // previously with same value).
+ UpdateType UpdateDnsRecord(scoped_ptr<const RecordParsed> record);
+
+ // Check cache for record with key |key|. Return the record if it exists, or
+ // NULL if it doesn't.
+ const RecordParsed* LookupKey(const Key& key);
+
+ // Return records with type |type| and name |name|. Expired records will not
+ // be returned. If |type| is zero, return all records with name |name|.
+ void FindDnsRecords(unsigned type,
+ const std::string& name,
+ std::vector<const RecordParsed*>* records,
+ base::Time now) const;
+
+ // Remove expired records, call |record_removed_callback| for every removed
+ // record.
+ void CleanupRecords(base::Time now,
+ const RecordRemovedCallback& record_removed_callback);
+
+ // Returns a time less than or equal to the next time a record will expire.
+ // Is updated when CleanupRecords or UpdateDnsRecord are called. Returns
+ // base::Time when the cache is empty.
+ base::Time next_expiration() const { return next_expiration_; }
+
+ // Remove a record from the cache. Returns a scoped version of the pointer
+ // passed in if it was removed, scoped null otherwise.
+ scoped_ptr<const RecordParsed> RemoveRecord(const RecordParsed* record);
+
+ void Clear();
+
+ private:
+ typedef std::map<Key, const RecordParsed*> RecordMap;
+
+ // Get the effective expiration of a cache entry, based on its creation time
+ // and TTL. Does adjustments so entries with a TTL of zero will have a
+ // nonzero TTL, as explained in RFC 6762 Section 10.1.
+ static base::Time GetEffectiveExpiration(const RecordParsed* entry);
+
+ // Get optional part of the DNS key for shared records. For example, in PTR
+ // records this is the pointed domain, since multiple PTR records may exist
+ // for the same name.
+ static std::string GetOptionalFieldForRecord(
+ const RecordParsed* record);
+
+ RecordMap mdns_cache_;
+
+ base::Time next_expiration_;
+
+ DISALLOW_COPY_AND_ASSIGN(MDnsCache);
+};
+
+} // namespace net
+
+#endif // NET_DNS_MDNS_CACHE_H_
diff --git a/chromium/net/dns/mdns_cache_unittest.cc b/chromium/net/dns/mdns_cache_unittest.cc
new file mode 100644
index 00000000000..c12ad6b6ec3
--- /dev/null
+++ b/chromium/net/dns/mdns_cache_unittest.cc
@@ -0,0 +1,375 @@
+// Copyright (c) 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <algorithm>
+
+#include "base/bind.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/dns_test_util.h"
+#include "net/dns/mdns_cache.h"
+#include "net/dns/record_parsed.h"
+#include "net/dns/record_rdata.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+using ::testing::Return;
+using ::testing::StrictMock;
+
+namespace net {
+
+static const uint8 kTestResponsesDifferentAnswers[] = {
+ // Answer 1
+ // ghs.l.google.com in DNS format.
+ 3, 'g', 'h', 's',
+ 1, 'l',
+ 6, 'g', 'o', 'o', 'g', 'l', 'e',
+ 3, 'c', 'o', 'm',
+ 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0, 0, 0, 53, // TTL (4 bytes) is 53 seconds.
+ 0, 4, // RDLENGTH is 4 bytes.
+ 74, 125, 95, 121, // RDATA is the IP: 74.125.95.121
+
+ // Answer 2
+ // Pointer to answer 1
+ 0xc0, 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0, 0, 0, 53, // TTL (4 bytes) is 53 seconds.
+ 0, 4, // RDLENGTH is 4 bytes.
+ 74, 125, 95, 122, // RDATA is the IP: 74.125.95.122
+};
+
+static const uint8 kTestResponsesSameAnswers[] = {
+ // Answer 1
+ // ghs.l.google.com in DNS format.
+ 3, 'g', 'h', 's',
+ 1, 'l',
+ 6, 'g', 'o', 'o', 'g', 'l', 'e',
+ 3, 'c', 'o', 'm',
+ 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0, 0, 0, 53, // TTL (4 bytes) is 53 seconds.
+ 0, 4, // RDLENGTH is 4 bytes.
+ 74, 125, 95, 121, // RDATA is the IP: 74.125.95.121
+
+ // Answer 2
+ // Pointer to answer 1
+ 0xc0, 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0, 0, 0, 112, // TTL (4 bytes) is 112 seconds.
+ 0, 4, // RDLENGTH is 4 bytes.
+ 74, 125, 95, 121, // RDATA is the IP: 74.125.95.121
+};
+
+static const uint8 kTestResponseTwoRecords[] = {
+ // Answer 1
+ // ghs.l.google.com in DNS format. (A)
+ 3, 'g', 'h', 's',
+ 1, 'l',
+ 6, 'g', 'o', 'o', 'g', 'l', 'e',
+ 3, 'c', 'o', 'm',
+ 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0, 0, 0, 53, // TTL (4 bytes) is 53 seconds.
+ 0, 4, // RDLENGTH is 4 bytes.
+ 74, 125, 95, 121, // RDATA is the IP: 74.125.95.121
+
+ // Answer 2
+ // ghs.l.google.com in DNS format. (AAAA)
+ 3, 'g', 'h', 's',
+ 1, 'l',
+ 6, 'g', 'o', 'o', 'g', 'l', 'e',
+ 3, 'c', 'o', 'm',
+ 0x00,
+ 0x00, 0x1c, // TYPE is AAA.
+ 0x00, 0x01, // CLASS is IN.
+ 0, 0, 0, 53, // TTL (4 bytes) is 53 seconds.
+ 0, 16, // RDLENGTH is 16 bytes.
+ 0x4a, 0x7d, 0x4a, 0x7d,
+ 0x5f, 0x79, 0x5f, 0x79,
+ 0x5f, 0x79, 0x5f, 0x79,
+ 0x5f, 0x79, 0x5f, 0x79,
+};
+
+static const uint8 kTestResponsesGoodbyePacket[] = {
+ // Answer 1
+ // ghs.l.google.com in DNS format. (Goodbye packet)
+ 3, 'g', 'h', 's',
+ 1, 'l',
+ 6, 'g', 'o', 'o', 'g', 'l', 'e',
+ 3, 'c', 'o', 'm',
+ 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0, 0, 0, 0, // TTL (4 bytes) is zero.
+ 0, 4, // RDLENGTH is 4 bytes.
+ 74, 125, 95, 121, // RDATA is the IP: 74.125.95.121
+
+ // Answer 2
+ // ghs.l.google.com in DNS format.
+ 3, 'g', 'h', 's',
+ 1, 'l',
+ 6, 'g', 'o', 'o', 'g', 'l', 'e',
+ 3, 'c', 'o', 'm',
+ 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0, 0, 0, 53, // TTL (4 bytes) is 53 seconds.
+ 0, 4, // RDLENGTH is 4 bytes.
+ 74, 125, 95, 121, // RDATA is the IP: 74.125.95.121
+};
+
+class RecordRemovalMock {
+ public:
+ MOCK_METHOD1(OnRecordRemoved, void(const RecordParsed*));
+};
+
+class MDnsCacheTest : public ::testing::Test {
+ public:
+ MDnsCacheTest()
+ : default_time_(base::Time::FromDoubleT(1234.0)) {}
+ virtual ~MDnsCacheTest() {}
+
+ protected:
+ base::Time default_time_;
+ StrictMock<RecordRemovalMock> record_removal_;
+ MDnsCache cache_;
+};
+
+// Test a single insert, corresponding lookup, and unsuccessful lookup.
+TEST_F(MDnsCacheTest, InsertLookupSingle) {
+ DnsRecordParser parser(kT1ResponseDatagram, sizeof(kT1ResponseDatagram),
+ sizeof(dns_protocol::Header));
+ parser.SkipQuestion();
+
+ scoped_ptr<const RecordParsed> record1;
+ scoped_ptr<const RecordParsed> record2;
+ std::vector<const RecordParsed*> results;
+
+ record1 = RecordParsed::CreateFrom(&parser, default_time_);
+ record2 = RecordParsed::CreateFrom(&parser, default_time_);
+
+ EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record1.Pass()));
+
+ EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record2.Pass()));
+
+ cache_.FindDnsRecords(ARecordRdata::kType, "ghs.l.google.com", &results,
+ default_time_);
+
+ EXPECT_EQ(1u, results.size());
+ EXPECT_EQ(default_time_, results.front()->time_created());
+
+ EXPECT_EQ("ghs.l.google.com", results.front()->name());
+
+ results.clear();
+ cache_.FindDnsRecords(PtrRecordRdata::kType, "ghs.l.google.com", &results,
+ default_time_);
+
+ EXPECT_EQ(0u, results.size());
+}
+
+// Test that records expire when their ttl has passed.
+TEST_F(MDnsCacheTest, Expiration) {
+ DnsRecordParser parser(kT1ResponseDatagram, sizeof(kT1ResponseDatagram),
+ sizeof(dns_protocol::Header));
+ parser.SkipQuestion();
+ scoped_ptr<const RecordParsed> record1;
+ scoped_ptr<const RecordParsed> record2;
+
+ std::vector<const RecordParsed*> results;
+ const RecordParsed* record_to_be_deleted;
+
+ record1 = RecordParsed::CreateFrom(&parser, default_time_);
+ base::TimeDelta ttl1 = base::TimeDelta::FromSeconds(record1->ttl());
+
+ record2 = RecordParsed::CreateFrom(&parser, default_time_);
+ base::TimeDelta ttl2 = base::TimeDelta::FromSeconds(record2->ttl());
+ record_to_be_deleted = record2.get();
+
+ EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record1.Pass()));
+ EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record2.Pass()));
+
+ cache_.FindDnsRecords(ARecordRdata::kType, "ghs.l.google.com", &results,
+ default_time_);
+
+ EXPECT_EQ(1u, results.size());
+
+ EXPECT_EQ(default_time_ + ttl2, cache_.next_expiration());
+
+
+ cache_.FindDnsRecords(ARecordRdata::kType, "ghs.l.google.com", &results,
+ default_time_ + ttl2);
+
+ EXPECT_EQ(0u, results.size());
+
+ EXPECT_CALL(record_removal_, OnRecordRemoved(record_to_be_deleted));
+
+ cache_.CleanupRecords(default_time_ + ttl2, base::Bind(
+ &RecordRemovalMock::OnRecordRemoved, base::Unretained(&record_removal_)));
+
+ // To make sure that we've indeed removed them from the map, check no funny
+ // business happens once they're deleted for good.
+
+ EXPECT_EQ(default_time_ + ttl1, cache_.next_expiration());
+ cache_.FindDnsRecords(ARecordRdata::kType, "ghs.l.google.com", &results,
+ default_time_ + ttl2);
+
+ EXPECT_EQ(0u, results.size());
+}
+
+// Test that a new record replacing one with the same identity (name/rrtype for
+// unique records) causes the cache to output a "record changed" event.
+TEST_F(MDnsCacheTest, RecordChange) {
+ DnsRecordParser parser(kTestResponsesDifferentAnswers,
+ sizeof(kTestResponsesDifferentAnswers),
+ 0);
+
+ scoped_ptr<const RecordParsed> record1;
+ scoped_ptr<const RecordParsed> record2;
+ std::vector<const RecordParsed*> results;
+
+ record1 = RecordParsed::CreateFrom(&parser, default_time_);
+ record2 = RecordParsed::CreateFrom(&parser, default_time_);
+
+ EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record1.Pass()));
+ EXPECT_EQ(MDnsCache::RecordChanged,
+ cache_.UpdateDnsRecord(record2.Pass()));
+}
+
+// Test that a new record replacing an otherwise identical one already in the
+// cache causes the cache to output a "no change" event.
+TEST_F(MDnsCacheTest, RecordNoChange) {
+ DnsRecordParser parser(kTestResponsesSameAnswers,
+ sizeof(kTestResponsesSameAnswers),
+ 0);
+
+ scoped_ptr<const RecordParsed> record1;
+ scoped_ptr<const RecordParsed> record2;
+ std::vector<const RecordParsed*> results;
+
+ record1 = RecordParsed::CreateFrom(&parser, default_time_);
+ record2 = RecordParsed::CreateFrom(&parser, default_time_ +
+ base::TimeDelta::FromSeconds(1));
+
+ EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record1.Pass()));
+ EXPECT_EQ(MDnsCache::NoChange, cache_.UpdateDnsRecord(record2.Pass()));
+}
+
+// Test that the next expiration time of the cache is updated properly on record
+// insertion.
+TEST_F(MDnsCacheTest, RecordPreemptExpirationTime) {
+ DnsRecordParser parser(kTestResponsesSameAnswers,
+ sizeof(kTestResponsesSameAnswers),
+ 0);
+
+ scoped_ptr<const RecordParsed> record1;
+ scoped_ptr<const RecordParsed> record2;
+ std::vector<const RecordParsed*> results;
+
+ record1 = RecordParsed::CreateFrom(&parser, default_time_);
+ record2 = RecordParsed::CreateFrom(&parser, default_time_);
+ base::TimeDelta ttl1 = base::TimeDelta::FromSeconds(record1->ttl());
+ base::TimeDelta ttl2 = base::TimeDelta::FromSeconds(record2->ttl());
+
+ EXPECT_EQ(base::Time(), cache_.next_expiration());
+ EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record2.Pass()));
+ EXPECT_EQ(default_time_ + ttl2, cache_.next_expiration());
+ EXPECT_EQ(MDnsCache::NoChange, cache_.UpdateDnsRecord(record1.Pass()));
+ EXPECT_EQ(default_time_ + ttl1, cache_.next_expiration());
+}
+
+// Test that the cache handles mDNS "goodbye" packets correctly, not adding the
+// records to the cache if they are not already there, and eventually removing
+// records from the cache if they are.
+TEST_F(MDnsCacheTest, GoodbyePacket) {
+ DnsRecordParser parser(kTestResponsesGoodbyePacket,
+ sizeof(kTestResponsesGoodbyePacket),
+ 0);
+
+ scoped_ptr<const RecordParsed> record_goodbye;
+ scoped_ptr<const RecordParsed> record_hello;
+ scoped_ptr<const RecordParsed> record_goodbye2;
+ std::vector<const RecordParsed*> results;
+
+ record_goodbye = RecordParsed::CreateFrom(&parser, default_time_);
+ record_hello = RecordParsed::CreateFrom(&parser, default_time_);
+ parser = DnsRecordParser(kTestResponsesGoodbyePacket,
+ sizeof(kTestResponsesGoodbyePacket),
+ 0);
+ record_goodbye2 = RecordParsed::CreateFrom(&parser, default_time_);
+
+ base::TimeDelta ttl = base::TimeDelta::FromSeconds(record_hello->ttl());
+
+ EXPECT_EQ(base::Time(), cache_.next_expiration());
+ EXPECT_EQ(MDnsCache::NoChange, cache_.UpdateDnsRecord(record_goodbye.Pass()));
+ EXPECT_EQ(base::Time(), cache_.next_expiration());
+ EXPECT_EQ(MDnsCache::RecordAdded,
+ cache_.UpdateDnsRecord(record_hello.Pass()));
+ EXPECT_EQ(default_time_ + ttl, cache_.next_expiration());
+ EXPECT_EQ(MDnsCache::NoChange,
+ cache_.UpdateDnsRecord(record_goodbye2.Pass()));
+ EXPECT_EQ(default_time_ + base::TimeDelta::FromSeconds(1),
+ cache_.next_expiration());
+}
+
+TEST_F(MDnsCacheTest, AnyRRType) {
+ DnsRecordParser parser(kTestResponseTwoRecords,
+ sizeof(kTestResponseTwoRecords),
+ 0);
+
+ scoped_ptr<const RecordParsed> record1;
+ scoped_ptr<const RecordParsed> record2;
+ std::vector<const RecordParsed*> results;
+
+ record1 = RecordParsed::CreateFrom(&parser, default_time_);
+ record2 = RecordParsed::CreateFrom(&parser, default_time_);
+ EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record1.Pass()));
+ EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record2.Pass()));
+
+ cache_.FindDnsRecords(0, "ghs.l.google.com", &results, default_time_);
+
+ EXPECT_EQ(2u, results.size());
+ EXPECT_EQ(default_time_, results.front()->time_created());
+
+ EXPECT_EQ("ghs.l.google.com", results[0]->name());
+ EXPECT_EQ("ghs.l.google.com", results[1]->name());
+ EXPECT_EQ(dns_protocol::kTypeA,
+ std::min(results[0]->type(), results[1]->type()));
+ EXPECT_EQ(dns_protocol::kTypeAAAA,
+ std::max(results[0]->type(), results[1]->type()));
+}
+
+TEST_F(MDnsCacheTest, RemoveRecord) {
+ DnsRecordParser parser(kT1ResponseDatagram, sizeof(kT1ResponseDatagram),
+ sizeof(dns_protocol::Header));
+ parser.SkipQuestion();
+
+ scoped_ptr<const RecordParsed> record1;
+ std::vector<const RecordParsed*> results;
+
+ record1 = RecordParsed::CreateFrom(&parser, default_time_);
+ EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record1.Pass()));
+
+ cache_.FindDnsRecords(dns_protocol::kTypeCNAME, "codereview.chromium.org",
+ &results, default_time_);
+
+ EXPECT_EQ(1u, results.size());
+
+ scoped_ptr<const RecordParsed> record_out =
+ cache_.RemoveRecord(results.front());
+
+ EXPECT_EQ(record_out.get(), results.front());
+
+ cache_.FindDnsRecords(dns_protocol::kTypeCNAME, "codereview.chromium.org",
+ &results, default_time_);
+
+ EXPECT_EQ(0u, results.size());
+}
+
+} // namespace net
diff --git a/chromium/net/dns/mdns_client.cc b/chromium/net/dns/mdns_client.cc
new file mode 100644
index 00000000000..631b01a706e
--- /dev/null
+++ b/chromium/net/dns/mdns_client.cc
@@ -0,0 +1,17 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/mdns_client.h"
+
+#include "net/dns/mdns_client_impl.h"
+
+namespace net {
+
+// static
+scoped_ptr<MDnsClient> MDnsClient::CreateDefault() {
+ return scoped_ptr<MDnsClient>(
+ new MDnsClientImpl(MDnsConnection::SocketFactory::CreateDefault()));
+}
+
+} // namespace net
diff --git a/chromium/net/dns/mdns_client.h b/chromium/net/dns/mdns_client.h
new file mode 100644
index 00000000000..f12c6e292cd
--- /dev/null
+++ b/chromium/net/dns/mdns_client.h
@@ -0,0 +1,158 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_MDNS_CLIENT_H_
+#define NET_DNS_MDNS_CLIENT_H_
+
+#include <string>
+#include <vector>
+
+#include "base/callback.h"
+#include "net/dns/dns_query.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/record_parsed.h"
+
+namespace net {
+
+class RecordParsed;
+
+// Represents a one-time record lookup. A transaction takes one
+// associated callback (see |MDnsClient::CreateTransaction|) and calls it
+// whenever a matching record has been found, either from the cache or
+// by querying the network (it may choose to query either or both based on its
+// creation flags, see MDnsTransactionFlags). Network-based transactions will
+// time out after a reasonable number of seconds.
+class NET_EXPORT MDnsTransaction {
+ public:
+ // Used to signify what type of result the transaction has recieved.
+ enum Result {
+ // Passed whenever a record is found.
+ RESULT_RECORD,
+ // The transaction is done. Applies to non-single-valued transactions. Is
+ // called when the transaction has finished (this is the last call to the
+ // callback).
+ RESULT_DONE,
+ // No results have been found. Applies to single-valued transactions. Is
+ // called when the transaction has finished without finding any results.
+ // For transactions that use the network, this happens when a timeout
+ // occurs, for transactions that are cache-only, this happens when no
+ // results are in the cache.
+ RESULT_NO_RESULTS,
+ // Called when an NSec record is read for this transaction's
+ // query. This means there cannot possibly be a record of the type
+ // and name for this transaction.
+ RESULT_NSEC
+ };
+
+ // Used when creating an MDnsTransaction.
+ enum Flags {
+ // Transaction should return only one result, and stop listening after it.
+ // Note that single result transactions will signal when their timeout is
+ // reached, whereas multi-result transactions will not.
+ SINGLE_RESULT = 1 << 0,
+ // Query the cache or the network. May both be used. One must be present.
+ QUERY_CACHE = 1 << 1,
+ QUERY_NETWORK = 1 << 2,
+ // TODO(noamsml): Add flag for flushing cache when feature is implemented
+ // Mask of all possible flags on MDnsTransaction.
+ FLAG_MASK = (1 << 3) - 1,
+ };
+
+ typedef base::Callback<void(Result, const RecordParsed*)>
+ ResultCallback;
+
+ // Destroying the transaction cancels it.
+ virtual ~MDnsTransaction() {}
+
+ // Start the transaction. Return true on success. Cache-based transactions
+ // will execute the callback synchronously.
+ virtual bool Start() = 0;
+
+ // Get the host or service name for the transaction.
+ virtual const std::string& GetName() const = 0;
+
+ // Get the type for this transaction (SRV, TXT, A, AAA, etc)
+ virtual uint16 GetType() const = 0;
+};
+
+// A listener listens for updates regarding a specific record or set of records.
+// Created by the MDnsClient (see |MDnsClient::CreateListener|) and used to keep
+// track of listeners.
+class NET_EXPORT MDnsListener {
+ public:
+ // Used in the MDnsListener delegate to signify what type of change has been
+ // made to a record.
+ enum UpdateType {
+ RECORD_ADDED,
+ RECORD_CHANGED,
+ RECORD_REMOVED
+ };
+
+ class Delegate {
+ public:
+ virtual ~Delegate() {}
+
+ // Called when a record is added, removed or updated.
+ virtual void OnRecordUpdate(UpdateType update,
+ const RecordParsed* record) = 0;
+
+ // Called when a record is marked nonexistent by an NSEC record.
+ virtual void OnNsecRecord(const std::string& name, unsigned type) = 0;
+
+ // Called when the cache is purged (due, for example, ot the network
+ // disconnecting).
+ virtual void OnCachePurged() = 0;
+ };
+
+ // Destroying the listener stops listening.
+ virtual ~MDnsListener() {}
+
+ // Start the listener. Return true on success.
+ virtual bool Start() = 0;
+
+ // Get the host or service name for this query.
+ // Return an empty string for no name.
+ virtual const std::string& GetName() const = 0;
+
+ // Get the type for this query (SRV, TXT, A, AAA, etc)
+ virtual uint16 GetType() const = 0;
+};
+
+// Listens for Multicast DNS on the local network. You can access information
+// regarding multicast DNS either by creating an |MDnsListener| to be notified
+// of new records, or by creating an |MDnsTransaction| to look up the value of a
+// specific records. When all listeners and active transactions are destroyed,
+// the client stops listening on the network and destroys the cache.
+class NET_EXPORT MDnsClient {
+ public:
+ virtual ~MDnsClient() {}
+
+ // Create listener object for RRType |rrtype| and name |name|.
+ virtual scoped_ptr<MDnsListener> CreateListener(
+ uint16 rrtype,
+ const std::string& name,
+ MDnsListener::Delegate* delegate) = 0;
+
+ // Create a transaction that can be used to query either the MDns cache, the
+ // network, or both for records of type |rrtype| and name |name|. |flags| is
+ // defined by MDnsTransactionFlags.
+ virtual scoped_ptr<MDnsTransaction> CreateTransaction(
+ uint16 rrtype,
+ const std::string& name,
+ int flags,
+ const MDnsTransaction::ResultCallback& callback) = 0;
+
+ virtual bool StartListening() = 0;
+
+ // Do not call this inside callbacks from related MDnsListener and
+ // MDnsTransaction objects.
+ virtual void StopListening() = 0;
+ virtual bool IsListening() const = 0;
+
+ // Create the default MDnsClient
+ static scoped_ptr<MDnsClient> CreateDefault();
+};
+
+} // namespace net
+#endif // NET_DNS_MDNS_CLIENT_H_
diff --git a/chromium/net/dns/mdns_client_impl.cc b/chromium/net/dns/mdns_client_impl.cc
new file mode 100644
index 00000000000..8f79edf4cdf
--- /dev/null
+++ b/chromium/net/dns/mdns_client_impl.cc
@@ -0,0 +1,671 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/mdns_client_impl.h"
+
+#include "base/bind.h"
+#include "base/message_loop/message_loop_proxy.h"
+#include "base/stl_util.h"
+#include "base/time/default_clock.h"
+#include "net/base/dns_util.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/rand_callback.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/record_rdata.h"
+#include "net/udp/datagram_socket.h"
+
+// TODO(gene): Remove this temporary method of disabling NSEC support once it
+// becomes clear whether this feature should be
+// supported. http://crbug.com/255232
+#define ENABLE_NSEC
+
+namespace net {
+
+namespace {
+const char kMDnsMulticastGroupIPv4[] = "224.0.0.251";
+const char kMDnsMulticastGroupIPv6[] = "FF02::FB";
+const unsigned MDnsTransactionTimeoutSeconds = 3;
+}
+
+MDnsConnection::SocketHandler::SocketHandler(
+ MDnsConnection* connection, const IPEndPoint& multicast_addr,
+ MDnsConnection::SocketFactory* socket_factory)
+ : socket_(socket_factory->CreateSocket()), connection_(connection),
+ response_(new DnsResponse(dns_protocol::kMaxMulticastSize)),
+ multicast_addr_(multicast_addr) {
+}
+
+MDnsConnection::SocketHandler::~SocketHandler() {
+}
+
+int MDnsConnection::SocketHandler::Start() {
+ int rv = BindSocket();
+ if (rv != OK) {
+ return rv;
+ }
+
+ return DoLoop(0);
+}
+
+int MDnsConnection::SocketHandler::DoLoop(int rv) {
+ do {
+ if (rv > 0)
+ connection_->OnDatagramReceived(response_.get(), recv_addr_, rv);
+
+ rv = socket_->RecvFrom(
+ response_->io_buffer(),
+ response_->io_buffer()->size(),
+ &recv_addr_,
+ base::Bind(&MDnsConnection::SocketHandler::OnDatagramReceived,
+ base::Unretained(this)));
+ } while (rv > 0);
+
+ if (rv != ERR_IO_PENDING)
+ return rv;
+
+ return OK;
+}
+
+void MDnsConnection::SocketHandler::OnDatagramReceived(int rv) {
+ if (rv >= OK)
+ rv = DoLoop(rv);
+
+ if (rv != OK)
+ connection_->OnError(this, rv);
+}
+
+int MDnsConnection::SocketHandler::Send(IOBuffer* buffer, unsigned size) {
+ return socket_->SendTo(
+ buffer, size, multicast_addr_,
+ base::Bind(&MDnsConnection::SocketHandler::SendDone,
+ base::Unretained(this) ));
+}
+
+void MDnsConnection::SocketHandler::SendDone(int rv) {
+ // TODO(noamsml): Retry logic.
+}
+
+int MDnsConnection::SocketHandler::BindSocket() {
+ IPAddressNumber address_any(multicast_addr_.address().size());
+
+ IPEndPoint bind_endpoint(address_any, multicast_addr_.port());
+
+ socket_->AllowAddressReuse();
+ int rv = socket_->Listen(bind_endpoint);
+
+ if (rv < OK) return rv;
+
+ socket_->SetMulticastLoopbackMode(false);
+
+ return socket_->JoinGroup(multicast_addr_.address());
+}
+
+MDnsConnection::MDnsConnection(MDnsConnection::SocketFactory* socket_factory,
+ MDnsConnection::Delegate* delegate)
+ : socket_handler_ipv4_(this,
+ GetMDnsIPEndPoint(kMDnsMulticastGroupIPv4),
+ socket_factory),
+ socket_handler_ipv6_(this,
+ GetMDnsIPEndPoint(kMDnsMulticastGroupIPv6),
+ socket_factory),
+ delegate_(delegate) {
+}
+
+MDnsConnection::~MDnsConnection() {
+}
+
+int MDnsConnection::Init() {
+ int rv;
+
+ rv = socket_handler_ipv4_.Start();
+ if (rv != OK) return rv;
+ rv = socket_handler_ipv6_.Start();
+ if (rv != OK) return rv;
+
+ return OK;
+}
+
+int MDnsConnection::Send(IOBuffer* buffer, unsigned size) {
+ int rv;
+
+ rv = socket_handler_ipv4_.Send(buffer, size);
+ if (rv < OK && rv != ERR_IO_PENDING) return rv;
+
+ rv = socket_handler_ipv6_.Send(buffer, size);
+ if (rv < OK && rv != ERR_IO_PENDING) return rv;
+
+ return OK;
+}
+
+void MDnsConnection::OnError(SocketHandler* loop,
+ int error) {
+ // TODO(noamsml): Specific handling of intermittent errors that can be handled
+ // in the connection.
+ delegate_->OnConnectionError(error);
+}
+
+IPEndPoint MDnsConnection::GetMDnsIPEndPoint(const char* address) {
+ IPAddressNumber multicast_group_number;
+ bool success = ParseIPLiteralToNumber(address,
+ &multicast_group_number);
+ DCHECK(success);
+ return IPEndPoint(multicast_group_number,
+ dns_protocol::kDefaultPortMulticast);
+}
+
+void MDnsConnection::OnDatagramReceived(
+ DnsResponse* response,
+ const IPEndPoint& recv_addr,
+ int bytes_read) {
+ // TODO(noamsml): More sophisticated error handling.
+ DCHECK_GT(bytes_read, 0);
+ delegate_->HandlePacket(response, bytes_read);
+}
+
+class MDnsConnectionSocketFactoryImpl
+ : public MDnsConnection::SocketFactory {
+ public:
+ MDnsConnectionSocketFactoryImpl();
+ virtual ~MDnsConnectionSocketFactoryImpl();
+
+ virtual scoped_ptr<DatagramServerSocket> CreateSocket() OVERRIDE;
+};
+
+MDnsConnectionSocketFactoryImpl::MDnsConnectionSocketFactoryImpl() {
+}
+
+MDnsConnectionSocketFactoryImpl::~MDnsConnectionSocketFactoryImpl() {
+}
+
+scoped_ptr<DatagramServerSocket>
+MDnsConnectionSocketFactoryImpl::CreateSocket() {
+ return scoped_ptr<DatagramServerSocket>(new UDPServerSocket(
+ NULL, NetLog::Source()));
+}
+
+// static
+scoped_ptr<MDnsConnection::SocketFactory>
+MDnsConnection::SocketFactory::CreateDefault() {
+ return scoped_ptr<MDnsConnection::SocketFactory>(
+ new MDnsConnectionSocketFactoryImpl);
+}
+
+MDnsClientImpl::Core::Core(MDnsClientImpl* client,
+ MDnsConnection::SocketFactory* socket_factory)
+ : client_(client), connection_(new MDnsConnection(socket_factory, this)) {
+}
+
+MDnsClientImpl::Core::~Core() {
+ STLDeleteValues(&listeners_);
+}
+
+bool MDnsClientImpl::Core::Init() {
+ return connection_->Init() == OK;
+}
+
+bool MDnsClientImpl::Core::SendQuery(uint16 rrtype, std::string name) {
+ std::string name_dns;
+ if (!DNSDomainFromDot(name, &name_dns))
+ return false;
+
+ DnsQuery query(0, name_dns, rrtype);
+ query.set_flags(0); // Remove the RD flag from the query. It is unneeded.
+
+ return connection_->Send(query.io_buffer(), query.io_buffer()->size()) == OK;
+}
+
+void MDnsClientImpl::Core::HandlePacket(DnsResponse* response,
+ int bytes_read) {
+ unsigned offset;
+ // Note: We store cache keys rather than record pointers to avoid
+ // erroneous behavior in case a packet contains multiple exclusive
+ // records with the same type and name.
+ std::map<MDnsCache::Key, MDnsListener::UpdateType> update_keys;
+
+ if (!response->InitParseWithoutQuery(bytes_read)) {
+ LOG(WARNING) << "Could not understand an mDNS packet.";
+ return; // Message is unreadable.
+ }
+
+ // TODO(noamsml): duplicate query suppression.
+ if (!(response->flags() & dns_protocol::kFlagResponse))
+ return; // Message is a query. ignore it.
+
+ DnsRecordParser parser = response->Parser();
+ unsigned answer_count = response->answer_count() +
+ response->additional_answer_count();
+
+ for (unsigned i = 0; i < answer_count; i++) {
+ offset = parser.GetOffset();
+ scoped_ptr<const RecordParsed> record = RecordParsed::CreateFrom(
+ &parser, base::Time::Now());
+
+ if (!record) {
+ LOG(WARNING) << "Could not understand an mDNS record.";
+
+ if (offset == parser.GetOffset()) {
+ LOG(WARNING) << "Abandoned parsing the rest of the packet.";
+ return; // The parser did not advance, abort reading the packet.
+ } else {
+ continue; // We may be able to extract other records from the packet.
+ }
+ }
+
+ if ((record->klass() & dns_protocol::kMDnsClassMask) !=
+ dns_protocol::kClassIN) {
+ LOG(WARNING) << "Received an mDNS record with non-IN class. Ignoring.";
+ continue; // Ignore all records not in the IN class.
+ }
+
+ MDnsCache::Key update_key = MDnsCache::Key::CreateFor(record.get());
+ MDnsCache::UpdateType update = cache_.UpdateDnsRecord(record.Pass());
+
+ // Cleanup time may have changed.
+ ScheduleCleanup(cache_.next_expiration());
+
+ if (update != MDnsCache::NoChange) {
+ MDnsListener::UpdateType update_external;
+
+ switch (update) {
+ case MDnsCache::RecordAdded:
+ update_external = MDnsListener::RECORD_ADDED;
+ break;
+ case MDnsCache::RecordChanged:
+ update_external = MDnsListener::RECORD_CHANGED;
+ break;
+ case MDnsCache::NoChange:
+ default:
+ NOTREACHED();
+ // Dummy assignment to suppress compiler warning.
+ update_external = MDnsListener::RECORD_CHANGED;
+ break;
+ }
+
+ update_keys.insert(std::make_pair(update_key, update_external));
+ }
+ }
+
+ for (std::map<MDnsCache::Key, MDnsListener::UpdateType>::iterator i =
+ update_keys.begin(); i != update_keys.end(); i++) {
+ const RecordParsed* record = cache_.LookupKey(i->first);
+ if (!record)
+ continue;
+
+ if (record->type() == dns_protocol::kTypeNSEC) {
+#if defined(ENABLE_NSEC)
+ NotifyNsecRecord(record);
+#endif
+ } else {
+ AlertListeners(i->second, ListenerKey(record->name(), record->type()),
+ record);
+ }
+ }
+}
+
+void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) {
+ DCHECK_EQ(dns_protocol::kTypeNSEC, record->type());
+ const NsecRecordRdata* rdata = record->rdata<NsecRecordRdata>();
+ DCHECK(rdata);
+
+ // Remove all cached records matching the nonexistent RR types.
+ std::vector<const RecordParsed*> records_to_remove;
+
+ cache_.FindDnsRecords(0, record->name(), &records_to_remove,
+ base::Time::Now());
+
+ for (std::vector<const RecordParsed*>::iterator i = records_to_remove.begin();
+ i != records_to_remove.end(); i++) {
+ if ((*i)->type() == dns_protocol::kTypeNSEC)
+ continue;
+ if (!rdata->GetBit((*i)->type())) {
+ scoped_ptr<const RecordParsed> record_removed = cache_.RemoveRecord((*i));
+ DCHECK(record_removed);
+ OnRecordRemoved(record_removed.get());
+ }
+ }
+
+ // Alert all listeners waiting for the nonexistent RR types.
+ ListenerMap::iterator i =
+ listeners_.upper_bound(ListenerKey(record->name(), 0));
+ for (; i != listeners_.end() && i->first.first == record->name(); i++) {
+ if (!rdata->GetBit(i->first.second)) {
+ FOR_EACH_OBSERVER(MDnsListenerImpl, *i->second, AlertNsecRecord());
+ }
+ }
+}
+
+void MDnsClientImpl::Core::OnConnectionError(int error) {
+ // TODO(noamsml): On connection error, recreate connection and flush cache.
+}
+
+void MDnsClientImpl::Core::AlertListeners(
+ MDnsListener::UpdateType update_type,
+ const ListenerKey& key,
+ const RecordParsed* record) {
+ ListenerMap::iterator listener_map_iterator = listeners_.find(key);
+ if (listener_map_iterator == listeners_.end()) return;
+
+ FOR_EACH_OBSERVER(MDnsListenerImpl, *listener_map_iterator->second,
+ AlertDelegate(update_type, record));
+}
+
+void MDnsClientImpl::Core::AddListener(
+ MDnsListenerImpl* listener) {
+ ListenerKey key(listener->GetName(), listener->GetType());
+ std::pair<ListenerMap::iterator, bool> observer_insert_result =
+ listeners_.insert(
+ make_pair(key, static_cast<ObserverList<MDnsListenerImpl>*>(NULL)));
+
+ // If an equivalent key does not exist, actually create the observer list.
+ if (observer_insert_result.second)
+ observer_insert_result.first->second = new ObserverList<MDnsListenerImpl>();
+
+ ObserverList<MDnsListenerImpl>* observer_list =
+ observer_insert_result.first->second;
+
+ observer_list->AddObserver(listener);
+}
+
+void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) {
+ ListenerKey key(listener->GetName(), listener->GetType());
+ ListenerMap::iterator observer_list_iterator = listeners_.find(key);
+
+ DCHECK(observer_list_iterator != listeners_.end());
+ DCHECK(observer_list_iterator->second->HasObserver(listener));
+
+ observer_list_iterator->second->RemoveObserver(listener);
+
+ // Remove the observer list from the map if it is empty
+ if (observer_list_iterator->second->size() == 0) {
+ // Schedule the actual removal for later in case the listener removal
+ // happens while iterating over the observer list.
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE, base::Bind(
+ &MDnsClientImpl::Core::CleanupObserverList, AsWeakPtr(), key));
+ }
+}
+
+void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey& key) {
+ ListenerMap::iterator found = listeners_.find(key);
+ if (found != listeners_.end() && found->second->size() == 0) {
+ delete found->second;
+ listeners_.erase(found);
+ }
+}
+
+void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup) {
+ // Cleanup is already scheduled, no need to do anything.
+ if (cleanup == scheduled_cleanup_) return;
+ scheduled_cleanup_ = cleanup;
+
+ // This cancels the previously scheduled cleanup.
+ cleanup_callback_.Reset(base::Bind(
+ &MDnsClientImpl::Core::DoCleanup, base::Unretained(this)));
+
+ // If |cleanup| is empty, then no cleanup necessary.
+ if (cleanup != base::Time()) {
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ cleanup_callback_.callback(),
+ cleanup - base::Time::Now());
+ }
+}
+
+void MDnsClientImpl::Core::DoCleanup() {
+ cache_.CleanupRecords(base::Time::Now(), base::Bind(
+ &MDnsClientImpl::Core::OnRecordRemoved, base::Unretained(this)));
+
+ ScheduleCleanup(cache_.next_expiration());
+}
+
+void MDnsClientImpl::Core::OnRecordRemoved(
+ const RecordParsed* record) {
+ AlertListeners(MDnsListener::RECORD_REMOVED,
+ ListenerKey(record->name(), record->type()), record);
+}
+
+void MDnsClientImpl::Core::QueryCache(
+ uint16 rrtype, const std::string& name,
+ std::vector<const RecordParsed*>* records) const {
+ cache_.FindDnsRecords(rrtype, name, records, base::Time::Now());
+}
+
+MDnsClientImpl::MDnsClientImpl(
+ scoped_ptr<MDnsConnection::SocketFactory> socket_factory)
+ : socket_factory_(socket_factory.Pass()) {
+}
+
+MDnsClientImpl::~MDnsClientImpl() {
+}
+
+bool MDnsClientImpl::StartListening() {
+ DCHECK(!core_.get());
+ core_.reset(new Core(this, socket_factory_.get()));
+ if (!core_->Init()) {
+ core_.reset();
+ return false;
+ }
+ return true;
+}
+
+void MDnsClientImpl::StopListening() {
+ core_.reset();
+}
+
+bool MDnsClientImpl::IsListening() const {
+ return core_.get() != NULL;
+}
+
+scoped_ptr<MDnsListener> MDnsClientImpl::CreateListener(
+ uint16 rrtype,
+ const std::string& name,
+ MDnsListener::Delegate* delegate) {
+ return scoped_ptr<net::MDnsListener>(
+ new MDnsListenerImpl(rrtype, name, delegate, this));
+}
+
+scoped_ptr<MDnsTransaction> MDnsClientImpl::CreateTransaction(
+ uint16 rrtype,
+ const std::string& name,
+ int flags,
+ const MDnsTransaction::ResultCallback& callback) {
+ return scoped_ptr<MDnsTransaction>(
+ new MDnsTransactionImpl(rrtype, name, flags, callback, this));
+}
+
+MDnsListenerImpl::MDnsListenerImpl(
+ uint16 rrtype,
+ const std::string& name,
+ MDnsListener::Delegate* delegate,
+ MDnsClientImpl* client)
+ : rrtype_(rrtype), name_(name), client_(client), delegate_(delegate),
+ started_(false) {
+}
+
+bool MDnsListenerImpl::Start() {
+ DCHECK(!started_);
+
+ started_ = true;
+
+ DCHECK(client_->core());
+ client_->core()->AddListener(this);
+
+ return true;
+}
+
+MDnsListenerImpl::~MDnsListenerImpl() {
+ if (started_) {
+ DCHECK(client_->core());
+ client_->core()->RemoveListener(this);
+ }
+}
+
+const std::string& MDnsListenerImpl::GetName() const {
+ return name_;
+}
+
+uint16 MDnsListenerImpl::GetType() const {
+ return rrtype_;
+}
+
+void MDnsListenerImpl::AlertDelegate(MDnsListener::UpdateType update_type,
+ const RecordParsed* record) {
+ DCHECK(started_);
+ delegate_->OnRecordUpdate(update_type, record);
+}
+
+void MDnsListenerImpl::AlertNsecRecord() {
+ DCHECK(started_);
+ delegate_->OnNsecRecord(name_, rrtype_);
+}
+
+MDnsTransactionImpl::MDnsTransactionImpl(
+ uint16 rrtype,
+ const std::string& name,
+ int flags,
+ const MDnsTransaction::ResultCallback& callback,
+ MDnsClientImpl* client)
+ : rrtype_(rrtype), name_(name), callback_(callback), client_(client),
+ started_(false), flags_(flags) {
+ DCHECK((flags_ & MDnsTransaction::FLAG_MASK) == flags_);
+ DCHECK(flags_ & MDnsTransaction::QUERY_CACHE ||
+ flags_ & MDnsTransaction::QUERY_NETWORK);
+}
+
+MDnsTransactionImpl::~MDnsTransactionImpl() {
+ timeout_.Cancel();
+}
+
+bool MDnsTransactionImpl::Start() {
+ DCHECK(!started_);
+ started_ = true;
+
+ base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr();
+ if (flags_ & MDnsTransaction::QUERY_CACHE) {
+ ServeRecordsFromCache();
+
+ if (!weak_this || !is_active()) return true;
+ }
+
+ if (flags_ & MDnsTransaction::QUERY_NETWORK) {
+ return QueryAndListen();
+ }
+
+ // If this is a cache only query, signal that the transaction is over
+ // immediately.
+ SignalTransactionOver();
+ return true;
+}
+
+const std::string& MDnsTransactionImpl::GetName() const {
+ return name_;
+}
+
+uint16 MDnsTransactionImpl::GetType() const {
+ return rrtype_;
+}
+
+void MDnsTransactionImpl::CacheRecordFound(const RecordParsed* record) {
+ DCHECK(started_);
+ OnRecordUpdate(MDnsListener::RECORD_ADDED, record);
+}
+
+void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result,
+ const RecordParsed* record) {
+ DCHECK(started_);
+ if (!is_active()) return;
+
+ // Ensure callback is run after touching all class state, so that
+ // the callback can delete the transaction.
+ MDnsTransaction::ResultCallback callback = callback_;
+
+ // Reset the transaction if it expects a single result, or if the result
+ // is a final one (everything except for a record).
+ if (flags_ & MDnsTransaction::SINGLE_RESULT ||
+ result != MDnsTransaction::RESULT_RECORD) {
+ Reset();
+ }
+
+ callback.Run(result, record);
+}
+
+void MDnsTransactionImpl::Reset() {
+ callback_.Reset();
+ listener_.reset();
+ timeout_.Cancel();
+}
+
+void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update,
+ const RecordParsed* record) {
+ DCHECK(started_);
+ if (update == MDnsListener::RECORD_ADDED ||
+ update == MDnsListener::RECORD_CHANGED)
+ TriggerCallback(MDnsTransaction::RESULT_RECORD, record);
+}
+
+void MDnsTransactionImpl::SignalTransactionOver() {
+ DCHECK(started_);
+ if (flags_ & MDnsTransaction::SINGLE_RESULT) {
+ TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS, NULL);
+ } else {
+ TriggerCallback(MDnsTransaction::RESULT_DONE, NULL);
+ }
+}
+
+void MDnsTransactionImpl::ServeRecordsFromCache() {
+ std::vector<const RecordParsed*> records;
+ base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr();
+
+ if (client_->core()) {
+ client_->core()->QueryCache(rrtype_, name_, &records);
+ for (std::vector<const RecordParsed*>::iterator i = records.begin();
+ i != records.end() && weak_this; ++i) {
+ weak_this->TriggerCallback(MDnsTransaction::RESULT_RECORD, *i);
+ }
+
+#if defined(ENABLE_NSEC)
+ if (records.empty()) {
+ DCHECK(weak_this);
+ client_->core()->QueryCache(dns_protocol::kTypeNSEC, name_, &records);
+ if (!records.empty()) {
+ const NsecRecordRdata* rdata =
+ records.front()->rdata<NsecRecordRdata>();
+ DCHECK(rdata);
+ if (!rdata->GetBit(rrtype_))
+ weak_this->TriggerCallback(MDnsTransaction::RESULT_NSEC, NULL);
+ }
+ }
+#endif
+ }
+}
+
+bool MDnsTransactionImpl::QueryAndListen() {
+ listener_ = client_->CreateListener(rrtype_, name_, this);
+ if (!listener_->Start())
+ return false;
+
+ DCHECK(client_->core());
+ if (!client_->core()->SendQuery(rrtype_, name_))
+ return false;
+
+ timeout_.Reset(base::Bind(&MDnsTransactionImpl::SignalTransactionOver,
+ AsWeakPtr()));
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE,
+ timeout_.callback(),
+ base::TimeDelta::FromSeconds(MDnsTransactionTimeoutSeconds));
+
+ return true;
+}
+
+void MDnsTransactionImpl::OnNsecRecord(const std::string& name, unsigned type) {
+ TriggerCallback(RESULT_NSEC, NULL);
+}
+
+void MDnsTransactionImpl::OnCachePurged() {
+ // TODO(noamsml): Cache purge situations not yet implemented
+}
+
+} // namespace net
diff --git a/chromium/net/dns/mdns_client_impl.h b/chromium/net/dns/mdns_client_impl.h
new file mode 100644
index 00000000000..9fe3f99e7dd
--- /dev/null
+++ b/chromium/net/dns/mdns_client_impl.h
@@ -0,0 +1,298 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_MDNS_CLIENT_IMPL_H_
+#define NET_DNS_MDNS_CLIENT_IMPL_H_
+
+#include <map>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "base/cancelable_callback.h"
+#include "base/observer_list.h"
+#include "net/base/io_buffer.h"
+#include "net/base/ip_endpoint.h"
+#include "net/dns/mdns_cache.h"
+#include "net/dns/mdns_client.h"
+#include "net/udp/datagram_server_socket.h"
+#include "net/udp/udp_server_socket.h"
+#include "net/udp/udp_socket.h"
+
+namespace net {
+
+// A connection to the network for multicast DNS clients. It reads data into
+// DnsResponse objects and alerts the delegate that a packet has been received.
+class NET_EXPORT_PRIVATE MDnsConnection {
+ public:
+ class SocketFactory {
+ public:
+ virtual ~SocketFactory() {}
+
+ virtual scoped_ptr<DatagramServerSocket> CreateSocket() = 0;
+
+ static scoped_ptr<SocketFactory> CreateDefault();
+ };
+
+ class Delegate {
+ public:
+ // Handle an mDNS packet buffered in |response| with a size of |bytes_read|.
+ virtual void HandlePacket(DnsResponse* response, int bytes_read) = 0;
+ virtual void OnConnectionError(int error) = 0;
+ virtual ~Delegate() {}
+ };
+
+ explicit MDnsConnection(SocketFactory* socket_factory,
+ MDnsConnection::Delegate* delegate);
+
+ virtual ~MDnsConnection();
+
+ int Init();
+ int Send(IOBuffer* buffer, unsigned size);
+
+ private:
+ class SocketHandler {
+ public:
+ SocketHandler(MDnsConnection* connection,
+ const IPEndPoint& multicast_addr,
+ SocketFactory* socket_factory);
+ ~SocketHandler();
+ int DoLoop(int rv);
+ int Start();
+
+ int Send(IOBuffer* buffer, unsigned size);
+
+ private:
+ int BindSocket();
+ void OnDatagramReceived(int rv);
+
+ // Callback for when sending a query has finished.
+ void SendDone(int rv);
+
+ scoped_ptr<DatagramServerSocket> socket_;
+
+ MDnsConnection* connection_;
+ IPEndPoint recv_addr_;
+ scoped_ptr<DnsResponse> response_;
+ IPEndPoint multicast_addr_;
+ };
+
+ // Callback for handling a datagram being received on either ipv4 or ipv6.
+ void OnDatagramReceived(DnsResponse* response,
+ const IPEndPoint& recv_addr,
+ int bytes_read);
+
+ void OnError(SocketHandler* loop, int error);
+
+ IPEndPoint GetMDnsIPEndPoint(const char* address);
+
+ SocketHandler socket_handler_ipv4_;
+ SocketHandler socket_handler_ipv6_;
+
+ Delegate* delegate_;
+
+ DISALLOW_COPY_AND_ASSIGN(MDnsConnection);
+};
+
+class MDnsListenerImpl;
+
+class NET_EXPORT_PRIVATE MDnsClientImpl : public MDnsClient {
+ public:
+ // The core object exists while the MDnsClient is listening, and is deleted
+ // whenever the number of listeners reaches zero. The deletion happens
+ // asychronously, so destroying the last listener does not immediately
+ // invalidate the core.
+ class Core : public base::SupportsWeakPtr<Core>, MDnsConnection::Delegate {
+ public:
+ Core(MDnsClientImpl* client,
+ MDnsConnection::SocketFactory* socket_factory);
+ virtual ~Core();
+
+ // Initialize the core. Returns true on success.
+ bool Init();
+
+ // Send a query with a specific rrtype and name. Returns true on success.
+ bool SendQuery(uint16 rrtype, std::string name);
+
+ // Add/remove a listener to the list of listeners.
+ void AddListener(MDnsListenerImpl* listener);
+ void RemoveListener(MDnsListenerImpl* listener);
+
+ // Query the cache for records of a specific type and name.
+ void QueryCache(uint16 rrtype, const std::string& name,
+ std::vector<const RecordParsed*>* records) const;
+
+ // Parse the response and alert relevant listeners.
+ virtual void HandlePacket(DnsResponse* response, int bytes_read) OVERRIDE;
+
+ virtual void OnConnectionError(int error) OVERRIDE;
+
+ private:
+ typedef std::pair<std::string, uint16> ListenerKey;
+ typedef std::map<ListenerKey, ObserverList<MDnsListenerImpl>* >
+ ListenerMap;
+
+ // Alert listeners of an update to the cache.
+ void AlertListeners(MDnsListener::UpdateType update_type,
+ const ListenerKey& key, const RecordParsed* record);
+
+ // Schedule a cache cleanup to a specific time, cancelling other cleanups.
+ void ScheduleCleanup(base::Time cleanup);
+
+ // Clean up the cache and schedule a new cleanup.
+ void DoCleanup();
+
+ // Callback for when a record is removed from the cache.
+ void OnRecordRemoved(const RecordParsed* record);
+
+ void NotifyNsecRecord(const RecordParsed* record);
+
+ // Delete and erase the observer list for |key|. Only deletes the observer
+ // list if is empty.
+ void CleanupObserverList(const ListenerKey& key);
+
+ ListenerMap listeners_;
+
+ MDnsClientImpl* client_;
+ MDnsCache cache_;
+
+ base::CancelableCallback<void()> cleanup_callback_;
+ base::Time scheduled_cleanup_;
+
+ scoped_ptr<MDnsConnection> connection_;
+
+ DISALLOW_COPY_AND_ASSIGN(Core);
+ };
+
+ explicit MDnsClientImpl(
+ scoped_ptr<MDnsConnection::SocketFactory> socket_factory_);
+ virtual ~MDnsClientImpl();
+
+ // MDnsClient implementation:
+ virtual scoped_ptr<MDnsListener> CreateListener(
+ uint16 rrtype,
+ const std::string& name,
+ MDnsListener::Delegate* delegate) OVERRIDE;
+
+ virtual scoped_ptr<MDnsTransaction> CreateTransaction(
+ uint16 rrtype,
+ const std::string& name,
+ int flags,
+ const MDnsTransaction::ResultCallback& callback) OVERRIDE;
+
+ virtual bool StartListening() OVERRIDE;
+ virtual void StopListening() OVERRIDE;
+ virtual bool IsListening() const OVERRIDE;
+
+ Core* core() { return core_.get(); }
+
+ private:
+ scoped_ptr<Core> core_;
+
+ scoped_ptr<MDnsConnection::SocketFactory> socket_factory_;
+
+ DISALLOW_COPY_AND_ASSIGN(MDnsClientImpl);
+};
+
+class MDnsListenerImpl : public MDnsListener,
+ public base::SupportsWeakPtr<MDnsListenerImpl> {
+ public:
+ MDnsListenerImpl(uint16 rrtype,
+ const std::string& name,
+ MDnsListener::Delegate* delegate,
+ MDnsClientImpl* client);
+
+ virtual ~MDnsListenerImpl();
+
+ // MDnsListener implementation:
+ virtual bool Start() OVERRIDE;
+
+ virtual const std::string& GetName() const OVERRIDE;
+
+ virtual uint16 GetType() const OVERRIDE;
+
+ MDnsListener::Delegate* delegate() { return delegate_; }
+
+ // Alert the delegate of a record update.
+ void AlertDelegate(MDnsListener::UpdateType update_type,
+ const RecordParsed* record_parsed);
+
+ // Alert the delegate of the existence of an Nsec record.
+ void AlertNsecRecord();
+
+ private:
+ uint16 rrtype_;
+ std::string name_;
+ MDnsClientImpl* client_;
+ MDnsListener::Delegate* delegate_;
+
+ bool started_;
+ DISALLOW_COPY_AND_ASSIGN(MDnsListenerImpl);
+};
+
+class MDnsTransactionImpl : public base::SupportsWeakPtr<MDnsTransactionImpl>,
+ public MDnsTransaction,
+ public MDnsListener::Delegate {
+ public:
+ MDnsTransactionImpl(uint16 rrtype,
+ const std::string& name,
+ int flags,
+ const MDnsTransaction::ResultCallback& callback,
+ MDnsClientImpl* client);
+ virtual ~MDnsTransactionImpl();
+
+ // MDnsTransaction implementation:
+ virtual bool Start() OVERRIDE;
+
+ virtual const std::string& GetName() const OVERRIDE;
+ virtual uint16 GetType() const OVERRIDE;
+
+ // MDnsListener::Delegate implementation:
+ virtual void OnRecordUpdate(MDnsListener::UpdateType update,
+ const RecordParsed* record) OVERRIDE;
+ virtual void OnNsecRecord(const std::string& name, unsigned type) OVERRIDE;
+
+ virtual void OnCachePurged() OVERRIDE;
+
+ private:
+ bool is_active() { return !callback_.is_null(); }
+
+ void Reset();
+
+ // Trigger the callback and reset all related variables.
+ void TriggerCallback(MDnsTransaction::Result result,
+ const RecordParsed* record);
+
+ // Internal callback for when a cache record is found.
+ void CacheRecordFound(const RecordParsed* record);
+
+ // Signal the transactionis over and release all related resources.
+ void SignalTransactionOver();
+
+ // Reads records from the cache and calls the callback for every
+ // record read.
+ void ServeRecordsFromCache();
+
+ // Send a query to the network and set up a timeout to time out the
+ // transaction. Returns false if it fails to start listening on the network
+ // or if it fails to send a query.
+ bool QueryAndListen();
+
+ uint16 rrtype_;
+ std::string name_;
+ MDnsTransaction::ResultCallback callback_;
+
+ scoped_ptr<MDnsListener> listener_;
+ base::CancelableCallback<void()> timeout_;
+
+ MDnsClientImpl* client_;
+
+ bool started_;
+ int flags_;
+
+ DISALLOW_COPY_AND_ASSIGN(MDnsTransactionImpl);
+};
+
+} // namespace net
+#endif // NET_DNS_MDNS_CLIENT_IMPL_H_
diff --git a/chromium/net/dns/mdns_client_unittest.cc b/chromium/net/dns/mdns_client_unittest.cc
new file mode 100644
index 00000000000..324b4dfbee0
--- /dev/null
+++ b/chromium/net/dns/mdns_client_unittest.cc
@@ -0,0 +1,1176 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <queue>
+
+#include "base/memory/ref_counted.h"
+#include "base/message_loop/message_loop.h"
+#include "net/base/rand_callback.h"
+#include "net/base/test_completion_callback.h"
+#include "net/dns/mdns_client_impl.h"
+#include "net/dns/mock_mdns_socket_factory.h"
+#include "net/dns/record_rdata.h"
+#include "net/udp/udp_client_socket.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+using ::testing::Invoke;
+using ::testing::InvokeWithoutArgs;
+using ::testing::StrictMock;
+using ::testing::NiceMock;
+using ::testing::Exactly;
+using ::testing::Return;
+using ::testing::SaveArg;
+using ::testing::_;
+
+namespace net {
+
+namespace {
+
+const uint8 kSamplePacket1[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x02, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x00, // TTL (4 bytes) is 1 second;
+ 0x00, 0x01,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0xc0, 0x0c,
+
+ // Answer 2
+ 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r',
+ 0xc0, 0x14, // Pointer to "._tcp.local"
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 49 seconds.
+ 0x24, 0x75,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0xc0, 0x32
+};
+
+const uint8 kCorruptedPacketBadQuestion[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x01, // One question
+ 0x00, 0x02, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Question is corrupted and cannot be read.
+ 0x99, 'h', 'e', 'l', 'l', 'o',
+ 0x00,
+ 0x00, 0x00,
+ 0x00, 0x00,
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x99, // RDLENGTH is impossible
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0xc0, 0x0c,
+
+ // Answer 2
+ 0x08, '_', 'p', 'r', // Useless trailing data.
+};
+
+const uint8 kCorruptedPacketUnsalvagable[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x02, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x99, // RDLENGTH is impossible
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0xc0, 0x0c,
+
+ // Answer 2
+ 0x08, '_', 'p', 'r', // Useless trailing data.
+};
+
+const uint8 kCorruptedPacketDoubleRecord[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x02, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Answer 1
+ 0x06, 'p', 'r', 'i', 'v', 'e', 't',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x04, // RDLENGTH is 4
+ 0x05, 0x03,
+ 0xc0, 0x0c,
+
+ // Answer 2 -- Same key
+ 0x06, 'p', 'r', 'i', 'v', 'e', 't',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x04, // RDLENGTH is 4
+ 0x02, 0x03,
+ 0x04, 0x05,
+};
+
+const uint8 kCorruptedPacketSalvagable[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x02, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x99, 'h', 'e', 'l', 'l', 'o', // Bad RDATA format.
+ 0xc0, 0x0c,
+
+ // Answer 2
+ 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r',
+ 0xc0, 0x14, // Pointer to "._tcp.local"
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 49 seconds.
+ 0x24, 0x75,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0xc0, 0x32
+};
+
+const uint8 kSamplePacket2[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x02, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'z', 'z', 'z', 'z', 'z',
+ 0xc0, 0x0c,
+
+ // Answer 2
+ 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r',
+ 0xc0, 0x14, // Pointer to "._tcp.local"
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'z', 'z', 'z', 'z', 'z',
+ 0xc0, 0x32
+};
+
+const uint8 kQueryPacketPrivet[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x00, 0x00, // No flags.
+ 0x00, 0x01, // One question.
+ 0x00, 0x00, // 0 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Question
+ // This part is echoed back from the respective query.
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+};
+
+const uint8 kSamplePacketAdditionalOnly[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x00, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x01, // 0 additional RRs
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'h', 'e', 'l', 'l', 'o',
+ 0xc0, 0x0c,
+};
+
+const uint8 kSamplePacketNsec[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x01, // 1 RR (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x2f, // TYPE is NSEC.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x06, // RDLENGTH is 6 bytes.
+ 0xc0, 0x0c,
+ 0x00, 0x02, 0x00, 0x08 // Only A record present
+};
+
+const uint8 kSamplePacketAPrivet[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x01, // 1 RR (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x01, // TYPE is A.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x04, // RDLENGTH is 4 bytes.
+ 0xc0, 0x0c,
+ 0x00, 0x02,
+};
+
+const uint8 kSamplePacketGoodbye[] = {
+ // Header
+ 0x00, 0x00, // ID is zeroed out
+ 0x81, 0x80, // Standard query response, RA, no error
+ 0x00, 0x00, // No questions (for simplicity)
+ 0x00, 0x01, // 2 RRs (answers)
+ 0x00, 0x00, // 0 authority RRs
+ 0x00, 0x00, // 0 additional RRs
+
+ // Answer 1
+ 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
+ 0x04, '_', 't', 'c', 'p',
+ 0x05, 'l', 'o', 'c', 'a', 'l',
+ 0x00,
+ 0x00, 0x0c, // TYPE is PTR.
+ 0x00, 0x01, // CLASS is IN.
+ 0x00, 0x00, // TTL (4 bytes) is zero;
+ 0x00, 0x00,
+ 0x00, 0x08, // RDLENGTH is 8 bytes.
+ 0x05, 'z', 'z', 'z', 'z', 'z',
+ 0xc0, 0x0c,
+};
+
+std::string MakeString(const uint8* data, unsigned size) {
+ return std::string(reinterpret_cast<const char*>(data), size);
+}
+
+class PtrRecordCopyContainer {
+ public:
+ PtrRecordCopyContainer() {}
+ ~PtrRecordCopyContainer() {}
+
+ bool is_set() const { return set_; }
+
+ void SaveWithDummyArg(int unused, const RecordParsed* value) {
+ Save(value);
+ }
+
+ void Save(const RecordParsed* value) {
+ set_ = true;
+ name_ = value->name();
+ ptrdomain_ = value->rdata<PtrRecordRdata>()->ptrdomain();
+ ttl_ = value->ttl();
+ }
+
+ bool IsRecordWith(std::string name, std::string ptrdomain) {
+ return set_ && name_ == name && ptrdomain_ == ptrdomain;
+ }
+
+ const std::string& name() { return name_; }
+ const std::string& ptrdomain() { return ptrdomain_; }
+ int ttl() { return ttl_; }
+
+ private:
+ bool set_;
+ std::string name_;
+ std::string ptrdomain_;
+ int ttl_;
+};
+
+class MDnsTest : public ::testing::Test {
+ public:
+ MDnsTest();
+ virtual ~MDnsTest();
+ virtual void SetUp() OVERRIDE;
+ virtual void TearDown() OVERRIDE;
+ void DeleteTransaction();
+ void DeleteBothListeners();
+ void RunFor(base::TimeDelta time_period);
+ void Stop();
+
+ MOCK_METHOD2(MockableRecordCallback, void(MDnsTransaction::Result result,
+ const RecordParsed* record));
+
+ MOCK_METHOD2(MockableRecordCallback2, void(MDnsTransaction::Result result,
+ const RecordParsed* record));
+
+
+ protected:
+ void ExpectPacket(const uint8* packet, unsigned size);
+ void SimulatePacketReceive(const uint8* packet, unsigned size);
+
+ scoped_ptr<MDnsClientImpl> test_client_;
+ IPEndPoint mdns_ipv4_endpoint_;
+ StrictMock<MockMDnsSocketFactory>* socket_factory_;
+
+ // Transactions and listeners that can be deleted by class methods for
+ // reentrancy tests.
+ scoped_ptr<MDnsTransaction> transaction_;
+ scoped_ptr<MDnsListener> listener1_;
+ scoped_ptr<MDnsListener> listener2_;
+};
+
+class MockListenerDelegate : public MDnsListener::Delegate {
+ public:
+ MOCK_METHOD2(OnRecordUpdate,
+ void(MDnsListener::UpdateType update,
+ const RecordParsed* records));
+ MOCK_METHOD2(OnNsecRecord, void(const std::string&, unsigned));
+ MOCK_METHOD0(OnCachePurged, void());
+};
+
+MDnsTest::MDnsTest() {
+ socket_factory_ = new StrictMock<MockMDnsSocketFactory>();
+ test_client_.reset(new MDnsClientImpl(
+ scoped_ptr<MDnsConnection::SocketFactory>(socket_factory_)));
+}
+
+MDnsTest::~MDnsTest() {
+}
+
+void MDnsTest::SetUp() {
+ test_client_->StartListening();
+}
+
+void MDnsTest::TearDown() {
+}
+
+void MDnsTest::SimulatePacketReceive(const uint8* packet, unsigned size) {
+ socket_factory_->SimulateReceive(packet, size);
+}
+
+void MDnsTest::ExpectPacket(const uint8* packet, unsigned size) {
+ EXPECT_CALL(*socket_factory_, OnSendTo(MakeString(packet, size)))
+ .Times(2);
+}
+
+void MDnsTest::DeleteTransaction() {
+ transaction_.reset();
+}
+
+void MDnsTest::DeleteBothListeners() {
+ listener1_.reset();
+ listener2_.reset();
+}
+
+void MDnsTest::RunFor(base::TimeDelta time_period) {
+ base::CancelableCallback<void()> callback(base::Bind(&MDnsTest::Stop,
+ base::Unretained(this)));
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE, callback.callback(), time_period);
+
+ base::MessageLoop::current()->Run();
+ callback.Cancel();
+}
+
+void MDnsTest::Stop() {
+ base::MessageLoop::current()->Quit();
+}
+
+TEST_F(MDnsTest, PassiveListeners) {
+ StrictMock<MockListenerDelegate> delegate_privet;
+ StrictMock<MockListenerDelegate> delegate_printer;
+
+ PtrRecordCopyContainer record_privet;
+ PtrRecordCopyContainer record_printer;
+
+ scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
+ scoped_ptr<MDnsListener> listener_printer = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer);
+
+ ASSERT_TRUE(listener_privet->Start());
+ ASSERT_TRUE(listener_printer->Start());
+
+ // Send the same packet twice to ensure no records are double-counted.
+
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(1))
+ .WillOnce(Invoke(
+ &record_privet,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(1))
+ .WillOnce(Invoke(
+ &record_printer,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+
+ EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
+ "hello._privet._tcp.local"));
+
+ EXPECT_TRUE(record_printer.IsRecordWith("_printer._tcp.local",
+ "hello._printer._tcp.local"));
+
+ listener_privet.reset();
+ listener_printer.reset();
+}
+
+TEST_F(MDnsTest, PassiveListenersCacheCleanup) {
+ StrictMock<MockListenerDelegate> delegate_privet;
+
+ PtrRecordCopyContainer record_privet;
+ PtrRecordCopyContainer record_privet2;
+
+ scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
+
+ ASSERT_TRUE(listener_privet->Start());
+
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(1))
+ .WillOnce(Invoke(
+ &record_privet,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+
+ EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
+ "hello._privet._tcp.local"));
+
+ // Expect record is removed when its TTL expires.
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
+ .Times(Exactly(1))
+ .WillOnce(DoAll(InvokeWithoutArgs(this, &MDnsTest::Stop),
+ Invoke(&record_privet2,
+ &PtrRecordCopyContainer::SaveWithDummyArg)));
+
+ RunFor(base::TimeDelta::FromSeconds(record_privet.ttl() + 1));
+
+ EXPECT_TRUE(record_privet2.IsRecordWith("_privet._tcp.local",
+ "hello._privet._tcp.local"));
+}
+
+TEST_F(MDnsTest, MalformedPacket) {
+ StrictMock<MockListenerDelegate> delegate_printer;
+
+ PtrRecordCopyContainer record_printer;
+
+ scoped_ptr<MDnsListener> listener_printer = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer);
+
+ ASSERT_TRUE(listener_printer->Start());
+
+ EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(1))
+ .WillOnce(Invoke(
+ &record_printer,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ // First, send unsalvagable packet to ensure we can deal with it.
+ SimulatePacketReceive(kCorruptedPacketUnsalvagable,
+ sizeof(kCorruptedPacketUnsalvagable));
+
+ // Regression test: send a packet where the question cannot be read.
+ SimulatePacketReceive(kCorruptedPacketBadQuestion,
+ sizeof(kCorruptedPacketBadQuestion));
+
+ // Then send salvagable packet to ensure we can extract useful records.
+ SimulatePacketReceive(kCorruptedPacketSalvagable,
+ sizeof(kCorruptedPacketSalvagable));
+
+ EXPECT_TRUE(record_printer.IsRecordWith("_printer._tcp.local",
+ "hello._printer._tcp.local"));
+}
+
+TEST_F(MDnsTest, TransactionWithEmptyCache) {
+ ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
+
+ scoped_ptr<MDnsTransaction> transaction_privet =
+ test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ ASSERT_TRUE(transaction_privet->Start());
+
+ PtrRecordCopyContainer record_privet;
+
+ EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
+ .Times(Exactly(1))
+ .WillOnce(Invoke(&record_privet,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+
+ EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
+ "hello._privet._tcp.local"));
+}
+
+TEST_F(MDnsTest, TransactionCacheOnlyNoResult) {
+ scoped_ptr<MDnsTransaction> transaction_privet =
+ test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ EXPECT_CALL(*this,
+ MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS, _))
+ .Times(Exactly(1));
+
+ ASSERT_TRUE(transaction_privet->Start());
+}
+
+TEST_F(MDnsTest, TransactionWithCache) {
+ // Listener to force the client to listen
+ StrictMock<MockListenerDelegate> delegate_irrelevant;
+ scoped_ptr<MDnsListener> listener_irrelevant = test_client_->CreateListener(
+ dns_protocol::kTypeA, "codereview.chromium.local",
+ &delegate_irrelevant);
+
+ ASSERT_TRUE(listener_irrelevant->Start());
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+
+
+ PtrRecordCopyContainer record_privet;
+
+ EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
+ .WillOnce(Invoke(&record_privet,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ scoped_ptr<MDnsTransaction> transaction_privet =
+ test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ ASSERT_TRUE(transaction_privet->Start());
+
+ EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
+ "hello._privet._tcp.local"));
+}
+
+TEST_F(MDnsTest, AdditionalRecords) {
+ StrictMock<MockListenerDelegate> delegate_privet;
+
+ PtrRecordCopyContainer record_privet;
+
+ scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ &delegate_privet);
+
+ ASSERT_TRUE(listener_privet->Start());
+
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(1))
+ .WillOnce(Invoke(
+ &record_privet,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ SimulatePacketReceive(kSamplePacketAdditionalOnly,
+ sizeof(kSamplePacketAdditionalOnly));
+
+ EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
+ "hello._privet._tcp.local"));
+}
+
+TEST_F(MDnsTest, TransactionTimeout) {
+ ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
+
+ scoped_ptr<MDnsTransaction> transaction_privet =
+ test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ ASSERT_TRUE(transaction_privet->Start());
+
+ EXPECT_CALL(*this,
+ MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS, NULL))
+ .Times(Exactly(1))
+ .WillOnce(InvokeWithoutArgs(this, &MDnsTest::Stop));
+
+ RunFor(base::TimeDelta::FromSeconds(4));
+}
+
+TEST_F(MDnsTest, TransactionMultipleRecords) {
+ ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
+
+ scoped_ptr<MDnsTransaction> transaction_privet =
+ test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE ,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ ASSERT_TRUE(transaction_privet->Start());
+
+ PtrRecordCopyContainer record_privet;
+ PtrRecordCopyContainer record_privet2;
+
+ EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
+ .Times(Exactly(2))
+ .WillOnce(Invoke(&record_privet,
+ &PtrRecordCopyContainer::SaveWithDummyArg))
+ .WillOnce(Invoke(&record_privet2,
+ &PtrRecordCopyContainer::SaveWithDummyArg));
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+ SimulatePacketReceive(kSamplePacket2, sizeof(kSamplePacket2));
+
+ EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local",
+ "hello._privet._tcp.local"));
+
+ EXPECT_TRUE(record_privet2.IsRecordWith("_privet._tcp.local",
+ "zzzzz._privet._tcp.local"));
+
+ EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_DONE, NULL))
+ .WillOnce(InvokeWithoutArgs(this, &MDnsTest::Stop));
+
+ RunFor(base::TimeDelta::FromSeconds(4));
+}
+
+TEST_F(MDnsTest, TransactionReentrantDelete) {
+ ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
+
+ transaction_ = test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ ASSERT_TRUE(transaction_->Start());
+
+ EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS,
+ NULL))
+ .Times(Exactly(1))
+ .WillOnce(DoAll(InvokeWithoutArgs(this, &MDnsTest::DeleteTransaction),
+ InvokeWithoutArgs(this, &MDnsTest::Stop)));
+
+ RunFor(base::TimeDelta::FromSeconds(4));
+
+ EXPECT_EQ(NULL, transaction_.get());
+}
+
+TEST_F(MDnsTest, TransactionReentrantDeleteFromCache) {
+ StrictMock<MockListenerDelegate> delegate_irrelevant;
+ scoped_ptr<MDnsListener> listener_irrelevant = test_client_->CreateListener(
+ dns_protocol::kTypeA, "codereview.chromium.local",
+ &delegate_irrelevant);
+ ASSERT_TRUE(listener_irrelevant->Start());
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+
+ transaction_ = test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _))
+ .Times(Exactly(1))
+ .WillOnce(InvokeWithoutArgs(this, &MDnsTest::DeleteTransaction));
+
+ ASSERT_TRUE(transaction_->Start());
+
+ EXPECT_EQ(NULL, transaction_.get());
+}
+
+TEST_F(MDnsTest, TransactionReentrantCacheLookupStart) {
+ ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet));
+
+ scoped_ptr<MDnsTransaction> transaction1 = test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ scoped_ptr<MDnsTransaction> transaction2 = test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_printer._tcp.local",
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback2,
+ base::Unretained(this)));
+
+ EXPECT_CALL(*this, MockableRecordCallback2(MDnsTransaction::RESULT_RECORD,
+ _))
+ .Times(Exactly(1));
+
+ EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD,
+ _))
+ .Times(Exactly(1))
+ .WillOnce(IgnoreResult(InvokeWithoutArgs(transaction2.get(),
+ &MDnsTransaction::Start)));
+
+ ASSERT_TRUE(transaction1->Start());
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+}
+
+TEST_F(MDnsTest, GoodbyePacketNotification) {
+ StrictMock<MockListenerDelegate> delegate_privet;
+
+ scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
+ ASSERT_TRUE(listener_privet->Start());
+
+ SimulatePacketReceive(kSamplePacketGoodbye, sizeof(kSamplePacketGoodbye));
+
+ RunFor(base::TimeDelta::FromSeconds(2));
+}
+
+TEST_F(MDnsTest, GoodbyePacketRemoval) {
+ StrictMock<MockListenerDelegate> delegate_privet;
+
+ scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet);
+ ASSERT_TRUE(listener_privet->Start());
+
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(1));
+
+ SimulatePacketReceive(kSamplePacket2, sizeof(kSamplePacket2));
+
+ SimulatePacketReceive(kSamplePacketGoodbye, sizeof(kSamplePacketGoodbye));
+
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
+ .Times(Exactly(1));
+
+ RunFor(base::TimeDelta::FromSeconds(2));
+}
+
+// In order to reliably test reentrant listener deletes, we create two listeners
+// and have each of them delete both, so we're guaranteed to try and deliver a
+// callback to at least one deleted listener.
+
+TEST_F(MDnsTest, ListenerReentrantDelete) {
+ StrictMock<MockListenerDelegate> delegate_privet;
+
+ listener1_ = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ &delegate_privet);
+
+ listener2_ = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ &delegate_privet);
+
+ ASSERT_TRUE(listener1_->Start());
+
+ ASSERT_TRUE(listener2_->Start());
+
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(1))
+ .WillOnce(InvokeWithoutArgs(this, &MDnsTest::DeleteBothListeners));
+
+ SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1));
+
+ EXPECT_EQ(NULL, listener1_.get());
+ EXPECT_EQ(NULL, listener2_.get());
+}
+
+ACTION_P(SaveIPAddress, ip_container) {
+ ::testing::StaticAssertTypeEq<const RecordParsed*, arg1_type>();
+ ::testing::StaticAssertTypeEq<IPAddressNumber*, ip_container_type>();
+
+ *ip_container = arg1->template rdata<ARecordRdata>()->address();
+}
+
+TEST_F(MDnsTest, DoubleRecordDisagreeing) {
+ IPAddressNumber address;
+ StrictMock<MockListenerDelegate> delegate_privet;
+
+ scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
+ dns_protocol::kTypeA, "privet.local", &delegate_privet);
+
+ ASSERT_TRUE(listener_privet->Start());
+
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .Times(Exactly(1))
+ .WillOnce(SaveIPAddress(&address));
+
+ SimulatePacketReceive(kCorruptedPacketDoubleRecord,
+ sizeof(kCorruptedPacketDoubleRecord));
+
+ EXPECT_EQ("2.3.4.5", IPAddressToString(address));
+}
+
+TEST_F(MDnsTest, NsecWithListener) {
+ StrictMock<MockListenerDelegate> delegate_privet;
+ scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
+ dns_protocol::kTypeA, "_privet._tcp.local", &delegate_privet);
+
+ // Test to make sure nsec callback is NOT called for PTR
+ // (which is marked as existing).
+ StrictMock<MockListenerDelegate> delegate_privet2;
+ scoped_ptr<MDnsListener> listener_privet2 = test_client_->CreateListener(
+ dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet2);
+
+ ASSERT_TRUE(listener_privet->Start());
+
+ EXPECT_CALL(delegate_privet,
+ OnNsecRecord("_privet._tcp.local", dns_protocol::kTypeA));
+
+ SimulatePacketReceive(kSamplePacketNsec,
+ sizeof(kSamplePacketNsec));
+}
+
+TEST_F(MDnsTest, NsecWithTransactionFromNetwork) {
+ scoped_ptr<MDnsTransaction> transaction_privet =
+ test_client_->CreateTransaction(
+ dns_protocol::kTypeA, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ EXPECT_CALL(*socket_factory_, OnSendTo(_))
+ .Times(2);
+
+ ASSERT_TRUE(transaction_privet->Start());
+
+ EXPECT_CALL(*this,
+ MockableRecordCallback(MDnsTransaction::RESULT_NSEC, NULL));
+
+ SimulatePacketReceive(kSamplePacketNsec,
+ sizeof(kSamplePacketNsec));
+}
+
+TEST_F(MDnsTest, NsecWithTransactionFromCache) {
+ // Force mDNS to listen.
+ StrictMock<MockListenerDelegate> delegate_irrelevant;
+ scoped_ptr<MDnsListener> listener_irrelevant =
+ test_client_->CreateListener(dns_protocol::kTypePTR, "_privet._tcp.local",
+ &delegate_irrelevant);
+ listener_irrelevant->Start();
+
+ SimulatePacketReceive(kSamplePacketNsec,
+ sizeof(kSamplePacketNsec));
+
+ EXPECT_CALL(*this,
+ MockableRecordCallback(MDnsTransaction::RESULT_NSEC, NULL));
+
+ scoped_ptr<MDnsTransaction> transaction_privet_a =
+ test_client_->CreateTransaction(
+ dns_protocol::kTypeA, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ ASSERT_TRUE(transaction_privet_a->Start());
+
+ // Test that a PTR transaction does NOT consider the same NSEC record to be a
+ // valid answer to the query
+
+ scoped_ptr<MDnsTransaction> transaction_privet_ptr =
+ test_client_->CreateTransaction(
+ dns_protocol::kTypePTR, "_privet._tcp.local",
+ MDnsTransaction::QUERY_NETWORK |
+ MDnsTransaction::QUERY_CACHE |
+ MDnsTransaction::SINGLE_RESULT,
+ base::Bind(&MDnsTest::MockableRecordCallback,
+ base::Unretained(this)));
+
+ EXPECT_CALL(*socket_factory_, OnSendTo(_))
+ .Times(2);
+
+ ASSERT_TRUE(transaction_privet_ptr->Start());
+}
+
+TEST_F(MDnsTest, NsecConflictRemoval) {
+ StrictMock<MockListenerDelegate> delegate_privet;
+ scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener(
+ dns_protocol::kTypeA, "_privet._tcp.local", &delegate_privet);
+
+ ASSERT_TRUE(listener_privet->Start());
+
+ const RecordParsed* record1;
+ const RecordParsed* record2;
+
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _))
+ .WillOnce(SaveArg<1>(&record1));
+
+ SimulatePacketReceive(kSamplePacketAPrivet,
+ sizeof(kSamplePacketAPrivet));
+
+ EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _))
+ .WillOnce(SaveArg<1>(&record2));
+
+ EXPECT_CALL(delegate_privet,
+ OnNsecRecord("_privet._tcp.local", dns_protocol::kTypeA));
+
+ SimulatePacketReceive(kSamplePacketNsec,
+ sizeof(kSamplePacketNsec));
+
+ EXPECT_EQ(record1, record2);
+}
+
+
+// Note: These tests assume that the ipv4 socket will always be created first.
+// This is a simplifying assumption based on the way the code works now.
+
+class SimpleMockSocketFactory
+ : public MDnsConnection::SocketFactory {
+ public:
+ SimpleMockSocketFactory() {
+ }
+ virtual ~SimpleMockSocketFactory() {
+ }
+
+ virtual scoped_ptr<DatagramServerSocket> CreateSocket() OVERRIDE {
+ scoped_ptr<MockMDnsDatagramServerSocket> socket(
+ new StrictMock<MockMDnsDatagramServerSocket>);
+ sockets_.push(socket.get());
+ return socket.PassAs<DatagramServerSocket>();
+ }
+
+ MockMDnsDatagramServerSocket* PopFirstSocket() {
+ MockMDnsDatagramServerSocket* socket = sockets_.front();
+ sockets_.pop();
+ return socket;
+ }
+
+ size_t num_sockets() {
+ return sockets_.size();
+ }
+
+ private:
+ std::queue<MockMDnsDatagramServerSocket*> sockets_;
+};
+
+class MockMDnsConnectionDelegate : public MDnsConnection::Delegate {
+ public:
+ virtual void HandlePacket(DnsResponse* response, int size) {
+ HandlePacketInternal(std::string(response->io_buffer()->data(), size));
+ }
+
+ MOCK_METHOD1(HandlePacketInternal, void(std::string packet));
+
+ MOCK_METHOD1(OnConnectionError, void(int error));
+};
+
+class MDnsConnectionTest : public ::testing::Test {
+ public:
+ MDnsConnectionTest() : connection_(&factory_, &delegate_) {
+ }
+
+ protected:
+ // Follow successful connection initialization.
+ virtual void SetUp() OVERRIDE {
+ ASSERT_EQ(2u, factory_.num_sockets());
+
+ socket_ipv4_ = factory_.PopFirstSocket();
+ socket_ipv6_ = factory_.PopFirstSocket();
+ }
+
+ bool InitConnection() {
+ EXPECT_CALL(*socket_ipv4_, AllowAddressReuse());
+ EXPECT_CALL(*socket_ipv6_, AllowAddressReuse());
+
+ EXPECT_CALL(*socket_ipv4_, SetMulticastLoopbackMode(false));
+ EXPECT_CALL(*socket_ipv6_, SetMulticastLoopbackMode(false));
+
+ EXPECT_CALL(*socket_ipv4_, ListenInternal("0.0.0.0:5353"))
+ .WillOnce(Return(OK));
+ EXPECT_CALL(*socket_ipv6_, ListenInternal("[::]:5353"))
+ .WillOnce(Return(OK));
+
+ EXPECT_CALL(*socket_ipv4_, JoinGroupInternal("224.0.0.251"))
+ .WillOnce(Return(OK));
+ EXPECT_CALL(*socket_ipv6_, JoinGroupInternal("ff02::fb"))
+ .WillOnce(Return(OK));
+
+ return connection_.Init() == OK;
+ }
+
+ StrictMock<MockMDnsConnectionDelegate> delegate_;
+
+ MockMDnsDatagramServerSocket* socket_ipv4_;
+ MockMDnsDatagramServerSocket* socket_ipv6_;
+ SimpleMockSocketFactory factory_;
+ MDnsConnection connection_;
+ TestCompletionCallback callback_;
+};
+
+TEST_F(MDnsConnectionTest, ReceiveSynchronous) {
+ std::string sample_packet = MakeString(kSamplePacket1,
+ sizeof(kSamplePacket1));
+
+ socket_ipv6_->SetResponsePacket(sample_packet);
+ EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _))
+ .WillOnce(Return(ERR_IO_PENDING));
+ EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _))
+ .WillOnce(
+ Invoke(socket_ipv6_, &MockMDnsDatagramServerSocket::HandleRecvNow))
+ .WillOnce(Return(ERR_IO_PENDING));
+
+ EXPECT_CALL(delegate_, HandlePacketInternal(sample_packet));
+
+ ASSERT_TRUE(InitConnection());
+}
+
+TEST_F(MDnsConnectionTest, ReceiveAsynchronous) {
+ std::string sample_packet = MakeString(kSamplePacket1,
+ sizeof(kSamplePacket1));
+ socket_ipv6_->SetResponsePacket(sample_packet);
+ EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _))
+ .WillOnce(Return(ERR_IO_PENDING));
+ EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _))
+ .WillOnce(
+ Invoke(socket_ipv6_, &MockMDnsDatagramServerSocket::HandleRecvLater))
+ .WillOnce(Return(ERR_IO_PENDING));
+
+ ASSERT_TRUE(InitConnection());
+
+ EXPECT_CALL(delegate_, HandlePacketInternal(sample_packet));
+
+ base::MessageLoop::current()->RunUntilIdle();
+}
+
+TEST_F(MDnsConnectionTest, Send) {
+ std::string sample_packet = MakeString(kSamplePacket1,
+ sizeof(kSamplePacket1));
+
+ scoped_refptr<IOBufferWithSize> buf(
+ new IOBufferWithSize(sizeof kSamplePacket1));
+ memcpy(buf->data(), kSamplePacket1, sizeof(kSamplePacket1));
+
+ EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _))
+ .WillOnce(Return(ERR_IO_PENDING));
+ EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _))
+ .WillOnce(Return(ERR_IO_PENDING));
+
+ ASSERT_TRUE(InitConnection());
+
+ EXPECT_CALL(*socket_ipv4_,
+ SendToInternal(sample_packet, "224.0.0.251:5353", _));
+ EXPECT_CALL(*socket_ipv6_,
+ SendToInternal(sample_packet, "[ff02::fb]:5353", _));
+
+ connection_.Send(buf, buf->size());
+}
+
+TEST_F(MDnsConnectionTest, Error) {
+ CompletionCallback callback;
+
+ EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _))
+ .WillOnce(Return(ERR_IO_PENDING));
+ EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _))
+ .WillOnce(DoAll(SaveArg<3>(&callback), Return(ERR_IO_PENDING)));
+
+ ASSERT_TRUE(InitConnection());
+
+ EXPECT_CALL(delegate_, OnConnectionError(ERR_SOCKET_NOT_CONNECTED));
+ callback.Run(ERR_SOCKET_NOT_CONNECTED);
+}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/dns/mock_host_resolver.cc b/chromium/net/dns/mock_host_resolver.cc
new file mode 100644
index 00000000000..b3d1489c9d7
--- /dev/null
+++ b/chromium/net/dns/mock_host_resolver.cc
@@ -0,0 +1,408 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/mock_host_resolver.h"
+
+#include <string>
+#include <vector>
+
+#include "base/bind.h"
+#include "base/memory/ref_counted.h"
+#include "base/message_loop/message_loop.h"
+#include "base/stl_util.h"
+#include "base/strings/string_split.h"
+#include "base/strings/string_util.h"
+#include "base/threading/platform_thread.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+#include "net/base/test_completion_callback.h"
+#include "net/dns/host_cache.h"
+
+#if defined(OS_WIN)
+#include "net/base/winsock_init.h"
+#endif
+
+namespace net {
+
+namespace {
+
+// Cache size for the MockCachingHostResolver.
+const unsigned kMaxCacheEntries = 100;
+// TTL for the successful resolutions. Failures are not cached.
+const unsigned kCacheEntryTTLSeconds = 60;
+
+} // namespace
+
+int ParseAddressList(const std::string& host_list,
+ const std::string& canonical_name,
+ AddressList* addrlist) {
+ *addrlist = AddressList();
+ std::vector<std::string> addresses;
+ base::SplitString(host_list, ',', &addresses);
+ addrlist->set_canonical_name(canonical_name);
+ for (size_t index = 0; index < addresses.size(); ++index) {
+ IPAddressNumber ip_number;
+ if (!ParseIPLiteralToNumber(addresses[index], &ip_number)) {
+ LOG(WARNING) << "Not a supported IP literal: " << addresses[index];
+ return ERR_UNEXPECTED;
+ }
+ addrlist->push_back(IPEndPoint(ip_number, -1));
+ }
+ return OK;
+}
+
+struct MockHostResolverBase::Request {
+ Request(const RequestInfo& req_info,
+ AddressList* addr,
+ const CompletionCallback& cb)
+ : info(req_info), addresses(addr), callback(cb) {}
+ RequestInfo info;
+ AddressList* addresses;
+ CompletionCallback callback;
+};
+
+MockHostResolverBase::~MockHostResolverBase() {
+ STLDeleteValues(&requests_);
+}
+
+int MockHostResolverBase::Resolve(const RequestInfo& info,
+ AddressList* addresses,
+ const CompletionCallback& callback,
+ RequestHandle* handle,
+ const BoundNetLog& net_log) {
+ DCHECK(CalledOnValidThread());
+ num_resolve_++;
+ size_t id = next_request_id_++;
+ int rv = ResolveFromIPLiteralOrCache(info, addresses);
+ if (rv != ERR_DNS_CACHE_MISS) {
+ return rv;
+ }
+ if (synchronous_mode_) {
+ return ResolveProc(id, info, addresses);
+ }
+ // Store the request for asynchronous resolution
+ Request* req = new Request(info, addresses, callback);
+ requests_[id] = req;
+ if (handle)
+ *handle = reinterpret_cast<RequestHandle>(id);
+
+ if (!ondemand_mode_) {
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE,
+ base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id));
+ }
+
+ return ERR_IO_PENDING;
+}
+
+int MockHostResolverBase::ResolveFromCache(const RequestInfo& info,
+ AddressList* addresses,
+ const BoundNetLog& net_log) {
+ num_resolve_from_cache_++;
+ DCHECK(CalledOnValidThread());
+ next_request_id_++;
+ int rv = ResolveFromIPLiteralOrCache(info, addresses);
+ return rv;
+}
+
+void MockHostResolverBase::CancelRequest(RequestHandle handle) {
+ DCHECK(CalledOnValidThread());
+ size_t id = reinterpret_cast<size_t>(handle);
+ RequestMap::iterator it = requests_.find(id);
+ if (it != requests_.end()) {
+ scoped_ptr<Request> req(it->second);
+ requests_.erase(it);
+ } else {
+ NOTREACHED() << "CancelRequest must NOT be called after request is "
+ "complete or canceled.";
+ }
+}
+
+HostCache* MockHostResolverBase::GetHostCache() {
+ return cache_.get();
+}
+
+void MockHostResolverBase::ResolveAllPending() {
+ DCHECK(CalledOnValidThread());
+ DCHECK(ondemand_mode_);
+ for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i) {
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE,
+ base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), i->first));
+ }
+}
+
+// start id from 1 to distinguish from NULL RequestHandle
+MockHostResolverBase::MockHostResolverBase(bool use_caching)
+ : synchronous_mode_(false),
+ ondemand_mode_(false),
+ next_request_id_(1),
+ num_resolve_(0),
+ num_resolve_from_cache_(0) {
+ rules_ = CreateCatchAllHostResolverProc();
+
+ if (use_caching) {
+ cache_.reset(new HostCache(kMaxCacheEntries));
+ }
+}
+
+int MockHostResolverBase::ResolveFromIPLiteralOrCache(const RequestInfo& info,
+ AddressList* addresses) {
+ IPAddressNumber ip;
+ if (ParseIPLiteralToNumber(info.hostname(), &ip)) {
+ *addresses = AddressList::CreateFromIPAddress(ip, info.port());
+ if (info.host_resolver_flags() & HOST_RESOLVER_CANONNAME)
+ addresses->SetDefaultCanonicalName();
+ return OK;
+ }
+ int rv = ERR_DNS_CACHE_MISS;
+ if (cache_.get() && info.allow_cached_response()) {
+ HostCache::Key key(info.hostname(),
+ info.address_family(),
+ info.host_resolver_flags());
+ const HostCache::Entry* entry = cache_->Lookup(key, base::TimeTicks::Now());
+ if (entry) {
+ rv = entry->error;
+ if (rv == OK)
+ *addresses = AddressList::CopyWithPort(entry->addrlist, info.port());
+ }
+ }
+ return rv;
+}
+
+int MockHostResolverBase::ResolveProc(size_t id,
+ const RequestInfo& info,
+ AddressList* addresses) {
+ AddressList addr;
+ int rv = rules_->Resolve(info.hostname(),
+ info.address_family(),
+ info.host_resolver_flags(),
+ &addr,
+ NULL);
+ if (cache_.get()) {
+ HostCache::Key key(info.hostname(),
+ info.address_family(),
+ info.host_resolver_flags());
+ // Storing a failure with TTL 0 so that it overwrites previous value.
+ base::TimeDelta ttl;
+ if (rv == OK)
+ ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds);
+ cache_->Set(key, HostCache::Entry(rv, addr), base::TimeTicks::Now(), ttl);
+ }
+ if (rv == OK)
+ *addresses = AddressList::CopyWithPort(addr, info.port());
+ return rv;
+}
+
+void MockHostResolverBase::ResolveNow(size_t id) {
+ RequestMap::iterator it = requests_.find(id);
+ if (it == requests_.end())
+ return; // was canceled
+
+ scoped_ptr<Request> req(it->second);
+ requests_.erase(it);
+ int rv = ResolveProc(id, req->info, req->addresses);
+ if (!req->callback.is_null())
+ req->callback.Run(rv);
+}
+
+//-----------------------------------------------------------------------------
+
+RuleBasedHostResolverProc::RuleBasedHostResolverProc(HostResolverProc* previous)
+ : HostResolverProc(previous) {
+}
+
+void RuleBasedHostResolverProc::AddRule(const std::string& host_pattern,
+ const std::string& replacement) {
+ AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
+ replacement);
+}
+
+void RuleBasedHostResolverProc::AddRuleForAddressFamily(
+ const std::string& host_pattern,
+ AddressFamily address_family,
+ const std::string& replacement) {
+ DCHECK(!replacement.empty());
+ HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
+ HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
+ Rule rule(Rule::kResolverTypeSystem,
+ host_pattern,
+ address_family,
+ flags,
+ replacement,
+ std::string(),
+ 0);
+ rules_.push_back(rule);
+}
+
+void RuleBasedHostResolverProc::AddIPLiteralRule(
+ const std::string& host_pattern,
+ const std::string& ip_literal,
+ const std::string& canonical_name) {
+ // Literals are always resolved to themselves by HostResolverImpl,
+ // consequently we do not support remapping them.
+ IPAddressNumber ip_number;
+ DCHECK(!ParseIPLiteralToNumber(host_pattern, &ip_number));
+ HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
+ HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
+ if (!canonical_name.empty())
+ flags |= HOST_RESOLVER_CANONNAME;
+ Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
+ ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, canonical_name,
+ 0);
+ rules_.push_back(rule);
+}
+
+void RuleBasedHostResolverProc::AddRuleWithLatency(
+ const std::string& host_pattern,
+ const std::string& replacement,
+ int latency_ms) {
+ DCHECK(!replacement.empty());
+ HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
+ HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
+ Rule rule(Rule::kResolverTypeSystem,
+ host_pattern,
+ ADDRESS_FAMILY_UNSPECIFIED,
+ flags,
+ replacement,
+ std::string(),
+ latency_ms);
+ rules_.push_back(rule);
+}
+
+void RuleBasedHostResolverProc::AllowDirectLookup(
+ const std::string& host_pattern) {
+ HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
+ HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
+ Rule rule(Rule::kResolverTypeSystem,
+ host_pattern,
+ ADDRESS_FAMILY_UNSPECIFIED,
+ flags,
+ std::string(),
+ std::string(),
+ 0);
+ rules_.push_back(rule);
+}
+
+void RuleBasedHostResolverProc::AddSimulatedFailure(
+ const std::string& host_pattern) {
+ HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
+ HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
+ Rule rule(Rule::kResolverTypeFail,
+ host_pattern,
+ ADDRESS_FAMILY_UNSPECIFIED,
+ flags,
+ std::string(),
+ std::string(),
+ 0);
+ rules_.push_back(rule);
+}
+
+void RuleBasedHostResolverProc::ClearRules() {
+ rules_.clear();
+}
+
+int RuleBasedHostResolverProc::Resolve(const std::string& host,
+ AddressFamily address_family,
+ HostResolverFlags host_resolver_flags,
+ AddressList* addrlist,
+ int* os_error) {
+ RuleList::iterator r;
+ for (r = rules_.begin(); r != rules_.end(); ++r) {
+ bool matches_address_family =
+ r->address_family == ADDRESS_FAMILY_UNSPECIFIED ||
+ r->address_family == address_family;
+ // Flags match if all of the bitflags in host_resolver_flags are enabled
+ // in the rule's host_resolver_flags. However, the rule may have additional
+ // flags specified, in which case the flags should still be considered a
+ // match.
+ bool matches_flags = (r->host_resolver_flags & host_resolver_flags) ==
+ host_resolver_flags;
+ if (matches_flags && matches_address_family &&
+ MatchPattern(host, r->host_pattern)) {
+ if (r->latency_ms != 0) {
+ base::PlatformThread::Sleep(
+ base::TimeDelta::FromMilliseconds(r->latency_ms));
+ }
+
+ // Remap to a new host.
+ const std::string& effective_host =
+ r->replacement.empty() ? host : r->replacement;
+
+ // Apply the resolving function to the remapped hostname.
+ switch (r->resolver_type) {
+ case Rule::kResolverTypeFail:
+ return ERR_NAME_NOT_RESOLVED;
+ case Rule::kResolverTypeSystem:
+#if defined(OS_WIN)
+ net::EnsureWinsockInit();
+#endif
+ return SystemHostResolverCall(effective_host,
+ address_family,
+ host_resolver_flags,
+ addrlist, os_error);
+ case Rule::kResolverTypeIPLiteral:
+ return ParseAddressList(effective_host,
+ r->canonical_name,
+ addrlist);
+ default:
+ NOTREACHED();
+ return ERR_UNEXPECTED;
+ }
+ }
+ }
+ return ResolveUsingPrevious(host, address_family,
+ host_resolver_flags, addrlist, os_error);
+}
+
+RuleBasedHostResolverProc::~RuleBasedHostResolverProc() {
+}
+
+RuleBasedHostResolverProc* CreateCatchAllHostResolverProc() {
+ RuleBasedHostResolverProc* catchall = new RuleBasedHostResolverProc(NULL);
+ catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost");
+
+ // Next add a rules-based layer the use controls.
+ return new RuleBasedHostResolverProc(catchall);
+}
+
+//-----------------------------------------------------------------------------
+
+int HangingHostResolver::Resolve(const RequestInfo& info,
+ AddressList* addresses,
+ const CompletionCallback& callback,
+ RequestHandle* out_req,
+ const BoundNetLog& net_log) {
+ return ERR_IO_PENDING;
+}
+
+int HangingHostResolver::ResolveFromCache(const RequestInfo& info,
+ AddressList* addresses,
+ const BoundNetLog& net_log) {
+ return ERR_DNS_CACHE_MISS;
+}
+
+//-----------------------------------------------------------------------------
+
+ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() {}
+
+ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc(
+ HostResolverProc* proc) {
+ Init(proc);
+}
+
+ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() {
+ HostResolverProc* old_proc =
+ HostResolverProc::SetDefault(previous_proc_.get());
+ // The lifetimes of multiple instances must be nested.
+ CHECK_EQ(old_proc, current_proc_);
+}
+
+void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) {
+ current_proc_ = proc;
+ previous_proc_ = HostResolverProc::SetDefault(current_proc_.get());
+ current_proc_->SetLastProc(previous_proc_.get());
+}
+
+} // namespace net
diff --git a/chromium/net/dns/mock_host_resolver.h b/chromium/net/dns/mock_host_resolver.h
new file mode 100644
index 00000000000..282521cc3b9
--- /dev/null
+++ b/chromium/net/dns/mock_host_resolver.h
@@ -0,0 +1,284 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_MOCK_HOST_RESOLVER_H_
+#define NET_DNS_MOCK_HOST_RESOLVER_H_
+
+#include <list>
+#include <map>
+
+#include "base/memory/weak_ptr.h"
+#include "base/synchronization/waitable_event.h"
+#include "base/threading/non_thread_safe.h"
+#include "net/dns/host_resolver.h"
+#include "net/dns/host_resolver_proc.h"
+
+namespace net {
+
+class HostCache;
+class RuleBasedHostResolverProc;
+
+// Fills |*addrlist| with a socket address for |host_list| which should be a
+// comma-separated list of IPv4 or IPv6 literal(s) without enclosing brackets.
+// If |canonical_name| is non-empty it is used as the DNS canonical name for
+// the host. Returns OK on success, ERR_UNEXPECTED otherwise.
+int ParseAddressList(const std::string& host_list,
+ const std::string& canonical_name,
+ AddressList* addrlist);
+
+// In most cases, it is important that unit tests avoid relying on making actual
+// DNS queries since the resulting tests can be flaky, especially if the network
+// is unreliable for some reason. To simplify writing tests that avoid making
+// actual DNS queries, pass a MockHostResolver as the HostResolver dependency.
+// The socket addresses returned can be configured using the
+// RuleBasedHostResolverProc:
+//
+// host_resolver->rules()->AddRule("foo.com", "1.2.3.4");
+// host_resolver->rules()->AddRule("bar.com", "2.3.4.5");
+//
+// The above rules define a static mapping from hostnames to IP address
+// literals. The first parameter to AddRule specifies a host pattern to match
+// against, and the second parameter indicates what value should be used to
+// replace the given hostname. So, the following is also supported:
+//
+// host_mapper->AddRule("*.com", "127.0.0.1");
+//
+// Replacement doesn't have to be string representing an IP address. It can
+// re-map one hostname to another as well.
+//
+// By default, MockHostResolvers include a single rule that maps all hosts to
+// 127.0.0.1.
+
+// Base class shared by MockHostResolver and MockCachingHostResolver.
+class MockHostResolverBase : public HostResolver,
+ public base::SupportsWeakPtr<MockHostResolverBase>,
+ public base::NonThreadSafe {
+ public:
+ virtual ~MockHostResolverBase();
+
+ RuleBasedHostResolverProc* rules() { return rules_.get(); }
+ void set_rules(RuleBasedHostResolverProc* rules) { rules_ = rules; }
+
+ // Controls whether resolutions complete synchronously or asynchronously.
+ void set_synchronous_mode(bool is_synchronous) {
+ synchronous_mode_ = is_synchronous;
+ }
+
+ // Asynchronous requests are automatically resolved by default.
+ // If set_ondemand_mode() is set then Resolve() returns IO_PENDING and
+ // ResolveAllPending() must be explicitly invoked to resolve all requests
+ // that are pending.
+ void set_ondemand_mode(bool is_ondemand) {
+ ondemand_mode_ = is_ondemand;
+ }
+
+ // HostResolver methods:
+ virtual int Resolve(const RequestInfo& info,
+ AddressList* addresses,
+ const CompletionCallback& callback,
+ RequestHandle* out_req,
+ const BoundNetLog& net_log) OVERRIDE;
+ virtual int ResolveFromCache(const RequestInfo& info,
+ AddressList* addresses,
+ const BoundNetLog& net_log) OVERRIDE;
+ virtual void CancelRequest(RequestHandle req) OVERRIDE;
+ virtual HostCache* GetHostCache() OVERRIDE;
+
+ // Resolves all pending requests. It is only valid to invoke this if
+ // set_ondemand_mode was set before. The requests are resolved asynchronously,
+ // after this call returns.
+ void ResolveAllPending();
+
+ // Returns true if there are pending requests that can be resolved by invoking
+ // ResolveAllPending().
+ bool has_pending_requests() const { return !requests_.empty(); }
+
+ // The number of times that Resolve() has been called.
+ size_t num_resolve() const {
+ return num_resolve_;
+ }
+
+ // The number of times that ResolveFromCache() has been called.
+ size_t num_resolve_from_cache() const {
+ return num_resolve_from_cache_;
+ }
+
+ protected:
+ explicit MockHostResolverBase(bool use_caching);
+
+ private:
+ struct Request;
+ typedef std::map<size_t, Request*> RequestMap;
+
+ // Resolve as IP or from |cache_| return cached error or
+ // DNS_CACHE_MISS if failed.
+ int ResolveFromIPLiteralOrCache(const RequestInfo& info,
+ AddressList* addresses);
+ // Resolve via |proc_|.
+ int ResolveProc(size_t id, const RequestInfo& info, AddressList* addresses);
+ // Resolve request stored in |requests_|. Pass rv to callback.
+ void ResolveNow(size_t id);
+
+ bool synchronous_mode_;
+ bool ondemand_mode_;
+ scoped_refptr<RuleBasedHostResolverProc> rules_;
+ scoped_ptr<HostCache> cache_;
+ RequestMap requests_;
+ size_t next_request_id_;
+
+ size_t num_resolve_;
+ size_t num_resolve_from_cache_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockHostResolverBase);
+};
+
+class MockHostResolver : public MockHostResolverBase {
+ public:
+ MockHostResolver() : MockHostResolverBase(false /*use_caching*/) {}
+ virtual ~MockHostResolver() {}
+};
+
+// Same as MockHostResolver, except internally it uses a host-cache.
+//
+// Note that tests are advised to use MockHostResolver instead, since it is
+// more predictable. (MockHostResolver also can be put into synchronous
+// operation mode in case that is what you needed from the caching version).
+class MockCachingHostResolver : public MockHostResolverBase {
+ public:
+ MockCachingHostResolver() : MockHostResolverBase(true /*use_caching*/) {}
+ virtual ~MockCachingHostResolver() {}
+};
+
+// RuleBasedHostResolverProc applies a set of rules to map a host string to
+// a replacement host string. It then uses the system host resolver to return
+// a socket address. Generally the replacement should be an IPv4 literal so
+// there is no network dependency.
+class RuleBasedHostResolverProc : public HostResolverProc {
+ public:
+ explicit RuleBasedHostResolverProc(HostResolverProc* previous);
+
+ // Any hostname matching the given pattern will be replaced with the given
+ // replacement value. Usually, replacement should be an IP address literal.
+ void AddRule(const std::string& host_pattern,
+ const std::string& replacement);
+
+ // Same as AddRule(), but further restricts to |address_family|.
+ void AddRuleForAddressFamily(const std::string& host_pattern,
+ AddressFamily address_family,
+ const std::string& replacement);
+
+ // Same as AddRule(), but the replacement is expected to be an IPv4 or IPv6
+ // literal. This can be used in place of AddRule() to bypass the system's
+ // host resolver (the address list will be constructed manually).
+ // If |canonical_name| is non-empty, it is copied to the resulting AddressList
+ // but does not impact DNS resolution.
+ // |ip_literal| can be a single IP address like "192.168.1.1" or a comma
+ // separated list of IP addresses, like "::1,192:168.1.2".
+ void AddIPLiteralRule(const std::string& host_pattern,
+ const std::string& ip_literal,
+ const std::string& canonical_name);
+
+ void AddRuleWithLatency(const std::string& host_pattern,
+ const std::string& replacement,
+ int latency_ms);
+
+ // Make sure that |host| will not be re-mapped or even processed by underlying
+ // host resolver procedures. It can also be a pattern.
+ void AllowDirectLookup(const std::string& host);
+
+ // Simulate a lookup failure for |host| (it also can be a pattern).
+ void AddSimulatedFailure(const std::string& host);
+
+ // Deletes all the rules that have been added.
+ void ClearRules();
+
+ // HostResolverProc methods:
+ virtual int Resolve(const std::string& host,
+ AddressFamily address_family,
+ HostResolverFlags host_resolver_flags,
+ AddressList* addrlist,
+ int* os_error) OVERRIDE;
+
+ private:
+ struct Rule {
+ enum ResolverType {
+ kResolverTypeFail,
+ kResolverTypeSystem,
+ kResolverTypeIPLiteral,
+ };
+
+ ResolverType resolver_type;
+ std::string host_pattern;
+ AddressFamily address_family;
+ HostResolverFlags host_resolver_flags;
+ std::string replacement;
+ std::string canonical_name;
+ int latency_ms; // In milliseconds.
+
+ Rule(ResolverType resolver_type,
+ const std::string& host_pattern,
+ AddressFamily address_family,
+ HostResolverFlags host_resolver_flags,
+ const std::string& replacement,
+ const std::string& canonical_name,
+ int latency_ms)
+ : resolver_type(resolver_type),
+ host_pattern(host_pattern),
+ address_family(address_family),
+ host_resolver_flags(host_resolver_flags),
+ replacement(replacement),
+ canonical_name(canonical_name),
+ latency_ms(latency_ms) {}
+ };
+
+ typedef std::list<Rule> RuleList;
+
+ virtual ~RuleBasedHostResolverProc();
+
+ RuleList rules_;
+};
+
+// Create rules that map all requests to localhost.
+RuleBasedHostResolverProc* CreateCatchAllHostResolverProc();
+
+// HangingHostResolver never completes its |Resolve| request.
+class HangingHostResolver : public HostResolver {
+ public:
+ virtual int Resolve(const RequestInfo& info,
+ AddressList* addresses,
+ const CompletionCallback& callback,
+ RequestHandle* out_req,
+ const BoundNetLog& net_log) OVERRIDE;
+ virtual int ResolveFromCache(const RequestInfo& info,
+ AddressList* addresses,
+ const BoundNetLog& net_log) OVERRIDE;
+ virtual void CancelRequest(RequestHandle req) OVERRIDE {}
+};
+
+// This class sets the default HostResolverProc for a particular scope. The
+// chain of resolver procs starting at |proc| is placed in front of any existing
+// default resolver proc(s). This means that if multiple
+// ScopedDefaultHostResolverProcs are declared, then resolving will start with
+// the procs given to the last-allocated one, then fall back to the procs given
+// to the previously-allocated one, and so forth.
+//
+// NOTE: Only use this as a catch-all safety net. Individual tests should use
+// MockHostResolver.
+class ScopedDefaultHostResolverProc {
+ public:
+ ScopedDefaultHostResolverProc();
+ explicit ScopedDefaultHostResolverProc(HostResolverProc* proc);
+
+ ~ScopedDefaultHostResolverProc();
+
+ void Init(HostResolverProc* proc);
+
+ private:
+ scoped_refptr<HostResolverProc> current_proc_;
+ scoped_refptr<HostResolverProc> previous_proc_;
+};
+
+} // namespace net
+
+#endif // NET_DNS_MOCK_HOST_RESOLVER_H_
diff --git a/chromium/net/dns/mock_mdns_socket_factory.cc b/chromium/net/dns/mock_mdns_socket_factory.cc
new file mode 100644
index 00000000000..8c08c150cd1
--- /dev/null
+++ b/chromium/net/dns/mock_mdns_socket_factory.cc
@@ -0,0 +1,115 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <algorithm>
+
+#include "net/base/net_errors.h"
+#include "net/dns/mock_mdns_socket_factory.h"
+
+using testing::_;
+using testing::Invoke;
+
+namespace net {
+
+MockMDnsDatagramServerSocket::MockMDnsDatagramServerSocket() {
+}
+
+MockMDnsDatagramServerSocket::~MockMDnsDatagramServerSocket() {
+}
+
+int MockMDnsDatagramServerSocket::SendTo(IOBuffer* buf, int buf_len,
+ const IPEndPoint& address,
+ const CompletionCallback& callback) {
+ return SendToInternal(std::string(buf->data(), buf_len), address.ToString(),
+ callback);
+}
+
+int MockMDnsDatagramServerSocket::Listen(const IPEndPoint& address) {
+ return ListenInternal(address.ToString());
+}
+
+int MockMDnsDatagramServerSocket::JoinGroup(
+ const IPAddressNumber& group_address) const {
+ return JoinGroupInternal(IPAddressToString(group_address));
+}
+
+int MockMDnsDatagramServerSocket::LeaveGroup(
+ const IPAddressNumber& group_address) const {
+ return LeaveGroupInternal(IPAddressToString(group_address));
+}
+
+void MockMDnsDatagramServerSocket::SetResponsePacket(
+ std::string response_packet) {
+ response_packet_ = response_packet;
+}
+
+int MockMDnsDatagramServerSocket::HandleRecvNow(
+ IOBuffer* buffer, int size, IPEndPoint* address,
+ const CompletionCallback& callback) {
+ int size_returned =
+ std::min(response_packet_.size(), static_cast<size_t>(size));
+ memcpy(buffer->data(), response_packet_.data(), size_returned);
+ return size_returned;
+}
+
+int MockMDnsDatagramServerSocket::HandleRecvLater(
+ IOBuffer* buffer, int size, IPEndPoint* address,
+ const CompletionCallback& callback) {
+ int rv = HandleRecvNow(buffer, size, address, callback);
+ base::MessageLoop::current()->PostTask(FROM_HERE, base::Bind(callback, rv));
+ return ERR_IO_PENDING;
+}
+
+MockMDnsSocketFactory::MockMDnsSocketFactory() {
+}
+
+MockMDnsSocketFactory::~MockMDnsSocketFactory() {
+}
+
+scoped_ptr<DatagramServerSocket> MockMDnsSocketFactory::CreateSocket() {
+ scoped_ptr<MockMDnsDatagramServerSocket> new_socket(
+ new testing:: NiceMock<MockMDnsDatagramServerSocket>);
+
+ ON_CALL(*new_socket, SendToInternal(_, _, _))
+ .WillByDefault(Invoke(
+ this,
+ &MockMDnsSocketFactory::SendToInternal));
+
+ ON_CALL(*new_socket, RecvFrom(_, _, _, _))
+ .WillByDefault(Invoke(
+ this,
+ &MockMDnsSocketFactory::RecvFromInternal));
+
+ return new_socket.PassAs<DatagramServerSocket>();
+}
+
+void MockMDnsSocketFactory::SimulateReceive(const uint8* packet, int size) {
+ DCHECK(recv_buffer_size_ >= size);
+ DCHECK(recv_buffer_.get());
+ DCHECK(!recv_callback_.is_null());
+
+ memcpy(recv_buffer_->data(), packet, size);
+ CompletionCallback recv_callback = recv_callback_;
+ recv_callback_.Reset();
+ recv_callback.Run(size);
+}
+
+int MockMDnsSocketFactory::RecvFromInternal(
+ IOBuffer* buffer, int size,
+ IPEndPoint* address,
+ const CompletionCallback& callback) {
+ recv_buffer_ = buffer;
+ recv_buffer_size_ = size;
+ recv_callback_ = callback;
+ return ERR_IO_PENDING;
+}
+
+int MockMDnsSocketFactory::SendToInternal(
+ const std::string& packet, const std::string& address,
+ const CompletionCallback& callback) {
+ OnSendTo(packet);
+ return packet.size();
+}
+
+} // namespace net
diff --git a/chromium/net/dns/mock_mdns_socket_factory.h b/chromium/net/dns/mock_mdns_socket_factory.h
new file mode 100644
index 00000000000..f60b08c591b
--- /dev/null
+++ b/chromium/net/dns/mock_mdns_socket_factory.h
@@ -0,0 +1,101 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_MOCK_MDNS_SOCKET_FACTORY_H_
+#define NET_DNS_MOCK_MDNS_SOCKET_FACTORY_H_
+
+#include <string>
+
+#include "net/dns/mdns_client_impl.h"
+#include "testing/gmock/include/gmock/gmock.h"
+
+namespace net {
+
+class MockMDnsDatagramServerSocket : public DatagramServerSocket {
+ public:
+ MockMDnsDatagramServerSocket();
+ ~MockMDnsDatagramServerSocket();
+
+ // DatagramServerSocket implementation:
+ int Listen(const IPEndPoint& address);
+
+ MOCK_METHOD1(ListenInternal, int(const std::string& address));
+
+ MOCK_METHOD4(RecvFrom, int(IOBuffer* buffer, int size,
+ IPEndPoint* address,
+ const CompletionCallback& callback));
+
+ int SendTo(IOBuffer* buf, int buf_len, const IPEndPoint& address,
+ const CompletionCallback& callback);
+
+ MOCK_METHOD3(SendToInternal, int(const std::string& packet,
+ const std::string address,
+ const CompletionCallback& callback));
+
+ MOCK_METHOD1(SetReceiveBufferSize, bool(int32 size));
+ MOCK_METHOD1(SetSendBufferSize, bool(int32 size));
+
+ MOCK_METHOD0(Close, void());
+
+ MOCK_CONST_METHOD1(GetPeerAddress, int(IPEndPoint* address));
+ MOCK_CONST_METHOD1(GetLocalAddress, int(IPEndPoint* address));
+ MOCK_CONST_METHOD0(NetLog, const BoundNetLog&());
+
+ MOCK_METHOD0(AllowAddressReuse, void());
+ MOCK_METHOD0(AllowBroadcast, void());
+
+ int JoinGroup(const IPAddressNumber& group_address) const;
+
+ MOCK_CONST_METHOD1(JoinGroupInternal, int(const std::string& group));
+
+ int LeaveGroup(const IPAddressNumber& group_address) const;
+
+ MOCK_CONST_METHOD1(LeaveGroupInternal, int(const std::string& group));
+
+ MOCK_METHOD1(SetMulticastTimeToLive, int(int ttl));
+
+ MOCK_METHOD1(SetMulticastLoopbackMode, int(bool loopback));
+
+ void SetResponsePacket(std::string response_packet);
+
+ int HandleRecvNow(IOBuffer* buffer, int size, IPEndPoint* address,
+ const CompletionCallback& callback);
+
+ int HandleRecvLater(IOBuffer* buffer, int size, IPEndPoint* address,
+ const CompletionCallback& callback);
+
+ private:
+ std::string response_packet_;
+};
+
+class MockMDnsSocketFactory : public MDnsConnection::SocketFactory {
+ public:
+ MockMDnsSocketFactory();
+
+ virtual ~MockMDnsSocketFactory();
+
+ virtual scoped_ptr<DatagramServerSocket> CreateSocket() OVERRIDE;
+
+ void SimulateReceive(const uint8* packet, int size);
+
+ MOCK_METHOD1(OnSendTo, void(const std::string&));
+
+ private:
+ int SendToInternal(const std::string& packet, const std::string& address,
+ const CompletionCallback& callback);
+
+ // The latest receive callback is always saved, since the MDnsConnection
+ // does not care which socket a packet is received on.
+ int RecvFromInternal(IOBuffer* buffer, int size,
+ IPEndPoint* address,
+ const CompletionCallback& callback);
+
+ scoped_refptr<IOBuffer> recv_buffer_;
+ int recv_buffer_size_;
+ CompletionCallback recv_callback_;
+};
+
+} // namespace net
+
+#endif // NET_DNS_MOCK_MDNS_SOCKET_FACTORY_H_
diff --git a/chromium/net/dns/notify_watcher_mac.cc b/chromium/net/dns/notify_watcher_mac.cc
new file mode 100644
index 00000000000..286f18bb7bb
--- /dev/null
+++ b/chromium/net/dns/notify_watcher_mac.cc
@@ -0,0 +1,64 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/notify_watcher_mac.h"
+
+#include <notify.h>
+
+#include "base/logging.h"
+#include "base/posix/eintr_wrapper.h"
+
+namespace net {
+
+NotifyWatcherMac::NotifyWatcherMac() : notify_fd_(-1), notify_token_(-1) {}
+
+NotifyWatcherMac::~NotifyWatcherMac() {
+ Cancel();
+}
+
+bool NotifyWatcherMac::Watch(const char* key, const CallbackType& callback) {
+ DCHECK(key);
+ DCHECK(!callback.is_null());
+ Cancel();
+ uint32_t status = notify_register_file_descriptor(
+ key, &notify_fd_, 0, &notify_token_);
+ if (status != NOTIFY_STATUS_OK)
+ return false;
+ DCHECK_GE(notify_fd_, 0);
+ if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
+ notify_fd_,
+ true,
+ base::MessageLoopForIO::WATCH_READ,
+ &watcher_,
+ this)) {
+ Cancel();
+ return false;
+ }
+ callback_ = callback;
+ return true;
+}
+
+void NotifyWatcherMac::Cancel() {
+ if (notify_fd_ >= 0) {
+ notify_cancel(notify_token_); // Also closes |notify_fd_|.
+ notify_fd_ = -1;
+ callback_.Reset();
+ watcher_.StopWatchingFileDescriptor();
+ }
+}
+
+void NotifyWatcherMac::OnFileCanReadWithoutBlocking(int fd) {
+ int token;
+ int status = HANDLE_EINTR(read(notify_fd_, &token, sizeof(token)));
+ if (status != sizeof(token)) {
+ Cancel();
+ callback_.Run(false);
+ return;
+ }
+ // Ignoring |token| value to avoid possible endianness mismatch:
+ // http://openradar.appspot.com/8821081
+ callback_.Run(true);
+}
+
+} // namespace net
diff --git a/chromium/net/dns/notify_watcher_mac.h b/chromium/net/dns/notify_watcher_mac.h
new file mode 100644
index 00000000000..01375d56637
--- /dev/null
+++ b/chromium/net/dns/notify_watcher_mac.h
@@ -0,0 +1,47 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_NOTIFY_WATCHER_MAC_H_
+#define NET_DNS_NOTIFY_WATCHER_MAC_H_
+
+#include "base/callback.h"
+#include "base/message_loop/message_loop.h"
+
+namespace net {
+
+// Watches for notifications from Libnotify and delivers them to a Callback.
+// After failure the watch is cancelled and will have to be restarted.
+class NotifyWatcherMac : public base::MessageLoopForIO::Watcher {
+ public:
+ // Called on received notification with true on success and false on error.
+ typedef base::Callback<void(bool succeeded)> CallbackType;
+
+ NotifyWatcherMac();
+
+ // When deleted, automatically cancels.
+ virtual ~NotifyWatcherMac();
+
+ // Registers for notifications for |key|. Returns true if succeeds. If so,
+ // will deliver asynchronous notifications and errors to |callback|.
+ bool Watch(const char* key, const CallbackType& callback);
+
+ // Cancels the watch.
+ void Cancel();
+
+ private:
+ // MessageLoopForIO::Watcher:
+ virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE;
+ virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE {}
+
+ int notify_fd_;
+ int notify_token_;
+ CallbackType callback_;
+ base::MessageLoopForIO::FileDescriptorWatcher watcher_;
+
+ DISALLOW_COPY_AND_ASSIGN(NotifyWatcherMac);
+};
+
+} // namespace net
+
+#endif // NET_DNS_NOTIFY_WATCHER_MAC_H_
diff --git a/chromium/net/dns/record_parsed.cc b/chromium/net/dns/record_parsed.cc
new file mode 100644
index 00000000000..bee6c7aac63
--- /dev/null
+++ b/chromium/net/dns/record_parsed.cc
@@ -0,0 +1,86 @@
+// Copyright (c) 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/record_parsed.h"
+
+#include "base/logging.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/record_rdata.h"
+
+namespace net {
+
+RecordParsed::RecordParsed(const std::string& name, uint16 type, uint16 klass,
+ uint32 ttl, scoped_ptr<const RecordRdata> rdata,
+ base::Time time_created)
+ : name_(name), type_(type), klass_(klass), ttl_(ttl), rdata_(rdata.Pass()),
+ time_created_(time_created) {
+}
+
+RecordParsed::~RecordParsed() {
+}
+
+// static
+scoped_ptr<const RecordParsed> RecordParsed::CreateFrom(
+ DnsRecordParser* parser,
+ base::Time time_created) {
+ DnsResourceRecord record;
+ scoped_ptr<const RecordRdata> rdata;
+
+ if (!parser->ReadRecord(&record))
+ return scoped_ptr<const RecordParsed>();
+
+ switch (record.type) {
+ case ARecordRdata::kType:
+ rdata = ARecordRdata::Create(record.rdata, *parser);
+ break;
+ case AAAARecordRdata::kType:
+ rdata = AAAARecordRdata::Create(record.rdata, *parser);
+ break;
+ case CnameRecordRdata::kType:
+ rdata = CnameRecordRdata::Create(record.rdata, *parser);
+ break;
+ case PtrRecordRdata::kType:
+ rdata = PtrRecordRdata::Create(record.rdata, *parser);
+ break;
+ case SrvRecordRdata::kType:
+ rdata = SrvRecordRdata::Create(record.rdata, *parser);
+ break;
+ case TxtRecordRdata::kType:
+ rdata = TxtRecordRdata::Create(record.rdata, *parser);
+ break;
+ case NsecRecordRdata::kType:
+ rdata = NsecRecordRdata::Create(record.rdata, *parser);
+ break;
+ default:
+ LOG(WARNING) << "Unknown RData type for recieved record: " << record.type;
+ return scoped_ptr<const RecordParsed>();
+ }
+
+ if (!rdata.get())
+ return scoped_ptr<const RecordParsed>();
+
+ return scoped_ptr<const RecordParsed>(new RecordParsed(record.name,
+ record.type,
+ record.klass,
+ record.ttl,
+ rdata.Pass(),
+ time_created));
+}
+
+bool RecordParsed::IsEqual(const RecordParsed* other, bool is_mdns) const {
+ DCHECK(other);
+ uint16 klass = klass_;
+ uint16 other_klass = other->klass_;
+
+ if (is_mdns) {
+ klass &= dns_protocol::kMDnsClassMask;
+ other_klass &= dns_protocol::kMDnsClassMask;
+ }
+
+ return name_ == other->name_ &&
+ klass == other_klass &&
+ type_ == other->type_ &&
+ rdata_->IsEqual(other->rdata_.get());
+}
+}
diff --git a/chromium/net/dns/record_parsed.h b/chromium/net/dns/record_parsed.h
new file mode 100644
index 00000000000..016c4910bfb
--- /dev/null
+++ b/chromium/net/dns/record_parsed.h
@@ -0,0 +1,64 @@
+// Copyright (c) 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_RECORD_PARSED_H_
+#define NET_DNS_RECORD_PARSED_H_
+
+#include <string>
+
+#include "base/memory/scoped_ptr.h"
+#include "base/time/time.h"
+#include "net/base/net_export.h"
+
+namespace net {
+
+class DnsRecordParser;
+class RecordRdata;
+
+// Parsed record. This is a form of DnsResourceRecord where the rdata section
+// has been parsed into a data structure.
+class NET_EXPORT_PRIVATE RecordParsed {
+ public:
+ virtual ~RecordParsed();
+
+ // All records are inherently immutable. Return a const pointer.
+ static scoped_ptr<const RecordParsed> CreateFrom(DnsRecordParser* parser,
+ base::Time time_created);
+
+ const std::string& name() const { return name_; }
+ uint16 type() const { return type_; }
+ uint16 klass() const { return klass_; }
+ uint32 ttl() const { return ttl_; }
+
+ base::Time time_created() const { return time_created_; }
+
+ template <class T> const T* rdata() const {
+ if (T::kType != type_)
+ return NULL;
+ return static_cast<const T*>(rdata_.get());
+ }
+
+ // Check if two records have the same data. Ignores time_created and ttl.
+ // If |is_mdns| is true, ignore the top bit of the class
+ // (the cache flush bit).
+ bool IsEqual(const RecordParsed* other, bool is_mdns) const;
+
+ private:
+ RecordParsed(const std::string& name, uint16 type, uint16 klass,
+ uint32 ttl, scoped_ptr<const RecordRdata> rdata,
+ base::Time time_created);
+
+ std::string name_; // in dotted form
+ const uint16 type_;
+ const uint16 klass_;
+ const uint32 ttl_;
+
+ const scoped_ptr<const RecordRdata> rdata_;
+
+ const base::Time time_created_;
+};
+
+} // namespace net
+
+#endif // NET_DNS_RECORD_PARSED_H_
diff --git a/chromium/net/dns/record_parsed_unittest.cc b/chromium/net/dns/record_parsed_unittest.cc
new file mode 100644
index 00000000000..2864dcbe761
--- /dev/null
+++ b/chromium/net/dns/record_parsed_unittest.cc
@@ -0,0 +1,75 @@
+// Copyright (c) 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/record_parsed.h"
+
+#include "net/dns/dns_protocol.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/dns_test_util.h"
+#include "net/dns/record_rdata.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+static const uint8 kT1ResponseWithCacheFlushBit[] = {
+ 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w',
+ 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm',
+ 0x03, 'o', 'r', 'g',
+ 0x00,
+ 0x00, 0x05, // TYPE is CNAME.
+ 0x80, 0x01, // CLASS is IN with cache flush bit set.
+ 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
+ 0x24, 0x74,
+ 0x00, 0x12, // RDLENGTH is 18 bytes.
+ // ghs.l.google.com in DNS format.
+ 0x03, 'g', 'h', 's',
+ 0x01, 'l',
+ 0x06, 'g', 'o', 'o', 'g', 'l', 'e',
+ 0x03, 'c', 'o', 'm',
+ 0x00
+};
+
+TEST(RecordParsedTest, ParseSingleRecord) {
+ DnsRecordParser parser(kT1ResponseDatagram, sizeof(kT1ResponseDatagram),
+ sizeof(dns_protocol::Header));
+ scoped_ptr<const RecordParsed> record;
+ const CnameRecordRdata* rdata;
+
+ parser.SkipQuestion();
+ record = RecordParsed::CreateFrom(&parser, base::Time());
+ EXPECT_TRUE(record != NULL);
+
+ ASSERT_EQ("codereview.chromium.org", record->name());
+ ASSERT_EQ(dns_protocol::kTypeCNAME, record->type());
+ ASSERT_EQ(dns_protocol::kClassIN, record->klass());
+
+ rdata = record->rdata<CnameRecordRdata>();
+ ASSERT_TRUE(rdata != NULL);
+ ASSERT_EQ(kT1CanonName, rdata->cname());
+
+ ASSERT_FALSE(record->rdata<SrvRecordRdata>());
+ ASSERT_TRUE(record->IsEqual(record.get(), true));
+}
+
+TEST(RecordParsedTest, CacheFlushBitCompare) {
+ DnsRecordParser parser1(kT1ResponseDatagram, sizeof(kT1ResponseDatagram),
+ sizeof(dns_protocol::Header));
+ parser1.SkipQuestion();
+ scoped_ptr<const RecordParsed> record1 =
+ RecordParsed::CreateFrom(&parser1, base::Time());
+
+ DnsRecordParser parser2(kT1ResponseWithCacheFlushBit,
+ sizeof(kT1ResponseWithCacheFlushBit),
+ 0);
+
+ scoped_ptr<const RecordParsed> record2 =
+ RecordParsed::CreateFrom(&parser2, base::Time());
+
+ EXPECT_FALSE(record1->IsEqual(record2.get(), false));
+ EXPECT_TRUE(record1->IsEqual(record2.get(), true));
+ EXPECT_FALSE(record2->IsEqual(record1.get(), false));
+ EXPECT_TRUE(record2->IsEqual(record1.get(), true));
+}
+
+} //namespace net
diff --git a/chromium/net/dns/record_rdata.cc b/chromium/net/dns/record_rdata.cc
new file mode 100644
index 00000000000..4ebc643a377
--- /dev/null
+++ b/chromium/net/dns/record_rdata.cc
@@ -0,0 +1,287 @@
+// Copyright (c) 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/record_rdata.h"
+
+#include "net/base/big_endian.h"
+#include "net/base/dns_util.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/dns_response.h"
+
+namespace net {
+
+static const size_t kSrvRecordMinimumSize = 6;
+
+RecordRdata::RecordRdata() {
+}
+
+SrvRecordRdata::SrvRecordRdata() : priority_(0), weight_(0), port_(0) {
+}
+
+SrvRecordRdata::~SrvRecordRdata() {}
+
+// static
+scoped_ptr<SrvRecordRdata> SrvRecordRdata::Create(
+ const base::StringPiece& data,
+ const DnsRecordParser& parser) {
+ if (data.size() < kSrvRecordMinimumSize) return scoped_ptr<SrvRecordRdata>();
+
+ scoped_ptr<SrvRecordRdata> rdata(new SrvRecordRdata);
+
+ BigEndianReader reader(data.data(), data.size());
+ // 2 bytes for priority, 2 bytes for weight, 2 bytes for port.
+ reader.ReadU16(&rdata->priority_);
+ reader.ReadU16(&rdata->weight_);
+ reader.ReadU16(&rdata->port_);
+
+ if (!parser.ReadName(data.substr(kSrvRecordMinimumSize).begin(),
+ &rdata->target_))
+ return scoped_ptr<SrvRecordRdata>();
+
+ return rdata.Pass();
+}
+
+uint16 SrvRecordRdata::Type() const {
+ return SrvRecordRdata::kType;
+}
+
+bool SrvRecordRdata::IsEqual(const RecordRdata* other) const {
+ if (other->Type() != Type()) return false;
+ const SrvRecordRdata* srv_other = static_cast<const SrvRecordRdata*>(other);
+ return weight_ == srv_other->weight_ &&
+ port_ == srv_other->port_ &&
+ priority_ == srv_other->priority_ &&
+ target_ == srv_other->target_;
+}
+
+ARecordRdata::ARecordRdata() {
+}
+
+ARecordRdata::~ARecordRdata() {
+}
+
+// static
+scoped_ptr<ARecordRdata> ARecordRdata::Create(
+ const base::StringPiece& data,
+ const DnsRecordParser& parser) {
+ if (data.size() != kIPv4AddressSize)
+ return scoped_ptr<ARecordRdata>();
+
+ scoped_ptr<ARecordRdata> rdata(new ARecordRdata);
+
+ rdata->address_.resize(kIPv4AddressSize);
+ for (unsigned i = 0; i < kIPv4AddressSize; ++i) {
+ rdata->address_[i] = data[i];
+ }
+
+ return rdata.Pass();
+}
+
+uint16 ARecordRdata::Type() const {
+ return ARecordRdata::kType;
+}
+
+bool ARecordRdata::IsEqual(const RecordRdata* other) const {
+ if (other->Type() != Type()) return false;
+ const ARecordRdata* a_other = static_cast<const ARecordRdata*>(other);
+ return address_ == a_other->address_;
+}
+
+AAAARecordRdata::AAAARecordRdata() {
+}
+
+AAAARecordRdata::~AAAARecordRdata() {
+}
+
+// static
+scoped_ptr<AAAARecordRdata> AAAARecordRdata::Create(
+ const base::StringPiece& data,
+ const DnsRecordParser& parser) {
+ if (data.size() != kIPv6AddressSize)
+ return scoped_ptr<AAAARecordRdata>();
+
+ scoped_ptr<AAAARecordRdata> rdata(new AAAARecordRdata);
+
+ rdata->address_.resize(kIPv6AddressSize);
+ for (unsigned i = 0; i < kIPv6AddressSize; ++i) {
+ rdata->address_[i] = data[i];
+ }
+
+ return rdata.Pass();
+}
+
+uint16 AAAARecordRdata::Type() const {
+ return AAAARecordRdata::kType;
+}
+
+bool AAAARecordRdata::IsEqual(const RecordRdata* other) const {
+ if (other->Type() != Type()) return false;
+ const AAAARecordRdata* a_other = static_cast<const AAAARecordRdata*>(other);
+ return address_ == a_other->address_;
+}
+
+CnameRecordRdata::CnameRecordRdata() {
+}
+
+CnameRecordRdata::~CnameRecordRdata() {
+}
+
+// static
+scoped_ptr<CnameRecordRdata> CnameRecordRdata::Create(
+ const base::StringPiece& data,
+ const DnsRecordParser& parser) {
+ scoped_ptr<CnameRecordRdata> rdata(new CnameRecordRdata);
+
+ if (!parser.ReadName(data.begin(), &rdata->cname_))
+ return scoped_ptr<CnameRecordRdata>();
+
+ return rdata.Pass();
+}
+
+uint16 CnameRecordRdata::Type() const {
+ return CnameRecordRdata::kType;
+}
+
+bool CnameRecordRdata::IsEqual(const RecordRdata* other) const {
+ if (other->Type() != Type()) return false;
+ const CnameRecordRdata* cname_other =
+ static_cast<const CnameRecordRdata*>(other);
+ return cname_ == cname_other->cname_;
+}
+
+PtrRecordRdata::PtrRecordRdata() {
+}
+
+PtrRecordRdata::~PtrRecordRdata() {
+}
+
+// static
+scoped_ptr<PtrRecordRdata> PtrRecordRdata::Create(
+ const base::StringPiece& data,
+ const DnsRecordParser& parser) {
+ scoped_ptr<PtrRecordRdata> rdata(new PtrRecordRdata);
+
+ if (!parser.ReadName(data.begin(), &rdata->ptrdomain_))
+ return scoped_ptr<PtrRecordRdata>();
+
+ return rdata.Pass();
+}
+
+uint16 PtrRecordRdata::Type() const {
+ return PtrRecordRdata::kType;
+}
+
+bool PtrRecordRdata::IsEqual(const RecordRdata* other) const {
+ if (other->Type() != Type()) return false;
+ const PtrRecordRdata* ptr_other = static_cast<const PtrRecordRdata*>(other);
+ return ptrdomain_ == ptr_other->ptrdomain_;
+}
+
+TxtRecordRdata::TxtRecordRdata() {
+}
+
+TxtRecordRdata::~TxtRecordRdata() {
+}
+
+// static
+scoped_ptr<TxtRecordRdata> TxtRecordRdata::Create(
+ const base::StringPiece& data,
+ const DnsRecordParser& parser) {
+ scoped_ptr<TxtRecordRdata> rdata(new TxtRecordRdata);
+
+ for (size_t i = 0; i < data.size(); ) {
+ uint8 length = data[i];
+
+ if (i + length >= data.size())
+ return scoped_ptr<TxtRecordRdata>();
+
+ rdata->texts_.push_back(data.substr(i + 1, length).as_string());
+
+ // Move to the next string.
+ i += length + 1;
+ }
+
+ return rdata.Pass();
+}
+
+uint16 TxtRecordRdata::Type() const {
+ return TxtRecordRdata::kType;
+}
+
+bool TxtRecordRdata::IsEqual(const RecordRdata* other) const {
+ if (other->Type() != Type()) return false;
+ const TxtRecordRdata* txt_other = static_cast<const TxtRecordRdata*>(other);
+ return texts_ == txt_other->texts_;
+}
+
+NsecRecordRdata::NsecRecordRdata() {
+}
+
+NsecRecordRdata::~NsecRecordRdata() {
+}
+
+// static
+scoped_ptr<NsecRecordRdata> NsecRecordRdata::Create(
+ const base::StringPiece& data,
+ const DnsRecordParser& parser) {
+ scoped_ptr<NsecRecordRdata> rdata(new NsecRecordRdata);
+
+ // Read the "next domain". This part for the NSEC record format is
+ // ignored for mDNS, since it has no semantic meaning.
+ unsigned next_domain_length = parser.ReadName(data.data(), NULL);
+
+ // If we did not succeed in getting the next domain or the data length
+ // is too short for reading the bitmap header, return.
+ if (next_domain_length == 0 || data.length() < next_domain_length + 2)
+ return scoped_ptr<NsecRecordRdata>();
+
+ struct BitmapHeader {
+ uint8 block_number; // The block number should be zero.
+ uint8 length; // Bitmap length in bytes. Between 1 and 32.
+ };
+
+ const BitmapHeader* header = reinterpret_cast<const BitmapHeader*>(
+ data.data() + next_domain_length);
+
+ // The block number must be zero in mDns-specific NSEC records. The bitmap
+ // length must be between 1 and 32.
+ if (header->block_number != 0 || header->length == 0 || header->length > 32)
+ return scoped_ptr<NsecRecordRdata>();
+
+ base::StringPiece bitmap_data = data.substr(next_domain_length + 2);
+
+ // Since we may only have one block, the data length must be exactly equal to
+ // the domain length plus bitmap size.
+ if (bitmap_data.length() != header->length)
+ return scoped_ptr<NsecRecordRdata>();
+
+ rdata->bitmap_.insert(rdata->bitmap_.begin(),
+ bitmap_data.begin(),
+ bitmap_data.end());
+
+ return rdata.Pass();
+}
+
+uint16 NsecRecordRdata::Type() const {
+ return NsecRecordRdata::kType;
+}
+
+bool NsecRecordRdata::IsEqual(const RecordRdata* other) const {
+ if (other->Type() != Type())
+ return false;
+ const NsecRecordRdata* nsec_other =
+ static_cast<const NsecRecordRdata*>(other);
+ return bitmap_ == nsec_other->bitmap_;
+}
+
+bool NsecRecordRdata::GetBit(unsigned i) const {
+ unsigned byte_num = i/8;
+ if (bitmap_.size() < byte_num + 1)
+ return false;
+
+ unsigned bit_num = 7 - i % 8;
+ return (bitmap_[byte_num] & (1 << bit_num)) != 0;
+}
+
+} // namespace net
diff --git a/chromium/net/dns/record_rdata.h b/chromium/net/dns/record_rdata.h
new file mode 100644
index 00000000000..f83a48650e1
--- /dev/null
+++ b/chromium/net/dns/record_rdata.h
@@ -0,0 +1,217 @@
+// Copyright (c) 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_RECORD_RDATA_H_
+#define NET_DNS_RECORD_RDATA_H_
+
+#include <string>
+#include <vector>
+
+#include "base/basictypes.h"
+#include "base/compiler_specific.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/strings/string_piece.h"
+#include "net/base/big_endian.h"
+#include "net/base/net_export.h"
+#include "net/base/net_util.h"
+#include "net/dns/dns_protocol.h"
+
+namespace net {
+
+class DnsRecordParser;
+
+// Parsed represenation of the extra data in a record. Does not include standard
+// DNS record data such as TTL, Name, Type and Class.
+class NET_EXPORT_PRIVATE RecordRdata {
+ public:
+ virtual ~RecordRdata() {}
+
+ virtual bool IsEqual(const RecordRdata* other) const = 0;
+ virtual uint16 Type() const = 0;
+
+ protected:
+ RecordRdata();
+
+ DISALLOW_COPY_AND_ASSIGN(RecordRdata);
+};
+
+// SRV record format (http://www.ietf.org/rfc/rfc2782.txt):
+// 2 bytes network-order unsigned priority
+// 2 bytes network-order unsigned weight
+// 2 bytes network-order unsigned port
+// target: domain name (on-the-wire representation)
+class NET_EXPORT_PRIVATE SrvRecordRdata : public RecordRdata {
+ public:
+ static const uint16 kType = dns_protocol::kTypeSRV;
+
+ virtual ~SrvRecordRdata();
+ static scoped_ptr<SrvRecordRdata> Create(const base::StringPiece& data,
+ const DnsRecordParser& parser);
+
+ virtual bool IsEqual(const RecordRdata* other) const OVERRIDE;
+ virtual uint16 Type() const OVERRIDE;
+
+ uint16 priority() const { return priority_; }
+ uint16 weight() const { return weight_; }
+ uint16 port() const { return port_; }
+
+ const std::string& target() const { return target_; }
+
+ private:
+ SrvRecordRdata();
+
+ uint16 priority_;
+ uint16 weight_;
+ uint16 port_;
+
+ std::string target_;
+
+ DISALLOW_COPY_AND_ASSIGN(SrvRecordRdata);
+};
+
+// A Record format (http://www.ietf.org/rfc/rfc1035.txt):
+// 4 bytes for IP address.
+class NET_EXPORT_PRIVATE ARecordRdata : public RecordRdata {
+ public:
+ static const uint16 kType = dns_protocol::kTypeA;
+
+ virtual ~ARecordRdata();
+ static scoped_ptr<ARecordRdata> Create(const base::StringPiece& data,
+ const DnsRecordParser& parser);
+ virtual bool IsEqual(const RecordRdata* other) const OVERRIDE;
+ virtual uint16 Type() const OVERRIDE;
+
+ const IPAddressNumber& address() const { return address_; }
+
+ private:
+ ARecordRdata();
+
+ IPAddressNumber address_;
+
+ DISALLOW_COPY_AND_ASSIGN(ARecordRdata);
+};
+
+// AAAA Record format (http://www.ietf.org/rfc/rfc1035.txt):
+// 16 bytes for IP address.
+class NET_EXPORT_PRIVATE AAAARecordRdata : public RecordRdata {
+ public:
+ static const uint16 kType = dns_protocol::kTypeAAAA;
+
+ virtual ~AAAARecordRdata();
+ static scoped_ptr<AAAARecordRdata> Create(const base::StringPiece& data,
+ const DnsRecordParser& parser);
+ virtual bool IsEqual(const RecordRdata* other) const OVERRIDE;
+ virtual uint16 Type() const OVERRIDE;
+
+ const IPAddressNumber& address() const { return address_; }
+
+ private:
+ AAAARecordRdata();
+
+ IPAddressNumber address_;
+
+ DISALLOW_COPY_AND_ASSIGN(AAAARecordRdata);
+};
+
+// CNAME record format (http://www.ietf.org/rfc/rfc1035.txt):
+// cname: On the wire representation of domain name.
+class NET_EXPORT_PRIVATE CnameRecordRdata : public RecordRdata {
+ public:
+ static const uint16 kType = dns_protocol::kTypeCNAME;
+
+ virtual ~CnameRecordRdata();
+ static scoped_ptr<CnameRecordRdata> Create(const base::StringPiece& data,
+ const DnsRecordParser& parser);
+ virtual bool IsEqual(const RecordRdata* other) const OVERRIDE;
+ virtual uint16 Type() const OVERRIDE;
+
+ std::string cname() const { return cname_; }
+
+ private:
+ CnameRecordRdata();
+
+ std::string cname_;
+
+ DISALLOW_COPY_AND_ASSIGN(CnameRecordRdata);
+};
+
+// PTR record format (http://www.ietf.org/rfc/rfc1035.txt):
+// domain: On the wire representation of domain name.
+class NET_EXPORT_PRIVATE PtrRecordRdata : public RecordRdata {
+ public:
+ static const uint16 kType = dns_protocol::kTypePTR;
+
+ virtual ~PtrRecordRdata();
+ static scoped_ptr<PtrRecordRdata> Create(const base::StringPiece& data,
+ const DnsRecordParser& parser);
+ virtual bool IsEqual(const RecordRdata* other) const OVERRIDE;
+ virtual uint16 Type() const OVERRIDE;
+
+ std::string ptrdomain() const { return ptrdomain_; }
+
+ private:
+ PtrRecordRdata();
+
+ std::string ptrdomain_;
+
+ DISALLOW_COPY_AND_ASSIGN(PtrRecordRdata);
+};
+
+// TXT record format (http://www.ietf.org/rfc/rfc1035.txt):
+// texts: One or more <character-string>s.
+// a <character-string> is a length octet followed by as many characters.
+class NET_EXPORT_PRIVATE TxtRecordRdata : public RecordRdata {
+ public:
+ static const uint16 kType = dns_protocol::kTypeTXT;
+
+ virtual ~TxtRecordRdata();
+ static scoped_ptr<TxtRecordRdata> Create(const base::StringPiece& data,
+ const DnsRecordParser& parser);
+ virtual bool IsEqual(const RecordRdata* other) const OVERRIDE;
+ virtual uint16 Type() const OVERRIDE;
+
+ const std::vector<std::string>& texts() const { return texts_; }
+
+ private:
+ TxtRecordRdata();
+
+ std::vector<std::string> texts_;
+
+ DISALLOW_COPY_AND_ASSIGN(TxtRecordRdata);
+};
+
+// Only the subset of the NSEC record format required by mDNS is supported.
+// Nsec record format is described in http://www.ietf.org/rfc/rfc3845.txt and
+// the limited version required for mDNS described in
+// http://www.rfc-editor.org/rfc/rfc6762.txt Section 6.1.
+class NET_EXPORT_PRIVATE NsecRecordRdata : public RecordRdata {
+ public:
+ static const uint16 kType = dns_protocol::kTypeNSEC;
+
+ virtual ~NsecRecordRdata();
+ static scoped_ptr<NsecRecordRdata> Create(const base::StringPiece& data,
+ const DnsRecordParser& parser);
+ virtual bool IsEqual(const RecordRdata* other) const OVERRIDE;
+ virtual uint16 Type() const OVERRIDE;
+
+ // Length of the bitmap in bits.
+ unsigned bitmap_length() const { return bitmap_.size() * 8; }
+
+ // Returns bit i-th bit in the bitmap, where bits withing a byte are organized
+ // most to least significant. If it is set, a record with rrtype i exists for
+ // the domain name of this nsec record.
+ bool GetBit(unsigned i) const;
+
+ private:
+ NsecRecordRdata();
+
+ std::vector<uint8> bitmap_;
+
+ DISALLOW_COPY_AND_ASSIGN(NsecRecordRdata);
+};
+
+
+} // namespace net
+
+#endif // NET_DNS_RECORD_RDATA_H_
diff --git a/chromium/net/dns/record_rdata_unittest.cc b/chromium/net/dns/record_rdata_unittest.cc
new file mode 100644
index 00000000000..90bac446e2e
--- /dev/null
+++ b/chromium/net/dns/record_rdata_unittest.cc
@@ -0,0 +1,222 @@
+// Copyright (c) 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/memory/scoped_ptr.h"
+#include "net/base/net_util.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/record_rdata.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+base::StringPiece MakeStringPiece(const uint8* data, unsigned size) {
+ const char* data_cc = reinterpret_cast<const char*>(data);
+ return base::StringPiece(data_cc, size);
+}
+
+TEST(RecordRdataTest, ParseSrvRecord) {
+ scoped_ptr<SrvRecordRdata> record1_obj;
+ scoped_ptr<SrvRecordRdata> record2_obj;
+
+ // These are just the rdata portions of the DNS records, rather than complete
+ // records, but it works well enough for this test.
+
+ const uint8 record[] = {
+ 0x00, 0x01,
+ 0x00, 0x02,
+ 0x00, 0x50,
+ 0x03, 'w', 'w', 'w',
+ 0x06, 'g', 'o', 'o', 'g', 'l', 'e',
+ 0x03, 'c', 'o', 'm',
+ 0x00,
+ 0x01, 0x01,
+ 0x01, 0x02,
+ 0x01, 0x03,
+ 0x04, 'w', 'w', 'w', '2',
+ 0xc0, 0x0a, // Pointer to "google.com"
+ };
+
+ DnsRecordParser parser(record, sizeof(record), 0);
+ const unsigned first_record_len = 22;
+ base::StringPiece record1_strpiece = MakeStringPiece(
+ record, first_record_len);
+ base::StringPiece record2_strpiece = MakeStringPiece(
+ record + first_record_len, sizeof(record) - first_record_len);
+
+ record1_obj = SrvRecordRdata::Create(record1_strpiece, parser);
+ ASSERT_TRUE(record1_obj != NULL);
+ ASSERT_EQ(1, record1_obj->priority());
+ ASSERT_EQ(2, record1_obj->weight());
+ ASSERT_EQ(80, record1_obj->port());
+
+ ASSERT_EQ("www.google.com", record1_obj->target());
+
+ record2_obj = SrvRecordRdata::Create(record2_strpiece, parser);
+ ASSERT_TRUE(record2_obj != NULL);
+ ASSERT_EQ(257, record2_obj->priority());
+ ASSERT_EQ(258, record2_obj->weight());
+ ASSERT_EQ(259, record2_obj->port());
+
+ ASSERT_EQ("www2.google.com", record2_obj->target());
+
+ ASSERT_TRUE(record1_obj->IsEqual(record1_obj.get()));
+ ASSERT_FALSE(record1_obj->IsEqual(record2_obj.get()));
+}
+
+TEST(RecordRdataTest, ParseARecord) {
+ scoped_ptr<ARecordRdata> record_obj;
+
+ // These are just the rdata portions of the DNS records, rather than complete
+ // records, but it works well enough for this test.
+
+ const uint8 record[] = {
+ 0x7F, 0x00, 0x00, 0x01 // 127.0.0.1
+ };
+
+ DnsRecordParser parser(record, sizeof(record), 0);
+ base::StringPiece record_strpiece = MakeStringPiece(record, sizeof(record));
+
+ record_obj = ARecordRdata::Create(record_strpiece, parser);
+ ASSERT_TRUE(record_obj != NULL);
+
+ ASSERT_EQ("127.0.0.1", IPAddressToString(record_obj->address()));
+
+ ASSERT_TRUE(record_obj->IsEqual(record_obj.get()));
+}
+
+TEST(RecordRdataTest, ParseAAAARecord) {
+ scoped_ptr<AAAARecordRdata> record_obj;
+
+ // These are just the rdata portions of the DNS records, rather than complete
+ // records, but it works well enough for this test.
+
+ const uint8 record[] = {
+ 0x12, 0x34, 0x56, 0x78,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x09 // 1234:5678::9A
+ };
+
+ DnsRecordParser parser(record, sizeof(record), 0);
+ base::StringPiece record_strpiece = MakeStringPiece(record, sizeof(record));
+
+ record_obj = AAAARecordRdata::Create(record_strpiece, parser);
+ ASSERT_TRUE(record_obj != NULL);
+
+ ASSERT_EQ("1234:5678::9",
+ IPAddressToString(record_obj->address()));
+
+ ASSERT_TRUE(record_obj->IsEqual(record_obj.get()));
+}
+
+TEST(RecordRdataTest, ParseCnameRecord) {
+ scoped_ptr<CnameRecordRdata> record_obj;
+
+ // These are just the rdata portions of the DNS records, rather than complete
+ // records, but it works well enough for this test.
+
+ const uint8 record[] = {
+ 0x03, 'w', 'w', 'w',
+ 0x06, 'g', 'o', 'o', 'g', 'l', 'e',
+ 0x03, 'c', 'o', 'm',
+ 0x00
+ };
+
+ DnsRecordParser parser(record, sizeof(record), 0);
+ base::StringPiece record_strpiece = MakeStringPiece(record, sizeof(record));
+
+ record_obj = CnameRecordRdata::Create(record_strpiece, parser);
+ ASSERT_TRUE(record_obj != NULL);
+
+ ASSERT_EQ("www.google.com", record_obj->cname());
+
+ ASSERT_TRUE(record_obj->IsEqual(record_obj.get()));
+}
+
+TEST(RecordRdataTest, ParsePtrRecord) {
+ scoped_ptr<PtrRecordRdata> record_obj;
+
+ // These are just the rdata portions of the DNS records, rather than complete
+ // records, but it works well enough for this test.
+
+ const uint8 record[] = {
+ 0x03, 'w', 'w', 'w',
+ 0x06, 'g', 'o', 'o', 'g', 'l', 'e',
+ 0x03, 'c', 'o', 'm',
+ 0x00
+ };
+
+ DnsRecordParser parser(record, sizeof(record), 0);
+ base::StringPiece record_strpiece = MakeStringPiece(record, sizeof(record));
+
+ record_obj = PtrRecordRdata::Create(record_strpiece, parser);
+ ASSERT_TRUE(record_obj != NULL);
+
+ ASSERT_EQ("www.google.com", record_obj->ptrdomain());
+
+ ASSERT_TRUE(record_obj->IsEqual(record_obj.get()));
+}
+
+TEST(RecordRdataTest, ParseTxtRecord) {
+ scoped_ptr<TxtRecordRdata> record_obj;
+
+ // These are just the rdata portions of the DNS records, rather than complete
+ // records, but it works well enough for this test.
+
+ const uint8 record[] = {
+ 0x03, 'w', 'w', 'w',
+ 0x06, 'g', 'o', 'o', 'g', 'l', 'e',
+ 0x03, 'c', 'o', 'm'
+ };
+
+ DnsRecordParser parser(record, sizeof(record), 0);
+ base::StringPiece record_strpiece = MakeStringPiece(record, sizeof(record));
+
+ record_obj = TxtRecordRdata::Create(record_strpiece, parser);
+ ASSERT_TRUE(record_obj != NULL);
+
+ std::vector<std::string> expected;
+ expected.push_back("www");
+ expected.push_back("google");
+ expected.push_back("com");
+
+ ASSERT_EQ(expected, record_obj->texts());
+
+ ASSERT_TRUE(record_obj->IsEqual(record_obj.get()));
+}
+
+TEST(RecordRdataTest, ParseNsecRecord) {
+ scoped_ptr<NsecRecordRdata> record_obj;
+
+ // These are just the rdata portions of the DNS records, rather than complete
+ // records, but it works well enough for this test.
+
+ const uint8 record[] = {
+ 0x03, 'w', 'w', 'w',
+ 0x06, 'g', 'o', 'o', 'g', 'l', 'e',
+ 0x03, 'c', 'o', 'm',
+ 0x00,
+ 0x00, 0x02, 0x40, 0x01
+ };
+
+ DnsRecordParser parser(record, sizeof(record), 0);
+ base::StringPiece record_strpiece = MakeStringPiece(record, sizeof(record));
+
+ record_obj = NsecRecordRdata::Create(record_strpiece, parser);
+ ASSERT_TRUE(record_obj != NULL);
+
+ ASSERT_EQ(16u, record_obj->bitmap_length());
+
+ EXPECT_FALSE(record_obj->GetBit(0));
+ EXPECT_TRUE(record_obj->GetBit(1));
+ for (int i = 2; i < 15; i++) {
+ EXPECT_FALSE(record_obj->GetBit(i));
+ }
+ EXPECT_TRUE(record_obj->GetBit(15));
+
+ ASSERT_TRUE(record_obj->IsEqual(record_obj.get()));
+}
+
+
+} // namespace net
diff --git a/chromium/net/dns/serial_worker.cc b/chromium/net/dns/serial_worker.cc
new file mode 100644
index 00000000000..394721c1a65
--- /dev/null
+++ b/chromium/net/dns/serial_worker.cc
@@ -0,0 +1,104 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/serial_worker.h"
+
+#include "base/bind.h"
+#include "base/location.h"
+#include "base/message_loop/message_loop_proxy.h"
+#include "base/threading/worker_pool.h"
+
+namespace net {
+
+namespace {
+ // Delay between calls to WorkerPool::PostTask
+ const int kWorkerPoolRetryDelayMs = 100;
+}
+
+SerialWorker::SerialWorker()
+ : message_loop_(base::MessageLoopProxy::current()),
+ state_(IDLE) {}
+
+SerialWorker::~SerialWorker() {}
+
+void SerialWorker::WorkNow() {
+ DCHECK(message_loop_->BelongsToCurrentThread());
+ switch (state_) {
+ case IDLE:
+ if (!base::WorkerPool::PostTask(FROM_HERE, base::Bind(
+ &SerialWorker::DoWorkJob, this), false)) {
+#if defined(OS_POSIX)
+ // See worker_pool_posix.cc.
+ NOTREACHED() << "WorkerPool::PostTask is not expected to fail on posix";
+#else
+ LOG(WARNING) << "Failed to WorkerPool::PostTask, will retry later";
+ message_loop_->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(&SerialWorker::RetryWork, this),
+ base::TimeDelta::FromMilliseconds(kWorkerPoolRetryDelayMs));
+ state_ = WAITING;
+ return;
+#endif
+ }
+ state_ = WORKING;
+ return;
+ case WORKING:
+ // Remember to re-read after |DoRead| finishes.
+ state_ = PENDING;
+ return;
+ case CANCELLED:
+ case PENDING:
+ case WAITING:
+ return;
+ default:
+ NOTREACHED() << "Unexpected state " << state_;
+ }
+}
+
+void SerialWorker::Cancel() {
+ DCHECK(message_loop_->BelongsToCurrentThread());
+ state_ = CANCELLED;
+}
+
+void SerialWorker::DoWorkJob() {
+ this->DoWork();
+ // If this fails, the loop is gone, so there is no point retrying.
+ message_loop_->PostTask(FROM_HERE, base::Bind(
+ &SerialWorker::OnWorkJobFinished, this));
+}
+
+void SerialWorker::OnWorkJobFinished() {
+ DCHECK(message_loop_->BelongsToCurrentThread());
+ switch (state_) {
+ case CANCELLED:
+ return;
+ case WORKING:
+ state_ = IDLE;
+ this->OnWorkFinished();
+ return;
+ case PENDING:
+ state_ = IDLE;
+ WorkNow();
+ return;
+ default:
+ NOTREACHED() << "Unexpected state " << state_;
+ }
+}
+
+void SerialWorker::RetryWork() {
+ DCHECK(message_loop_->BelongsToCurrentThread());
+ switch (state_) {
+ case CANCELLED:
+ return;
+ case WAITING:
+ state_ = IDLE;
+ WorkNow();
+ return;
+ default:
+ NOTREACHED() << "Unexpected state " << state_;
+ }
+}
+
+} // namespace net
+
diff --git a/chromium/net/dns/serial_worker.h b/chromium/net/dns/serial_worker.h
new file mode 100644
index 00000000000..59a4d3f189c
--- /dev/null
+++ b/chromium/net/dns/serial_worker.h
@@ -0,0 +1,96 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_SERIAL_WORKER_H_
+#define NET_DNS_SERIAL_WORKER_H_
+
+#include <string>
+
+#include "base/compiler_specific.h"
+#include "base/memory/ref_counted.h"
+#include "net/base/net_export.h"
+
+// Forward declaration
+namespace base {
+class MessageLoopProxy;
+}
+
+namespace net {
+
+// SerialWorker executes a job on WorkerPool serially -- **once at a time**.
+// On |WorkNow|, a call to |DoWork| is scheduled on the WorkerPool. Once it
+// completes, |OnWorkFinished| is called on the origin thread.
+// If |WorkNow| is called (1 or more times) while |DoWork| is already under way,
+// |DoWork| will be called once: after current |DoWork| completes, before a
+// call to |OnWorkFinished|.
+//
+// This behavior is designed for updating a result after some trigger, for
+// example reading a file once FilePathWatcher indicates it changed.
+//
+// Derived classes should store results of work done in |DoWork| in dedicated
+// fields and read them in |OnWorkFinished| which is executed on the origin
+// thread. This avoids the need to template this class.
+//
+// This implementation avoids locking by using the |state_| member to ensure
+// that |DoWork| and |OnWorkFinished| cannot execute in parallel.
+//
+// TODO(szym): update to WorkerPool::PostTaskAndReply once available.
+class NET_EXPORT_PRIVATE SerialWorker
+ : NON_EXPORTED_BASE(public base::RefCountedThreadSafe<SerialWorker>) {
+ public:
+ SerialWorker();
+
+ // Unless already scheduled, post |DoWork| to WorkerPool.
+ // Made virtual to allow mocking.
+ virtual void WorkNow();
+
+ // Stop scheduling jobs.
+ void Cancel();
+
+ bool IsCancelled() const { return state_ == CANCELLED; }
+
+ protected:
+ friend class base::RefCountedThreadSafe<SerialWorker>;
+ // protected to allow sub-classing, but prevent deleting
+ virtual ~SerialWorker();
+
+ // Executed on WorkerPool, at most once at a time.
+ virtual void DoWork() = 0;
+
+ // Executed on origin thread after |DoRead| completes.
+ virtual void OnWorkFinished() = 0;
+
+ base::MessageLoopProxy* loop() { return message_loop_.get(); }
+
+ private:
+ enum State {
+ CANCELLED = -1,
+ IDLE = 0,
+ WORKING, // |DoWorkJob| posted on WorkerPool, until |OnWorkJobFinished|
+ PENDING, // |WorkNow| while WORKING, must re-do work
+ WAITING, // WorkerPool is busy, |RetryWork| is posted
+ };
+
+ // Called on the worker thread, executes |DoWork| and notifies the origin
+ // thread.
+ void DoWorkJob();
+
+ // Called on the the origin thread after |DoWork| completes.
+ void OnWorkJobFinished();
+
+ // Posted to message loop in case WorkerPool is busy. (state == WAITING)
+ void RetryWork();
+
+ // Message loop for the thread of origin.
+ scoped_refptr<base::MessageLoopProxy> message_loop_;
+
+ State state_;
+
+ DISALLOW_COPY_AND_ASSIGN(SerialWorker);
+};
+
+} // namespace net
+
+#endif // NET_DNS_SERIAL_WORKER_H_
+
diff --git a/chromium/net/dns/serial_worker_unittest.cc b/chromium/net/dns/serial_worker_unittest.cc
new file mode 100644
index 00000000000..442526f29d2
--- /dev/null
+++ b/chromium/net/dns/serial_worker_unittest.cc
@@ -0,0 +1,163 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/serial_worker.h"
+
+#include "base/bind.h"
+#include "base/message_loop/message_loop.h"
+#include "base/synchronization/lock.h"
+#include "base/synchronization/waitable_event.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+class SerialWorkerTest : public testing::Test {
+ public:
+ // The class under test
+ class TestSerialWorker : public SerialWorker {
+ public:
+ explicit TestSerialWorker(SerialWorkerTest* t)
+ : test_(t) {}
+ virtual void DoWork() OVERRIDE {
+ ASSERT_TRUE(test_);
+ test_->OnWork();
+ }
+ virtual void OnWorkFinished() OVERRIDE {
+ ASSERT_TRUE(test_);
+ test_->OnWorkFinished();
+ }
+ private:
+ virtual ~TestSerialWorker() {}
+ SerialWorkerTest* test_;
+ };
+
+ // Mocks
+
+ void OnWork() {
+ { // Check that OnWork is executed serially.
+ base::AutoLock lock(work_lock_);
+ EXPECT_FALSE(work_running_) << "DoRead is not called serially!";
+ work_running_ = true;
+ }
+ BreakNow("OnWork");
+ work_allowed_.Wait();
+ // Calling from WorkerPool, but protected by work_allowed_/work_called_.
+ output_value_ = input_value_;
+
+ { // This lock might be destroyed after work_called_ is signalled.
+ base::AutoLock lock(work_lock_);
+ work_running_ = false;
+ }
+ work_called_.Signal();
+ }
+
+ void OnWorkFinished() {
+ EXPECT_TRUE(message_loop_ == base::MessageLoop::current());
+ EXPECT_EQ(output_value_, input_value_);
+ BreakNow("OnWorkFinished");
+ }
+
+ protected:
+ void BreakCallback(std::string breakpoint) {
+ breakpoint_ = breakpoint;
+ base::MessageLoop::current()->QuitNow();
+ }
+
+ void BreakNow(std::string b) {
+ message_loop_->PostTask(FROM_HERE,
+ base::Bind(&SerialWorkerTest::BreakCallback,
+ base::Unretained(this), b));
+ }
+
+ void RunUntilBreak(std::string b) {
+ message_loop_->Run();
+ ASSERT_EQ(breakpoint_, b);
+ }
+
+ SerialWorkerTest()
+ : input_value_(0),
+ output_value_(-1),
+ work_allowed_(false, false),
+ work_called_(false, false),
+ work_running_(false) {
+ }
+
+ // Helpers for tests.
+
+ // Lets OnWork run and waits for it to complete. Can only return if OnWork is
+ // executed on a concurrent thread.
+ void WaitForWork() {
+ RunUntilBreak("OnWork");
+ work_allowed_.Signal();
+ work_called_.Wait();
+ }
+
+ // test::Test methods
+ virtual void SetUp() OVERRIDE {
+ message_loop_ = base::MessageLoop::current();
+ worker_ = new TestSerialWorker(this);
+ }
+
+ virtual void TearDown() OVERRIDE {
+ // Cancel the worker to catch if it makes a late DoWork call.
+ worker_->Cancel();
+ // Check if OnWork is stalled.
+ EXPECT_FALSE(work_running_) << "OnWork should be done by TearDown";
+ // Release it for cleanliness.
+ if (work_running_) {
+ WaitForWork();
+ }
+ }
+
+ // Input value read on WorkerPool.
+ int input_value_;
+ // Output value written on WorkerPool.
+ int output_value_;
+
+ // read is called on WorkerPool so we need to synchronize with it.
+ base::WaitableEvent work_allowed_;
+ base::WaitableEvent work_called_;
+
+ // Protected by read_lock_. Used to verify that read calls are serialized.
+ bool work_running_;
+ base::Lock work_lock_;
+
+ // Loop for this thread.
+ base::MessageLoop* message_loop_;
+
+ // WatcherDelegate under test.
+ scoped_refptr<TestSerialWorker> worker_;
+
+ std::string breakpoint_;
+};
+
+TEST_F(SerialWorkerTest, ExecuteAndSerializeReads) {
+ for (int i = 0; i < 3; ++i) {
+ ++input_value_;
+ worker_->WorkNow();
+ WaitForWork();
+ RunUntilBreak("OnWorkFinished");
+
+ EXPECT_TRUE(message_loop_->IsIdleForTesting());
+ }
+
+ // Schedule two calls. OnWork checks if it is called serially.
+ ++input_value_;
+ worker_->WorkNow();
+ // read is blocked, so this will have to induce re-work
+ worker_->WorkNow();
+ WaitForWork();
+ WaitForWork();
+ RunUntilBreak("OnWorkFinished");
+
+ // No more tasks should remain.
+ EXPECT_TRUE(message_loop_->IsIdleForTesting());
+}
+
+} // namespace
+
+} // namespace net
+
diff --git a/chromium/net/dns/single_request_host_resolver.cc b/chromium/net/dns/single_request_host_resolver.cc
new file mode 100644
index 00000000000..31ef4c56b6e
--- /dev/null
+++ b/chromium/net/dns/single_request_host_resolver.cc
@@ -0,0 +1,77 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/single_request_host_resolver.h"
+
+#include "base/bind.h"
+#include "base/bind_helpers.h"
+#include "base/compiler_specific.h"
+#include "base/logging.h"
+#include "net/base/net_errors.h"
+
+namespace net {
+
+SingleRequestHostResolver::SingleRequestHostResolver(HostResolver* resolver)
+ : resolver_(resolver),
+ cur_request_(NULL),
+ callback_(
+ base::Bind(&SingleRequestHostResolver::OnResolveCompletion,
+ base::Unretained(this))) {
+ DCHECK(resolver_ != NULL);
+}
+
+SingleRequestHostResolver::~SingleRequestHostResolver() {
+ Cancel();
+}
+
+int SingleRequestHostResolver::Resolve(
+ const HostResolver::RequestInfo& info, AddressList* addresses,
+ const CompletionCallback& callback, const BoundNetLog& net_log) {
+ DCHECK(addresses);
+ DCHECK_EQ(false, callback.is_null());
+ DCHECK(cur_request_callback_.is_null()) << "resolver already in use";
+
+ HostResolver::RequestHandle request = NULL;
+
+ // We need to be notified of completion before |callback| is called, so that
+ // we can clear out |cur_request_*|.
+ CompletionCallback transient_callback =
+ callback.is_null() ? CompletionCallback() : callback_;
+
+ int rv = resolver_->Resolve(
+ info, addresses, transient_callback, &request, net_log);
+
+ if (rv == ERR_IO_PENDING) {
+ DCHECK_EQ(false, callback.is_null());
+ // Cleared in OnResolveCompletion().
+ cur_request_ = request;
+ cur_request_callback_ = callback;
+ }
+
+ return rv;
+}
+
+void SingleRequestHostResolver::Cancel() {
+ if (!cur_request_callback_.is_null()) {
+ resolver_->CancelRequest(cur_request_);
+ cur_request_ = NULL;
+ cur_request_callback_.Reset();
+ }
+}
+
+void SingleRequestHostResolver::OnResolveCompletion(int result) {
+ DCHECK(cur_request_);
+ DCHECK_EQ(false, cur_request_callback_.is_null());
+
+ CompletionCallback callback = cur_request_callback_;
+
+ // Clear the outstanding request information.
+ cur_request_ = NULL;
+ cur_request_callback_.Reset();
+
+ // Call the user's original callback.
+ callback.Run(result);
+}
+
+} // namespace net
diff --git a/chromium/net/dns/single_request_host_resolver.h b/chromium/net/dns/single_request_host_resolver.h
new file mode 100644
index 00000000000..52d01328911
--- /dev/null
+++ b/chromium/net/dns/single_request_host_resolver.h
@@ -0,0 +1,56 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_DNS_SINGLE_REQUEST_HOST_RESOLVER_H_
+#define NET_DNS_SINGLE_REQUEST_HOST_RESOLVER_H_
+
+#include "net/dns/host_resolver.h"
+
+namespace net {
+
+// This class represents the task of resolving a hostname (or IP address
+// literal) to an AddressList object. It wraps HostResolver to resolve only a
+// single hostname at a time and cancels this request when going out of scope.
+class NET_EXPORT SingleRequestHostResolver {
+ public:
+ // |resolver| must remain valid for the lifetime of |this|.
+ explicit SingleRequestHostResolver(HostResolver* resolver);
+
+ // If a completion callback is pending when the resolver is destroyed, the
+ // host resolution is cancelled, and the completion callback will not be
+ // called.
+ ~SingleRequestHostResolver();
+
+ // Resolves the given hostname (or IP address literal), filling out the
+ // |addresses| object upon success. See HostResolver::Resolve() for details.
+ int Resolve(const HostResolver::RequestInfo& info,
+ AddressList* addresses,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log);
+
+ // Cancels the in-progress request, if any. This prevents the callback
+ // from being invoked. Resolve() can be called again after cancelling.
+ void Cancel();
+
+ private:
+ // Callback for when the request to |resolver_| completes, so we dispatch
+ // to the user's callback.
+ void OnResolveCompletion(int result);
+
+ // The actual host resolver that will handle the request.
+ HostResolver* const resolver_;
+
+ // The current request (if any).
+ HostResolver::RequestHandle cur_request_;
+ CompletionCallback cur_request_callback_;
+
+ // Completion callback for when request to |resolver_| completes.
+ CompletionCallback callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(SingleRequestHostResolver);
+};
+
+} // namespace net
+
+#endif // NET_DNS_SINGLE_REQUEST_HOST_RESOLVER_H_
diff --git a/chromium/net/dns/single_request_host_resolver_unittest.cc b/chromium/net/dns/single_request_host_resolver_unittest.cc
new file mode 100644
index 00000000000..1b0198f4fde
--- /dev/null
+++ b/chromium/net/dns/single_request_host_resolver_unittest.cc
@@ -0,0 +1,124 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/single_request_host_resolver.h"
+
+#include "net/base/address_list.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/test_completion_callback.h"
+#include "net/dns/mock_host_resolver.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+// Helper class used by SingleRequestHostResolverTest.Cancel test.
+// It checks that only one request is outstanding at a time, and that
+// it is cancelled before the class is destroyed.
+class HangingHostResolver : public HostResolver {
+ public:
+ HangingHostResolver() : outstanding_request_(NULL) {}
+
+ virtual ~HangingHostResolver() {
+ EXPECT_TRUE(!has_outstanding_request());
+ }
+
+ bool has_outstanding_request() const {
+ return outstanding_request_ != NULL;
+ }
+
+ virtual int Resolve(const RequestInfo& info,
+ AddressList* addresses,
+ const CompletionCallback& callback,
+ RequestHandle* out_req,
+ const BoundNetLog& net_log) OVERRIDE {
+ EXPECT_FALSE(has_outstanding_request());
+ outstanding_request_ = reinterpret_cast<RequestHandle>(0x1234);
+ *out_req = outstanding_request_;
+
+ // Never complete this request! Caller is expected to cancel it
+ // before destroying the resolver.
+ return ERR_IO_PENDING;
+ }
+
+ virtual int ResolveFromCache(const RequestInfo& info,
+ AddressList* addresses,
+ const BoundNetLog& net_log) OVERRIDE {
+ NOTIMPLEMENTED();
+ return ERR_UNEXPECTED;
+ }
+
+ virtual void CancelRequest(RequestHandle req) OVERRIDE {
+ EXPECT_TRUE(has_outstanding_request());
+ EXPECT_EQ(req, outstanding_request_);
+ outstanding_request_ = NULL;
+ }
+
+ private:
+ RequestHandle outstanding_request_;
+
+ DISALLOW_COPY_AND_ASSIGN(HangingHostResolver);
+};
+
+// Test that a regular end-to-end lookup returns the expected result.
+TEST(SingleRequestHostResolverTest, NormalResolve) {
+ // Create a host resolver dependency that returns address "199.188.1.166"
+ // for resolutions of "watsup".
+ MockHostResolver resolver;
+ resolver.rules()->AddIPLiteralRule("watsup", "199.188.1.166", std::string());
+
+ SingleRequestHostResolver single_request_resolver(&resolver);
+
+ // Resolve "watsup:90" using our SingleRequestHostResolver.
+ AddressList addrlist;
+ TestCompletionCallback callback;
+ HostResolver::RequestInfo request(HostPortPair("watsup", 90));
+ int rv = single_request_resolver.Resolve(
+ request, &addrlist, callback.callback(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_EQ(OK, callback.WaitForResult());
+
+ // Verify that the result is what we specified in the MockHostResolver.
+ ASSERT_FALSE(addrlist.empty());
+ EXPECT_EQ("199.188.1.166", addrlist.front().ToStringWithoutPort());
+}
+
+// Test that the Cancel() method cancels any outstanding request.
+TEST(SingleRequestHostResolverTest, Cancel) {
+ HangingHostResolver resolver;
+
+ {
+ SingleRequestHostResolver single_request_resolver(&resolver);
+
+ // Resolve "watsup:90" using our SingleRequestHostResolver.
+ AddressList addrlist;
+ TestCompletionCallback callback;
+ HostResolver::RequestInfo request(HostPortPair("watsup", 90));
+ int rv = single_request_resolver.Resolve(
+ request, &addrlist, callback.callback(), BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_TRUE(resolver.has_outstanding_request());
+ }
+
+ // Now that the SingleRequestHostResolver has been destroyed, the
+ // in-progress request should have been aborted.
+ EXPECT_FALSE(resolver.has_outstanding_request());
+}
+
+// Test that the Cancel() method is a no-op when there is no outstanding
+// request.
+TEST(SingleRequestHostResolverTest, CancelWhileNoPendingRequest) {
+ HangingHostResolver resolver;
+ SingleRequestHostResolver single_request_resolver(&resolver);
+ single_request_resolver.Cancel();
+
+ // To pass, HangingHostResolver should not have received a cancellation
+ // request (since there is nothing to cancel). If it does, it will crash.
+}
+
+} // namespace
+
+} // namespace net