diff options
Diffstat (limited to 'chromium/net/dns')
76 files changed, 20576 insertions, 0 deletions
diff --git a/chromium/net/dns/address_sorter.h b/chromium/net/dns/address_sorter.h new file mode 100644 index 00000000000..6ac943086ef --- /dev/null +++ b/chromium/net/dns/address_sorter.h @@ -0,0 +1,46 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_ADDRESS_SORTER_H_ +#define NET_DNS_ADDRESS_SORTER_H_ + +#include "base/basictypes.h" +#include "base/callback.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/net_export.h" + +namespace net { + +class AddressList; + +// Sorts AddressList according to RFC3484, by likelihood of successful +// connection. Depending on the platform, the sort could be performed +// asynchronously by the OS, or synchronously by local implementation. +// AddressSorter does not necessarily preserve port numbers on the sorted list. +class NET_EXPORT AddressSorter { + public: + typedef base::Callback<void(bool success, + const AddressList& list)> CallbackType; + + virtual ~AddressSorter() {} + + // Sorts |list|, which must include at least one IPv6 address. + // Calls |callback| upon completion. Could complete synchronously. Could + // complete after this AddressSorter is destroyed. + virtual void Sort(const AddressList& list, + const CallbackType& callback) const = 0; + + // Creates platform-dependent AddressSorter. + static scoped_ptr<AddressSorter> CreateAddressSorter(); + + protected: + AddressSorter() {} + + private: + DISALLOW_COPY_AND_ASSIGN(AddressSorter); +}; + +} // namespace net + +#endif // NET_DNS_ADDRESS_SORTER_H_ diff --git a/chromium/net/dns/address_sorter_posix.cc b/chromium/net/dns/address_sorter_posix.cc new file mode 100644 index 00000000000..8d87774587d --- /dev/null +++ b/chromium/net/dns/address_sorter_posix.cc @@ -0,0 +1,426 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/address_sorter_posix.h" + +#include <netinet/in.h> + +#if defined(OS_MACOSX) || defined(OS_BSD) +#include <sys/socket.h> // Must be included before ifaddrs.h. +#include <ifaddrs.h> +#include <net/if.h> +#include <netinet/in_var.h> +#include <string.h> +#include <sys/ioctl.h> +#endif + +#include <algorithm> + +#include "base/logging.h" +#include "base/memory/scoped_vector.h" +#include "base/posix/eintr_wrapper.h" +#include "net/socket/client_socket_factory.h" +#include "net/udp/datagram_client_socket.h" + +#if defined(OS_LINUX) +#include "net/base/address_tracker_linux.h" +#endif + +namespace net { + +namespace { + +// Address sorting is performed according to RFC3484 with revisions. +// http://tools.ietf.org/html/draft-ietf-6man-rfc3484bis-06 +// Precedence and label are separate to support override through /etc/gai.conf. + +// Returns true if |p1| should precede |p2| in the table. +// Sorts table by decreasing prefix size to allow longest prefix matching. +bool ComparePolicy(const AddressSorterPosix::PolicyEntry& p1, + const AddressSorterPosix::PolicyEntry& p2) { + return p1.prefix_length > p2.prefix_length; +} + +// Creates sorted PolicyTable from |table| with |size| entries. +AddressSorterPosix::PolicyTable LoadPolicy( + AddressSorterPosix::PolicyEntry* table, + size_t size) { + AddressSorterPosix::PolicyTable result(table, table + size); + std::sort(result.begin(), result.end(), ComparePolicy); + return result; +} + +// Search |table| for matching prefix of |address|. |table| must be sorted by +// descending prefix (prefix of another prefix must be later in table). +unsigned GetPolicyValue(const AddressSorterPosix::PolicyTable& table, + const IPAddressNumber& address) { + if (address.size() == kIPv4AddressSize) + return GetPolicyValue(table, ConvertIPv4NumberToIPv6Number(address)); + for (unsigned i = 0; i < table.size(); ++i) { + const AddressSorterPosix::PolicyEntry& entry = table[i]; + IPAddressNumber prefix(entry.prefix, entry.prefix + kIPv6AddressSize); + if (IPNumberMatchesPrefix(address, prefix, entry.prefix_length)) + return entry.value; + } + NOTREACHED(); + // The last entry is the least restrictive, so assume it's default. + return table.back().value; +} + +bool IsIPv6Multicast(const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + return address[0] == 0xFF; +} + +AddressSorterPosix::AddressScope GetIPv6MulticastScope( + const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + return static_cast<AddressSorterPosix::AddressScope>(address[1] & 0x0F); +} + +bool IsIPv6Loopback(const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + // IN6_IS_ADDR_LOOPBACK + unsigned char kLoopback[kIPv6AddressSize] = { + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, + }; + return address == IPAddressNumber(kLoopback, kLoopback + kIPv6AddressSize); +} + +bool IsIPv6LinkLocal(const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + // IN6_IS_ADDR_LINKLOCAL + return (address[0] == 0xFE) && ((address[1] & 0xC0) == 0x80); +} + +bool IsIPv6SiteLocal(const IPAddressNumber& address) { + DCHECK_EQ(kIPv6AddressSize, address.size()); + // IN6_IS_ADDR_SITELOCAL + return (address[0] == 0xFE) && ((address[1] & 0xC0) == 0xC0); +} + +AddressSorterPosix::AddressScope GetScope( + const AddressSorterPosix::PolicyTable& ipv4_scope_table, + const IPAddressNumber& address) { + if (address.size() == kIPv6AddressSize) { + if (IsIPv6Multicast(address)) { + return GetIPv6MulticastScope(address); + } else if (IsIPv6Loopback(address) || IsIPv6LinkLocal(address)) { + return AddressSorterPosix::SCOPE_LINKLOCAL; + } else if (IsIPv6SiteLocal(address)) { + return AddressSorterPosix::SCOPE_SITELOCAL; + } else { + return AddressSorterPosix::SCOPE_GLOBAL; + } + } else if (address.size() == kIPv4AddressSize) { + return static_cast<AddressSorterPosix::AddressScope>( + GetPolicyValue(ipv4_scope_table, address)); + } else { + NOTREACHED(); + return AddressSorterPosix::SCOPE_NODELOCAL; + } +} + +// Default policy table. RFC 3484, Section 2.1. +AddressSorterPosix::PolicyEntry kDefaultPrecedenceTable[] = { + // ::1/128 -- loopback + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, 128, 50 }, + // ::/0 -- any + { { }, 0, 40 }, + // ::ffff:0:0/96 -- IPv4 mapped + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF }, 96, 35 }, + // 2002::/16 -- 6to4 + { { 0x20, 0x02, }, 16, 30 }, + // 2001::/32 -- Teredo + { { 0x20, 0x01, 0, 0 }, 32, 5 }, + // fc00::/7 -- unique local address + { { 0xFC }, 7, 3 }, + // ::/96 -- IPv4 compatible + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 96, 1 }, + // fec0::/10 -- site-local expanded scope + { { 0xFE, 0xC0 }, 10, 1 }, + // 3ffe::/16 -- 6bone + { { 0x3F, 0xFE }, 16, 1 }, +}; + +AddressSorterPosix::PolicyEntry kDefaultLabelTable[] = { + // ::1/128 -- loopback + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, 128, 0 }, + // ::/0 -- any + { { }, 0, 1 }, + // ::ffff:0:0/96 -- IPv4 mapped + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF }, 96, 4 }, + // 2002::/16 -- 6to4 + { { 0x20, 0x02, }, 16, 2 }, + // 2001::/32 -- Teredo + { { 0x20, 0x01, 0, 0 }, 32, 5 }, + // fc00::/7 -- unique local address + { { 0xFC }, 7, 13 }, + // ::/96 -- IPv4 compatible + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 96, 3 }, + // fec0::/10 -- site-local expanded scope + { { 0xFE, 0xC0 }, 10, 11 }, + // 3ffe::/16 -- 6bone + { { 0x3F, 0xFE }, 16, 12 }, +}; + +// Default mapping of IPv4 addresses to scope. +AddressSorterPosix::PolicyEntry kDefaultIPv4ScopeTable[] = { + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0x7F }, 104, + AddressSorterPosix::SCOPE_LINKLOCAL }, + { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xA9, 0xFE }, 112, + AddressSorterPosix::SCOPE_LINKLOCAL }, + { { }, 0, AddressSorterPosix::SCOPE_GLOBAL }, +}; + +// Returns number of matching initial bits between the addresses |a1| and |a2|. +unsigned CommonPrefixLength(const IPAddressNumber& a1, + const IPAddressNumber& a2) { + DCHECK_EQ(a1.size(), a2.size()); + for (size_t i = 0; i < a1.size(); ++i) { + unsigned diff = a1[i] ^ a2[i]; + if (!diff) + continue; + for (unsigned j = 0; j < CHAR_BIT; ++j) { + if (diff & (1 << (CHAR_BIT - 1))) + return i * CHAR_BIT + j; + diff <<= 1; + } + NOTREACHED(); + } + return a1.size() * CHAR_BIT; +} + +// Computes the number of leading 1-bits in |mask|. +unsigned MaskPrefixLength(const IPAddressNumber& mask) { + IPAddressNumber all_ones(mask.size(), 0xFF); + return CommonPrefixLength(mask, all_ones); +} + +struct DestinationInfo { + IPAddressNumber address; + AddressSorterPosix::AddressScope scope; + unsigned precedence; + unsigned label; + const AddressSorterPosix::SourceAddressInfo* src; + unsigned common_prefix_length; +}; + +// Returns true iff |dst_a| should precede |dst_b| in the address list. +// RFC 3484, section 6. +bool CompareDestinations(const DestinationInfo* dst_a, + const DestinationInfo* dst_b) { + // Rule 1: Avoid unusable destinations. + // Unusable destinations are already filtered out. + DCHECK(dst_a->src); + DCHECK(dst_b->src); + + // Rule 2: Prefer matching scope. + bool scope_match1 = (dst_a->src->scope == dst_a->scope); + bool scope_match2 = (dst_b->src->scope == dst_b->scope); + if (scope_match1 != scope_match2) + return scope_match1; + + // Rule 3: Avoid deprecated addresses. + if (dst_a->src->deprecated != dst_b->src->deprecated) + return !dst_a->src->deprecated; + + // Rule 4: Prefer home addresses. + if (dst_a->src->home != dst_b->src->home) + return dst_a->src->home; + + // Rule 5: Prefer matching label. + bool label_match1 = (dst_a->src->label == dst_a->label); + bool label_match2 = (dst_b->src->label == dst_b->label); + if (label_match1 != label_match2) + return label_match1; + + // Rule 6: Prefer higher precedence. + if (dst_a->precedence != dst_b->precedence) + return dst_a->precedence > dst_b->precedence; + + // Rule 7: Prefer native transport. + if (dst_a->src->native != dst_b->src->native) + return dst_a->src->native; + + // Rule 8: Prefer smaller scope. + if (dst_a->scope != dst_b->scope) + return dst_a->scope < dst_b->scope; + + // Rule 9: Use longest matching prefix. Only for matching address families. + if (dst_a->address.size() == dst_b->address.size()) { + if (dst_a->common_prefix_length != dst_b->common_prefix_length) + return dst_a->common_prefix_length > dst_b->common_prefix_length; + } + + // Rule 10: Leave the order unchanged. + // stable_sort takes care of that. + return false; +} + +} // namespace + +AddressSorterPosix::AddressSorterPosix(ClientSocketFactory* socket_factory) + : socket_factory_(socket_factory), + precedence_table_(LoadPolicy(kDefaultPrecedenceTable, + arraysize(kDefaultPrecedenceTable))), + label_table_(LoadPolicy(kDefaultLabelTable, + arraysize(kDefaultLabelTable))), + ipv4_scope_table_(LoadPolicy(kDefaultIPv4ScopeTable, + arraysize(kDefaultIPv4ScopeTable))) { + NetworkChangeNotifier::AddIPAddressObserver(this); + OnIPAddressChanged(); +} + +AddressSorterPosix::~AddressSorterPosix() { + NetworkChangeNotifier::RemoveIPAddressObserver(this); +} + +void AddressSorterPosix::Sort(const AddressList& list, + const CallbackType& callback) const { + DCHECK(CalledOnValidThread()); + ScopedVector<DestinationInfo> sort_list; + + for (size_t i = 0; i < list.size(); ++i) { + scoped_ptr<DestinationInfo> info(new DestinationInfo()); + info->address = list[i].address(); + info->scope = GetScope(ipv4_scope_table_, info->address); + info->precedence = GetPolicyValue(precedence_table_, info->address); + info->label = GetPolicyValue(label_table_, info->address); + + // Each socket can only be bound once. + scoped_ptr<DatagramClientSocket> socket( + socket_factory_->CreateDatagramClientSocket( + DatagramSocket::DEFAULT_BIND, + RandIntCallback(), + NULL /* NetLog */, + NetLog::Source())); + + // Even though no packets are sent, cannot use port 0 in Connect. + IPEndPoint dest(info->address, 80 /* port */); + int rv = socket->Connect(dest); + if (rv != OK) { + LOG(WARNING) << "Could not connect to " << dest.ToStringWithoutPort() + << " reason " << rv; + continue; + } + // Filter out unusable destinations. + IPEndPoint src; + rv = socket->GetLocalAddress(&src); + if (rv != OK) { + LOG(WARNING) << "Could not get local address for " + << src.ToStringWithoutPort() << " reason " << rv; + continue; + } + + SourceAddressInfo& src_info = source_map_[src.address()]; + if (src_info.scope == SCOPE_UNDEFINED) { + // If |source_info_| is out of date, |src| might be missing, but we still + // want to sort, even though the HostCache will be cleared soon. + FillPolicy(src.address(), &src_info); + } + info->src = &src_info; + + if (info->address.size() == src.address().size()) { + info->common_prefix_length = std::min( + CommonPrefixLength(info->address, src.address()), + info->src->prefix_length); + } + sort_list.push_back(info.release()); + } + + std::stable_sort(sort_list.begin(), sort_list.end(), CompareDestinations); + + AddressList result; + for (size_t i = 0; i < sort_list.size(); ++i) + result.push_back(IPEndPoint(sort_list[i]->address, 0 /* port */)); + + callback.Run(true, result); +} + +void AddressSorterPosix::OnIPAddressChanged() { + DCHECK(CalledOnValidThread()); + source_map_.clear(); +#if defined(OS_LINUX) + const internal::AddressTrackerLinux* tracker = + NetworkChangeNotifier::GetAddressTracker(); + if (!tracker) + return; + typedef internal::AddressTrackerLinux::AddressMap AddressMap; + AddressMap map = tracker->GetAddressMap(); + for (AddressMap::const_iterator it = map.begin(); it != map.end(); ++it) { + const IPAddressNumber& address = it->first; + const struct ifaddrmsg& msg = it->second; + SourceAddressInfo& info = source_map_[address]; + info.native = false; // TODO(szym): obtain this via netlink. + info.deprecated = msg.ifa_flags & IFA_F_DEPRECATED; + info.home = msg.ifa_flags & IFA_F_HOMEADDRESS; + info.prefix_length = msg.ifa_prefixlen; + FillPolicy(address, &info); + } +#elif defined(OS_MACOSX) || defined(OS_BSD) + // It's not clear we will receive notification when deprecated flag changes. + // Socket for ioctl. + int ioctl_socket = socket(AF_INET6, SOCK_DGRAM, 0); + if (ioctl_socket < 0) + return; + + struct ifaddrs* addrs; + int rv = getifaddrs(&addrs); + if (rv < 0) { + LOG(WARNING) << "getifaddrs failed " << rv; + close(ioctl_socket); + return; + } + + for (struct ifaddrs* ifa = addrs; ifa != NULL; ifa = ifa->ifa_next) { + IPEndPoint src; + if (!src.FromSockAddr(ifa->ifa_addr, ifa->ifa_addr->sa_len)) + continue; + SourceAddressInfo& info = source_map_[src.address()]; + // Note: no known way to fill in |native| and |home|. + info.native = info.home = info.deprecated = false; + if (ifa->ifa_addr->sa_family == AF_INET6) { + struct in6_ifreq ifr = {}; + strncpy(ifr.ifr_name, ifa->ifa_name, sizeof(ifr.ifr_name) - 1); + DCHECK_LE(ifa->ifa_addr->sa_len, sizeof(ifr.ifr_ifru.ifru_addr)); + memcpy(&ifr.ifr_ifru.ifru_addr, ifa->ifa_addr, ifa->ifa_addr->sa_len); + int rv = ioctl(ioctl_socket, SIOCGIFAFLAG_IN6, &ifr); + if (rv >= 0) { + info.deprecated = ifr.ifr_ifru.ifru_flags & IN6_IFF_DEPRECATED; + } else { + LOG(WARNING) << "SIOCGIFAFLAG_IN6 failed " << rv; + } + } + if (ifa->ifa_netmask) { + IPEndPoint netmask; + if (netmask.FromSockAddr(ifa->ifa_netmask, ifa->ifa_addr->sa_len)) { + info.prefix_length = MaskPrefixLength(netmask.address()); + } else { + LOG(WARNING) << "FromSockAddr failed on netmask"; + } + } + FillPolicy(src.address(), &info); + } + freeifaddrs(addrs); + close(ioctl_socket); +#endif +} + +void AddressSorterPosix::FillPolicy(const IPAddressNumber& address, + SourceAddressInfo* info) const { + DCHECK(CalledOnValidThread()); + info->scope = GetScope(ipv4_scope_table_, address); + info->label = GetPolicyValue(label_table_, address); +} + +// static +scoped_ptr<AddressSorter> AddressSorter::CreateAddressSorter() { + return scoped_ptr<AddressSorter>( + new AddressSorterPosix(ClientSocketFactory::GetDefaultFactory())); +} + +} // namespace net + diff --git a/chromium/net/dns/address_sorter_posix.h b/chromium/net/dns/address_sorter_posix.h new file mode 100644 index 00000000000..1c88ad2f978 --- /dev/null +++ b/chromium/net/dns/address_sorter_posix.h @@ -0,0 +1,94 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_ADDRESS_SORTER_POSIX_H_ +#define NET_DNS_ADDRESS_SORTER_POSIX_H_ + +#include <map> +#include <vector> + +#include "base/threading/non_thread_safe.h" +#include "net/base/address_list.h" +#include "net/base/net_errors.h" +#include "net/base/net_export.h" +#include "net/base/net_util.h" +#include "net/base/network_change_notifier.h" +#include "net/dns/address_sorter.h" + +namespace net { + +class ClientSocketFactory; + +// This implementation uses explicit policy to perform the sorting. It is not +// thread-safe and always completes synchronously. +class NET_EXPORT_PRIVATE AddressSorterPosix + : public AddressSorter, + public base::NonThreadSafe, + public NetworkChangeNotifier::IPAddressObserver { + public: + // Generic policy entry. + struct PolicyEntry { + // IPv4 addresses must be mapped to IPv6. + unsigned char prefix[kIPv6AddressSize]; + unsigned prefix_length; + unsigned value; + }; + + typedef std::vector<PolicyEntry> PolicyTable; + + enum AddressScope { + SCOPE_UNDEFINED = 0, + SCOPE_NODELOCAL = 1, + SCOPE_LINKLOCAL = 2, + SCOPE_SITELOCAL = 5, + SCOPE_ORGLOCAL = 8, + SCOPE_GLOBAL = 14, + }; + + struct SourceAddressInfo { + // Values read from policy tables. + AddressScope scope; + unsigned label; + + // Values from the OS, matter only if more than one source address is used. + unsigned prefix_length; + bool deprecated; // vs. preferred RFC4862 + bool home; // vs. care-of RFC6275 + bool native; + }; + + typedef std::map<IPAddressNumber, SourceAddressInfo> SourceAddressMap; + + explicit AddressSorterPosix(ClientSocketFactory* socket_factory); + virtual ~AddressSorterPosix(); + + // AddressSorter: + virtual void Sort(const AddressList& list, + const CallbackType& callback) const OVERRIDE; + + private: + friend class AddressSorterPosixTest; + + // NetworkChangeNotifier::IPAddressObserver: + virtual void OnIPAddressChanged() OVERRIDE; + + // Fills |info| with values for |address| from policy tables. + void FillPolicy(const IPAddressNumber& address, + SourceAddressInfo* info) const; + + // Mutable to allow using default values for source addresses which were not + // found in most recent OnIPAddressChanged. + mutable SourceAddressMap source_map_; + + ClientSocketFactory* socket_factory_; + PolicyTable precedence_table_; + PolicyTable label_table_; + PolicyTable ipv4_scope_table_; + + DISALLOW_COPY_AND_ASSIGN(AddressSorterPosix); +}; + +} // namespace net + +#endif // NET_DNS_ADDRESS_SORTER_POSIX_H_ diff --git a/chromium/net/dns/address_sorter_posix_unittest.cc b/chromium/net/dns/address_sorter_posix_unittest.cc new file mode 100644 index 00000000000..c4517379957 --- /dev/null +++ b/chromium/net/dns/address_sorter_posix_unittest.cc @@ -0,0 +1,327 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/address_sorter_posix.h" + +#include "base/bind.h" +#include "base/logging.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/stream_socket.h" +#include "net/udp/datagram_client_socket.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { +namespace { + +// Used to map destination address to source address. +typedef std::map<IPAddressNumber, IPAddressNumber> AddressMapping; + +IPAddressNumber ParseIP(const std::string& str) { + IPAddressNumber addr; + CHECK(ParseIPLiteralToNumber(str, &addr)); + return addr; +} + +// A mock socket which binds to source address according to AddressMapping. +class TestUDPClientSocket : public DatagramClientSocket { + public: + explicit TestUDPClientSocket(const AddressMapping* mapping) + : mapping_(mapping), connected_(false) {} + + virtual ~TestUDPClientSocket() {} + + virtual int Read(IOBuffer*, int, const CompletionCallback&) OVERRIDE { + NOTIMPLEMENTED(); + return OK; + } + virtual int Write(IOBuffer*, int, const CompletionCallback&) OVERRIDE { + NOTIMPLEMENTED(); + return OK; + } + virtual bool SetReceiveBufferSize(int32) OVERRIDE { + return true; + } + virtual bool SetSendBufferSize(int32) OVERRIDE { + return true; + } + + virtual void Close() OVERRIDE {} + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + NOTIMPLEMENTED(); + return OK; + } + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + if (!connected_) + return ERR_UNEXPECTED; + *address = local_endpoint_; + return OK; + } + + virtual int Connect(const IPEndPoint& remote) OVERRIDE { + if (connected_) + return ERR_UNEXPECTED; + AddressMapping::const_iterator it = mapping_->find(remote.address()); + if (it == mapping_->end()) + return ERR_FAILED; + connected_ = true; + local_endpoint_ = IPEndPoint(it->second, 39874 /* arbitrary port */); + return OK; + } + + virtual const BoundNetLog& NetLog() const OVERRIDE { + return net_log_; + } + + private: + BoundNetLog net_log_; + const AddressMapping* mapping_; + bool connected_; + IPEndPoint local_endpoint_; + + DISALLOW_COPY_AND_ASSIGN(TestUDPClientSocket); +}; + +// Creates TestUDPClientSockets and maintains an AddressMapping. +class TestSocketFactory : public ClientSocketFactory { + public: + TestSocketFactory() {} + virtual ~TestSocketFactory() {} + + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( + DatagramSocket::BindType, + const RandIntCallback&, + NetLog*, + const NetLog::Source&) OVERRIDE { + return scoped_ptr<DatagramClientSocket>(new TestUDPClientSocket(&mapping_)); + } + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( + const AddressList&, + NetLog*, + const NetLog::Source&) OVERRIDE { + NOTIMPLEMENTED(); + return scoped_ptr<StreamSocket>(); + } + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle>, + const HostPortPair&, + const SSLConfig&, + const SSLClientSocketContext&) OVERRIDE { + NOTIMPLEMENTED(); + return scoped_ptr<SSLClientSocket>(); + } + virtual void ClearSSLSessionCache() OVERRIDE { + NOTIMPLEMENTED(); + } + + void AddMapping(const IPAddressNumber& dst, const IPAddressNumber& src) { + mapping_[dst] = src; + } + + private: + AddressMapping mapping_; + + DISALLOW_COPY_AND_ASSIGN(TestSocketFactory); +}; + +void OnSortComplete(AddressList* result_buf, + const CompletionCallback& callback, + bool success, + const AddressList& result) { + EXPECT_TRUE(success); + if (success) + *result_buf = result; + callback.Run(OK); +} + +} // namespace + +class AddressSorterPosixTest : public testing::Test { + protected: + AddressSorterPosixTest() : sorter_(&socket_factory_) {} + + void AddMapping(const std::string& dst, const std::string& src) { + socket_factory_.AddMapping(ParseIP(dst), ParseIP(src)); + } + + AddressSorterPosix::SourceAddressInfo* GetSourceInfo( + const std::string& addr) { + IPAddressNumber address = ParseIP(addr); + AddressSorterPosix::SourceAddressInfo* info = &sorter_.source_map_[address]; + if (info->scope == AddressSorterPosix::SCOPE_UNDEFINED) + sorter_.FillPolicy(address, info); + return info; + } + + // Verify that NULL-terminated |addresses| matches (-1)-terminated |order| + // after sorting. + void Verify(const char* addresses[], const int order[]) { + AddressList list; + for (const char** addr = addresses; *addr != NULL; ++addr) + list.push_back(IPEndPoint(ParseIP(*addr), 80)); + for (size_t i = 0; order[i] >= 0; ++i) + CHECK_LT(order[i], static_cast<int>(list.size())); + + AddressList result; + TestCompletionCallback callback; + sorter_.Sort(list, base::Bind(&OnSortComplete, &result, + callback.callback())); + callback.WaitForResult(); + + for (size_t i = 0; (i < result.size()) || (order[i] >= 0); ++i) { + IPEndPoint expected = order[i] >= 0 ? list[order[i]] : IPEndPoint(); + IPEndPoint actual = i < result.size() ? result[i] : IPEndPoint(); + EXPECT_TRUE(expected.address() == actual.address()) << + "Address out of order at position " << i << "\n" << + " Actual: " << actual.ToStringWithoutPort() << "\n" << + "Expected: " << expected.ToStringWithoutPort(); + } + } + + TestSocketFactory socket_factory_; + AddressSorterPosix sorter_; +}; + +// Rule 1: Avoid unusable destinations. +TEST_F(AddressSorterPosixTest, Rule1) { + AddMapping("10.0.0.231", "10.0.0.1"); + const char* addresses[] = { "::1", "10.0.0.231", "127.0.0.1", NULL }; + const int order[] = { 1, -1 }; + Verify(addresses, order); +} + +// Rule 2: Prefer matching scope. +TEST_F(AddressSorterPosixTest, Rule2) { + AddMapping("3002::1", "4000::10"); // matching global + AddMapping("ff32::1", "fe81::10"); // matching link-local + AddMapping("fec1::1", "fec1::10"); // matching node-local + AddMapping("3002::2", "::1"); // global vs. link-local + AddMapping("fec1::2", "fe81::10"); // site-local vs. link-local + AddMapping("8.0.0.1", "169.254.0.10"); // global vs. link-local + // In all three cases, matching scope is preferred. + const int order[] = { 1, 0, -1 }; + const char* addresses1[] = { "3002::2", "3002::1", NULL }; + Verify(addresses1, order); + const char* addresses2[] = { "fec1::2", "ff32::1", NULL }; + Verify(addresses2, order); + const char* addresses3[] = { "8.0.0.1", "fec1::1", NULL }; + Verify(addresses3, order); +} + +// Rule 3: Avoid deprecated addresses. +TEST_F(AddressSorterPosixTest, Rule3) { + // Matching scope. + AddMapping("3002::1", "4000::10"); + GetSourceInfo("4000::10")->deprecated = true; + AddMapping("3002::2", "4000::20"); + const char* addresses[] = { "3002::1", "3002::2", NULL }; + const int order[] = { 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 4: Prefer home addresses. +TEST_F(AddressSorterPosixTest, Rule4) { + AddMapping("3002::1", "4000::10"); + AddMapping("3002::2", "4000::20"); + GetSourceInfo("4000::20")->home = true; + const char* addresses[] = { "3002::1", "3002::2", NULL }; + const int order[] = { 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 5: Prefer matching label. +TEST_F(AddressSorterPosixTest, Rule5) { + AddMapping("::1", "::1"); // matching loopback + AddMapping("::ffff:1234:1", "::ffff:1234:10"); // matching IPv4-mapped + AddMapping("2001::1", "::ffff:1234:10"); // Teredo vs. IPv4-mapped + AddMapping("2002::1", "2001::10"); // 6to4 vs. Teredo + const int order[] = { 1, 0, -1 }; + { + const char* addresses[] = { "2001::1", "::1", NULL }; + Verify(addresses, order); + } + { + const char* addresses[] = { "2002::1", "::ffff:1234:1", NULL }; + Verify(addresses, order); + } +} + +// Rule 6: Prefer higher precedence. +TEST_F(AddressSorterPosixTest, Rule6) { + AddMapping("::1", "::1"); // loopback + AddMapping("ff32::1", "fe81::10"); // multicast + AddMapping("::ffff:1234:1", "::ffff:1234:10"); // IPv4-mapped + AddMapping("2001::1", "2001::10"); // Teredo + const char* addresses[] = { "2001::1", "::ffff:1234:1", "ff32::1", "::1", + NULL }; + const int order[] = { 3, 2, 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 7: Prefer native transport. +TEST_F(AddressSorterPosixTest, Rule7) { + AddMapping("3002::1", "4000::10"); + AddMapping("3002::2", "4000::20"); + GetSourceInfo("4000::20")->native = true; + const char* addresses[] = { "3002::1", "3002::2", NULL }; + const int order[] = { 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 8: Prefer smaller scope. +TEST_F(AddressSorterPosixTest, Rule8) { + // Matching scope. Should precede the others by Rule 2. + AddMapping("fe81::1", "fe81::10"); // link-local + AddMapping("3000::1", "4000::10"); // global + // Mismatched scope. + AddMapping("ff32::1", "4000::10"); // link-local + AddMapping("ff35::1", "4000::10"); // site-local + AddMapping("ff38::1", "4000::10"); // org-local + const char* addresses[] = { "ff38::1", "3000::1", "ff35::1", "ff32::1", + "fe81::1", NULL }; + const int order[] = { 4, 1, 3, 2, 0, -1 }; + Verify(addresses, order); +} + +// Rule 9: Use longest matching prefix. +TEST_F(AddressSorterPosixTest, Rule9) { + AddMapping("3000::1", "3000:ffff::10"); // 16 bit match + GetSourceInfo("3000:ffff::10")->prefix_length = 16; + AddMapping("4000::1", "4000::10"); // 123 bit match, limited to 15 + GetSourceInfo("4000::10")->prefix_length = 15; + AddMapping("4002::1", "4000::10"); // 14 bit match + AddMapping("4080::1", "4000::10"); // 8 bit match + const char* addresses[] = { "4080::1", "4002::1", "4000::1", "3000::1", + NULL }; + const int order[] = { 3, 2, 1, 0, -1 }; + Verify(addresses, order); +} + +// Rule 10: Leave the order unchanged. +TEST_F(AddressSorterPosixTest, Rule10) { + AddMapping("4000::1", "4000::10"); + AddMapping("4000::2", "4000::10"); + AddMapping("4000::3", "4000::10"); + const char* addresses[] = { "4000::1", "4000::2", "4000::3", NULL }; + const int order[] = { 0, 1, 2, -1 }; + Verify(addresses, order); +} + +TEST_F(AddressSorterPosixTest, MultipleRules) { + AddMapping("::1", "::1"); // loopback + AddMapping("ff32::1", "fe81::10"); // link-local multicast + AddMapping("ff3e::1", "4000::10"); // global multicast + AddMapping("4000::1", "4000::10"); // global unicast + AddMapping("ff32::2", "fe81::20"); // deprecated link-local multicast + GetSourceInfo("fe81::20")->deprecated = true; + const char* addresses[] = { "ff3e::1", "ff32::2", "4000::1", "ff32::1", "::1", + "8.0.0.1", NULL }; + const int order[] = { 4, 3, 0, 2, 1, -1 }; + Verify(addresses, order); +} + +} // namespace net diff --git a/chromium/net/dns/address_sorter_unittest.cc b/chromium/net/dns/address_sorter_unittest.cc new file mode 100644 index 00000000000..0c2be884d29 --- /dev/null +++ b/chromium/net/dns/address_sorter_unittest.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/address_sorter.h" + +#if defined(OS_WIN) +#include <winsock2.h> +#endif + +#include "base/bind.h" +#include "base/logging.h" +#include "net/base/address_list.h" +#include "net/base/net_util.h" +#include "net/base/test_completion_callback.h" +#include "testing/gtest/include/gtest/gtest.h" + +#if defined(OS_WIN) +#include "net/base/winsock_init.h" +#endif + +namespace net { +namespace { + +IPEndPoint MakeEndPoint(const std::string& str) { + IPAddressNumber addr; + CHECK(ParseIPLiteralToNumber(str, &addr)); + return IPEndPoint(addr, 0); +} + +void OnSortComplete(AddressList* result_buf, + const CompletionCallback& callback, + bool success, + const AddressList& result) { + if (success) + *result_buf = result; + callback.Run(success ? OK : ERR_FAILED); +} + +TEST(AddressSorterTest, Sort) { + int expected_result = OK; +#if defined(OS_WIN) + EnsureWinsockInit(); + SOCKET sock = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP); + if (sock == INVALID_SOCKET) { + expected_result = ERR_FAILED; + } else { + closesocket(sock); + } +#endif + scoped_ptr<AddressSorter> sorter(AddressSorter::CreateAddressSorter()); + AddressList list; + list.push_back(MakeEndPoint("10.0.0.1")); + list.push_back(MakeEndPoint("8.8.8.8")); + list.push_back(MakeEndPoint("::1")); + list.push_back(MakeEndPoint("2001:4860:4860::8888")); + + AddressList result; + TestCompletionCallback callback; + sorter->Sort(list, base::Bind(&OnSortComplete, &result, + callback.callback())); + EXPECT_EQ(expected_result, callback.WaitForResult()); +} + +} // namespace +} // namespace net diff --git a/chromium/net/dns/address_sorter_win.cc b/chromium/net/dns/address_sorter_win.cc new file mode 100644 index 00000000000..3e1afa734cc --- /dev/null +++ b/chromium/net/dns/address_sorter_win.cc @@ -0,0 +1,198 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/address_sorter.h" + +#include <winsock2.h> + +#include <algorithm> + +#include "base/bind.h" +#include "base/location.h" +#include "base/logging.h" +#include "base/threading/worker_pool.h" +#include "base/win/windows_version.h" +#include "net/base/address_list.h" +#include "net/base/ip_endpoint.h" +#include "net/base/winsock_init.h" + +namespace net { + +namespace { + +class AddressSorterWin : public AddressSorter { + public: + AddressSorterWin() { + EnsureWinsockInit(); + } + + virtual ~AddressSorterWin() {} + + // AddressSorter: + virtual void Sort(const AddressList& list, + const CallbackType& callback) const OVERRIDE { + DCHECK(!list.empty()); + scoped_refptr<Job> job = new Job(list, callback); + } + + private: + // Executes the SIO_ADDRESS_LIST_SORT ioctl on the WorkerPool, and + // performs the necessary conversions to/from AddressList. + class Job : public base::RefCountedThreadSafe<Job> { + public: + Job(const AddressList& list, const CallbackType& callback) + : callback_(callback), + buffer_size_(sizeof(SOCKET_ADDRESS_LIST) + + list.size() * (sizeof(SOCKET_ADDRESS) + + sizeof(SOCKADDR_STORAGE))), + input_buffer_(reinterpret_cast<SOCKET_ADDRESS_LIST*>( + malloc(buffer_size_))), + output_buffer_(reinterpret_cast<SOCKET_ADDRESS_LIST*>( + malloc(buffer_size_))), + success_(false) { + input_buffer_->iAddressCount = list.size(); + SOCKADDR_STORAGE* storage = reinterpret_cast<SOCKADDR_STORAGE*>( + input_buffer_->Address + input_buffer_->iAddressCount); + + for (size_t i = 0; i < list.size(); ++i) { + IPEndPoint ipe = list[i]; + // Addresses must be sockaddr_in6. + if (ipe.GetFamily() == ADDRESS_FAMILY_IPV4) { + ipe = IPEndPoint(ConvertIPv4NumberToIPv6Number(ipe.address()), + ipe.port()); + } + + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(storage + i); + socklen_t addr_len = sizeof(SOCKADDR_STORAGE); + bool result = ipe.ToSockAddr(addr, &addr_len); + DCHECK(result); + input_buffer_->Address[i].lpSockaddr = addr; + input_buffer_->Address[i].iSockaddrLength = addr_len; + } + + if (!base::WorkerPool::PostTaskAndReply( + FROM_HERE, + base::Bind(&Job::Run, this), + base::Bind(&Job::OnComplete, this), + false /* task is slow */)) { + LOG(ERROR) << "WorkerPool::PostTaskAndReply failed"; + OnComplete(); + } + } + + private: + friend class base::RefCountedThreadSafe<Job>; + ~Job() {} + + // Executed on the WorkerPool. + void Run() { + SOCKET sock = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP); + if (sock == INVALID_SOCKET) + return; + DWORD result_size = 0; + int result = WSAIoctl(sock, SIO_ADDRESS_LIST_SORT, input_buffer_.get(), + buffer_size_, output_buffer_.get(), buffer_size_, + &result_size, NULL, NULL); + if (result == SOCKET_ERROR) { + LOG(ERROR) << "SIO_ADDRESS_LIST_SORT failed " << WSAGetLastError(); + } else { + success_ = true; + } + closesocket(sock); + } + + // Executed on the calling thread. + void OnComplete() { + AddressList list; + if (success_) { + list.reserve(output_buffer_->iAddressCount); + for (int i = 0; i < output_buffer_->iAddressCount; ++i) { + IPEndPoint ipe; + ipe.FromSockAddr(output_buffer_->Address[i].lpSockaddr, + output_buffer_->Address[i].iSockaddrLength); + // Unmap V4MAPPED IPv6 addresses so that Happy Eyeballs works. + if (IsIPv4Mapped(ipe.address())) { + ipe = IPEndPoint(ConvertIPv4MappedToIPv4(ipe.address()), + ipe.port()); + } + list.push_back(ipe); + } + } + callback_.Run(success_, list); + } + + const CallbackType callback_; + const size_t buffer_size_; + scoped_ptr_malloc<SOCKET_ADDRESS_LIST> input_buffer_; + scoped_ptr_malloc<SOCKET_ADDRESS_LIST> output_buffer_; + bool success_; + + DISALLOW_COPY_AND_ASSIGN(Job); + }; + + DISALLOW_COPY_AND_ASSIGN(AddressSorterWin); +}; + +// Merges |list_ipv4| and |list_ipv6| before passing it to |callback|, but +// only if |success| is true. +void MergeResults(const AddressSorter::CallbackType& callback, + const AddressList& list_ipv4, + bool success, + const AddressList& list_ipv6) { + if (!success) { + callback.Run(false, AddressList()); + return; + } + AddressList list; + list.insert(list.end(), list_ipv6.begin(), list_ipv6.end()); + list.insert(list.end(), list_ipv4.begin(), list_ipv4.end()); + callback.Run(true, list); +} + +// Wrapper for AddressSorterWin which does not sort IPv4 or IPv4-mapped +// addresses but always puts them at the end of the list. Needed because the +// SIO_ADDRESS_LIST_SORT does not support IPv4 addresses on Windows XP. +class AddressSorterWinXP : public AddressSorter { + public: + AddressSorterWinXP() {} + virtual ~AddressSorterWinXP() {} + + // AddressSorter: + virtual void Sort(const AddressList& list, + const CallbackType& callback) const OVERRIDE { + AddressList list_ipv4; + AddressList list_ipv6; + for (size_t i = 0; i < list.size(); ++i) { + const IPEndPoint& ipe = list[i]; + if (ipe.GetFamily() == ADDRESS_FAMILY_IPV4) { + list_ipv4.push_back(ipe); + } else { + list_ipv6.push_back(ipe); + } + } + if (!list_ipv6.empty()) { + sorter_.Sort(list_ipv6, base::Bind(&MergeResults, callback, list_ipv4)); + } else { + NOTREACHED() << "Should not be called with IPv4-only addresses."; + callback.Run(true, list); + } + } + + private: + AddressSorterWin sorter_; + + DISALLOW_COPY_AND_ASSIGN(AddressSorterWinXP); +}; + +} // namespace + +// static +scoped_ptr<AddressSorter> AddressSorter::CreateAddressSorter() { + if (base::win::GetVersion() < base::win::VERSION_VISTA) + return scoped_ptr<AddressSorter>(new AddressSorterWinXP()); + return scoped_ptr<AddressSorter>(new AddressSorterWin()); +} + +} // namespace net + diff --git a/chromium/net/dns/dns_client.cc b/chromium/net/dns/dns_client.cc new file mode 100644 index 00000000000..976f1533905 --- /dev/null +++ b/chromium/net/dns/dns_client.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_client.h" + +#include "base/bind.h" +#include "base/rand_util.h" +#include "net/base/net_log.h" +#include "net/dns/address_sorter.h" +#include "net/dns/dns_config_service.h" +#include "net/dns/dns_session.h" +#include "net/dns/dns_socket_pool.h" +#include "net/dns/dns_transaction.h" +#include "net/socket/client_socket_factory.h" + +namespace net { + +namespace { + +class DnsClientImpl : public DnsClient { + public: + explicit DnsClientImpl(NetLog* net_log) + : address_sorter_(AddressSorter::CreateAddressSorter()), + net_log_(net_log) {} + + virtual void SetConfig(const DnsConfig& config) OVERRIDE { + factory_.reset(); + session_ = NULL; + if (config.IsValid()) { + ClientSocketFactory* factory = ClientSocketFactory::GetDefaultFactory(); + scoped_ptr<DnsSocketPool> socket_pool( + config.randomize_ports ? DnsSocketPool::CreateDefault(factory) + : DnsSocketPool::CreateNull(factory)); + session_ = new DnsSession(config, + socket_pool.Pass(), + base::Bind(&base::RandInt), + net_log_); + factory_ = DnsTransactionFactory::CreateFactory(session_.get()); + } + } + + virtual const DnsConfig* GetConfig() const OVERRIDE { + return session_.get() ? &session_->config() : NULL; + } + + virtual DnsTransactionFactory* GetTransactionFactory() OVERRIDE { + return session_.get() ? factory_.get() : NULL; + } + + virtual AddressSorter* GetAddressSorter() OVERRIDE { + return address_sorter_.get(); + } + + private: + scoped_refptr<DnsSession> session_; + scoped_ptr<DnsTransactionFactory> factory_; + scoped_ptr<AddressSorter> address_sorter_; + + NetLog* net_log_; +}; + +} // namespace + +// static +scoped_ptr<DnsClient> DnsClient::CreateClient(NetLog* net_log) { + return scoped_ptr<DnsClient>(new DnsClientImpl(net_log)); +} + +} // namespace net + diff --git a/chromium/net/dns/dns_client.h b/chromium/net/dns/dns_client.h new file mode 100644 index 00000000000..650c7d0d416 --- /dev/null +++ b/chromium/net/dns/dns_client.h @@ -0,0 +1,44 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_CLIENT_H_ +#define NET_DNS_DNS_CLIENT_H_ + +#include "base/memory/scoped_ptr.h" +#include "net/base/net_export.h" + +namespace net { + +class AddressSorter; +struct DnsConfig; +class DnsTransactionFactory; +class NetLog; + +// Convenience wrapper which allows easy injection of DnsTransaction into +// HostResolverImpl. Pointers returned by the Get* methods are only guaranteed +// to remain valid until next time SetConfig is called. +class NET_EXPORT DnsClient { + public: + virtual ~DnsClient() {} + + // Creates a new DnsTransactionFactory according to the new |config|. + virtual void SetConfig(const DnsConfig& config) = 0; + + // Returns NULL if the current config is not valid. + virtual const DnsConfig* GetConfig() const = 0; + + // Returns NULL if the current config is not valid. + virtual DnsTransactionFactory* GetTransactionFactory() = 0; + + // Returns NULL if the current config is not valid. + virtual AddressSorter* GetAddressSorter() = 0; + + // Creates default client. + static scoped_ptr<DnsClient> CreateClient(NetLog* net_log); +}; + +} // namespace net + +#endif // NET_DNS_DNS_CLIENT_H_ + diff --git a/chromium/net/dns/dns_config_service.cc b/chromium/net/dns/dns_config_service.cc new file mode 100644 index 00000000000..ea8a3421cd2 --- /dev/null +++ b/chromium/net/dns/dns_config_service.cc @@ -0,0 +1,226 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_config_service.h" + +#include "base/logging.h" +#include "base/metrics/histogram.h" +#include "base/values.h" +#include "net/base/ip_endpoint.h" + +namespace net { + +// Default values are taken from glibc resolv.h except timeout which is set to +// |kDnsTimeoutSeconds|. +DnsConfig::DnsConfig() + : append_to_multi_label_name(true), + randomize_ports(false), + ndots(1), + timeout(base::TimeDelta::FromSeconds(kDnsTimeoutSeconds)), + attempts(2), + rotate(false), + edns0(false) {} + +DnsConfig::~DnsConfig() {} + +bool DnsConfig::Equals(const DnsConfig& d) const { + return EqualsIgnoreHosts(d) && (hosts == d.hosts); +} + +bool DnsConfig::EqualsIgnoreHosts(const DnsConfig& d) const { + return (nameservers == d.nameservers) && + (search == d.search) && + (append_to_multi_label_name == d.append_to_multi_label_name) && + (ndots == d.ndots) && + (timeout == d.timeout) && + (attempts == d.attempts) && + (rotate == d.rotate) && + (edns0 == d.edns0); +} + +void DnsConfig::CopyIgnoreHosts(const DnsConfig& d) { + nameservers = d.nameservers; + search = d.search; + append_to_multi_label_name = d.append_to_multi_label_name; + ndots = d.ndots; + timeout = d.timeout; + attempts = d.attempts; + rotate = d.rotate; + edns0 = d.edns0; +} + +base::Value* DnsConfig::ToValue() const { + base::DictionaryValue* dict = new base::DictionaryValue(); + + base::ListValue* list = new base::ListValue(); + for (size_t i = 0; i < nameservers.size(); ++i) + list->Append(new base::StringValue(nameservers[i].ToString())); + dict->Set("nameservers", list); + + list = new base::ListValue(); + for (size_t i = 0; i < search.size(); ++i) + list->Append(new base::StringValue(search[i])); + dict->Set("search", list); + + dict->SetBoolean("append_to_multi_label_name", append_to_multi_label_name); + dict->SetInteger("ndots", ndots); + dict->SetDouble("timeout", timeout.InSecondsF()); + dict->SetInteger("attempts", attempts); + dict->SetBoolean("rotate", rotate); + dict->SetBoolean("edns0", edns0); + dict->SetInteger("num_hosts", hosts.size()); + + return dict; +} + + +DnsConfigService::DnsConfigService() + : watch_failed_(false), + have_config_(false), + have_hosts_(false), + need_update_(false), + last_sent_empty_(true) {} + +DnsConfigService::~DnsConfigService() { +} + +void DnsConfigService::ReadConfig(const CallbackType& callback) { + DCHECK(CalledOnValidThread()); + DCHECK(!callback.is_null()); + DCHECK(callback_.is_null()); + callback_ = callback; + ReadNow(); +} + +void DnsConfigService::WatchConfig(const CallbackType& callback) { + DCHECK(CalledOnValidThread()); + DCHECK(!callback.is_null()); + DCHECK(callback_.is_null()); + callback_ = callback; + watch_failed_ = !StartWatching(); + ReadNow(); +} + +void DnsConfigService::InvalidateConfig() { + DCHECK(CalledOnValidThread()); + base::TimeTicks now = base::TimeTicks::Now(); + if (!last_invalidate_config_time_.is_null()) { + UMA_HISTOGRAM_LONG_TIMES("AsyncDNS.ConfigNotifyInterval", + now - last_invalidate_config_time_); + } + last_invalidate_config_time_ = now; + if (!have_config_) + return; + have_config_ = false; + StartTimer(); +} + +void DnsConfigService::InvalidateHosts() { + DCHECK(CalledOnValidThread()); + base::TimeTicks now = base::TimeTicks::Now(); + if (!last_invalidate_hosts_time_.is_null()) { + UMA_HISTOGRAM_LONG_TIMES("AsyncDNS.HostsNotifyInterval", + now - last_invalidate_hosts_time_); + } + last_invalidate_hosts_time_ = now; + if (!have_hosts_) + return; + have_hosts_ = false; + StartTimer(); +} + +void DnsConfigService::OnConfigRead(const DnsConfig& config) { + DCHECK(CalledOnValidThread()); + DCHECK(config.IsValid()); + + bool changed = false; + if (!config.EqualsIgnoreHosts(dns_config_)) { + dns_config_.CopyIgnoreHosts(config); + need_update_ = true; + changed = true; + } + if (!changed && !last_sent_empty_time_.is_null()) { + UMA_HISTOGRAM_LONG_TIMES("AsyncDNS.UnchangedConfigInterval", + base::TimeTicks::Now() - last_sent_empty_time_); + } + UMA_HISTOGRAM_BOOLEAN("AsyncDNS.ConfigChange", changed); + + have_config_ = true; + if (have_hosts_ || watch_failed_) + OnCompleteConfig(); +} + +void DnsConfigService::OnHostsRead(const DnsHosts& hosts) { + DCHECK(CalledOnValidThread()); + + bool changed = false; + if (hosts != dns_config_.hosts) { + dns_config_.hosts = hosts; + need_update_ = true; + changed = true; + } + if (!changed && !last_sent_empty_time_.is_null()) { + UMA_HISTOGRAM_LONG_TIMES("AsyncDNS.UnchangedHostsInterval", + base::TimeTicks::Now() - last_sent_empty_time_); + } + UMA_HISTOGRAM_BOOLEAN("AsyncDNS.HostsChange", changed); + + have_hosts_ = true; + if (have_config_ || watch_failed_) + OnCompleteConfig(); +} + +void DnsConfigService::StartTimer() { + DCHECK(CalledOnValidThread()); + if (last_sent_empty_) { + DCHECK(!timer_.IsRunning()); + return; // No need to withdraw again. + } + timer_.Stop(); + + // Give it a short timeout to come up with a valid config. Otherwise withdraw + // the config from the receiver. The goal is to avoid perceivable network + // outage (when using the wrong config) but at the same time avoid + // unnecessary Job aborts in HostResolverImpl. The signals come from multiple + // sources so it might receive multiple events during a config change. + + // DHCP and user-induced changes are on the order of seconds, so 150ms should + // not add perceivable delay. On the other hand, config readers should finish + // within 150ms with the rare exception of I/O block or extra large HOSTS. + const base::TimeDelta kTimeout = base::TimeDelta::FromMilliseconds(150); + + timer_.Start(FROM_HERE, + kTimeout, + this, + &DnsConfigService::OnTimeout); +} + +void DnsConfigService::OnTimeout() { + DCHECK(CalledOnValidThread()); + DCHECK(!last_sent_empty_); + // Indicate that even if there is no change in On*Read, we will need to + // update the receiver when the config becomes complete. + need_update_ = true; + // Empty config is considered invalid. + last_sent_empty_ = true; + last_sent_empty_time_ = base::TimeTicks::Now(); + callback_.Run(DnsConfig()); +} + +void DnsConfigService::OnCompleteConfig() { + timer_.Stop(); + if (!need_update_) + return; + need_update_ = false; + last_sent_empty_ = false; + if (watch_failed_) { + // If a watch failed, the config may not be accurate, so report empty. + callback_.Run(DnsConfig()); + } else { + callback_.Run(dns_config_); + } +} + +} // namespace net + diff --git a/chromium/net/dns/dns_config_service.h b/chromium/net/dns/dns_config_service.h new file mode 100644 index 00000000000..4babb9e7d37 --- /dev/null +++ b/chromium/net/dns/dns_config_service.h @@ -0,0 +1,174 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_CONFIG_SERVICE_H_ +#define NET_DNS_DNS_CONFIG_SERVICE_H_ + +#include <map> +#include <string> +#include <vector> + +#include "base/gtest_prod_util.h" +#include "base/memory/scoped_ptr.h" +#include "base/threading/non_thread_safe.h" +#include "base/time/time.h" +#include "base/timer/timer.h" +// Needed on shared build with MSVS2010 to avoid multiple definitions of +// std::vector<IPEndPoint>. +#include "net/base/address_list.h" +#include "net/base/ip_endpoint.h" // win requires size of IPEndPoint +#include "net/base/net_export.h" +#include "net/dns/dns_hosts.h" + +namespace base { +class Value; +} + +namespace net { + +// Always use 1 second timeout (followed by binary exponential backoff). +// TODO(szym): Remove code which reads timeout from system. +const unsigned kDnsTimeoutSeconds = 1; + +// DnsConfig stores configuration of the system resolver. +struct NET_EXPORT_PRIVATE DnsConfig { + DnsConfig(); + virtual ~DnsConfig(); + + bool Equals(const DnsConfig& d) const; + + bool EqualsIgnoreHosts(const DnsConfig& d) const; + + void CopyIgnoreHosts(const DnsConfig& src); + + // Returns a Value representation of |this|. Caller takes ownership of the + // returned Value. For performance reasons, the Value only contains the + // number of hosts rather than the full list. + base::Value* ToValue() const; + + bool IsValid() const { + return !nameservers.empty(); + } + + // List of name server addresses. + std::vector<IPEndPoint> nameservers; + // Suffix search list; used on first lookup when number of dots in given name + // is less than |ndots|. + std::vector<std::string> search; + + DnsHosts hosts; + + // AppendToMultiLabelName: is suffix search performed for multi-label names? + // True, except on Windows where it can be configured. + bool append_to_multi_label_name; + + // Indicates that source port randomization is required. This uses additional + // resources on some platforms. + bool randomize_ports; + + // Resolver options; see man resolv.conf. + + // Minimum number of dots before global resolution precedes |search|. + int ndots; + // Time between retransmissions, see res_state.retrans. + base::TimeDelta timeout; + // Maximum number of attempts, see res_state.retry. + int attempts; + // Round robin entries in |nameservers| for subsequent requests. + bool rotate; + // Enable EDNS0 extensions. + bool edns0; +}; + + +// Service for reading system DNS settings, on demand or when signalled by +// internal watchers and NetworkChangeNotifier. +class NET_EXPORT_PRIVATE DnsConfigService + : NON_EXPORTED_BASE(public base::NonThreadSafe) { + public: + // Callback interface for the client, called on the same thread as + // ReadConfig() and WatchConfig(). + typedef base::Callback<void(const DnsConfig& config)> CallbackType; + + // Creates the platform-specific DnsConfigService. + static scoped_ptr<DnsConfigService> CreateSystemService(); + + DnsConfigService(); + virtual ~DnsConfigService(); + + // Attempts to read the configuration. Will run |callback| when succeeded. + // Can be called at most once. + void ReadConfig(const CallbackType& callback); + + // Registers systems watchers. Will attempt to read config after watch starts, + // but only if watchers started successfully. Will run |callback| iff config + // changes from last call or has to be withdrawn. Can be called at most once. + // Might require MessageLoopForIO. + void WatchConfig(const CallbackType& callback); + + protected: + enum WatchStatus { + DNS_CONFIG_WATCH_STARTED = 0, + DNS_CONFIG_WATCH_FAILED_TO_START_CONFIG, + DNS_CONFIG_WATCH_FAILED_TO_START_HOSTS, + DNS_CONFIG_WATCH_FAILED_CONFIG, + DNS_CONFIG_WATCH_FAILED_HOSTS, + DNS_CONFIG_WATCH_MAX, + }; + + // Immediately attempts to read the current configuration. + virtual void ReadNow() = 0; + // Registers system watchers. Returns true iff succeeds. + virtual bool StartWatching() = 0; + + // Called when the current config (except hosts) has changed. + void InvalidateConfig(); + // Called when the current hosts have changed. + void InvalidateHosts(); + + // Called with new config. |config|.hosts is ignored. + void OnConfigRead(const DnsConfig& config); + // Called with new hosts. Rest of the config is assumed unchanged. + void OnHostsRead(const DnsHosts& hosts); + + void set_watch_failed(bool value) { watch_failed_ = value; } + + private: + // The timer counts from the last Invalidate* until complete config is read. + void StartTimer(); + void OnTimeout(); + // Called when the config becomes complete. Stops the timer. + void OnCompleteConfig(); + + CallbackType callback_; + + DnsConfig dns_config_; + + // True if any of the necessary watchers failed. In that case, the service + // will communicate changes via OnTimeout, but will only send empty DnsConfig. + bool watch_failed_; + // True after On*Read, before Invalidate*. Tells if the config is complete. + bool have_config_; + bool have_hosts_; + // True if receiver needs to be updated when the config becomes complete. + bool need_update_; + // True if the last config sent was empty (instead of |dns_config_|). + // Set when |timer_| expires. + bool last_sent_empty_; + + // Initialized and updated on Invalidate* call. + base::TimeTicks last_invalidate_config_time_; + base::TimeTicks last_invalidate_hosts_time_; + // Initialized and updated when |timer_| expires. + base::TimeTicks last_sent_empty_time_; + + // Started in Invalidate*, cleared in On*Read. + base::OneShotTimer<DnsConfigService> timer_; + + DISALLOW_COPY_AND_ASSIGN(DnsConfigService); +}; + +} // namespace net + +#endif // NET_DNS_DNS_CONFIG_SERVICE_H_ diff --git a/chromium/net/dns/dns_config_service_posix.cc b/chromium/net/dns/dns_config_service_posix.cc new file mode 100644 index 00000000000..ff2295e704a --- /dev/null +++ b/chromium/net/dns/dns_config_service_posix.cc @@ -0,0 +1,404 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_config_service_posix.h" + +#include <string> + +#include "base/basictypes.h" +#include "base/bind.h" +#include "base/files/file_path.h" +#include "base/files/file_path_watcher.h" +#include "base/memory/scoped_ptr.h" +#include "base/metrics/histogram.h" +#include "base/time/time.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_util.h" +#include "net/dns/dns_hosts.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/notify_watcher_mac.h" +#include "net/dns/serial_worker.h" + +namespace net { + +#if !defined(OS_ANDROID) +namespace internal { + +namespace { + +const base::FilePath::CharType* kFilePathHosts = + FILE_PATH_LITERAL("/etc/hosts"); + +#if defined(OS_MACOSX) +// From 10.7.3 configd-395.10/dnsinfo/dnsinfo.h +static const char* kDnsNotifyKey = + "com.apple.system.SystemConfiguration.dns_configuration"; + +class ConfigWatcher { + public: + bool Watch(const base::Callback<void(bool succeeded)>& callback) { + return watcher_.Watch(kDnsNotifyKey, callback); + } + + private: + NotifyWatcherMac watcher_; +}; +#else + +#ifndef _PATH_RESCONF // Normally defined in <resolv.h> +#define _PATH_RESCONF "/etc/resolv.conf" +#endif + +static const base::FilePath::CharType* kFilePathConfig = + FILE_PATH_LITERAL(_PATH_RESCONF); + +class ConfigWatcher { + public: + typedef base::Callback<void(bool succeeded)> CallbackType; + + bool Watch(const CallbackType& callback) { + callback_ = callback; + return watcher_.Watch(base::FilePath(kFilePathConfig), false, + base::Bind(&ConfigWatcher::OnCallback, + base::Unretained(this))); + } + + private: + void OnCallback(const base::FilePath& path, bool error) { + callback_.Run(!error); + } + + base::FilePathWatcher watcher_; + CallbackType callback_; +}; +#endif + +ConfigParsePosixResult ReadDnsConfig(DnsConfig* config) { + ConfigParsePosixResult result; +#if defined(OS_OPENBSD) + // Note: res_ninit in glibc always returns 0 and sets RES_INIT. + // res_init behaves the same way. + memset(&_res, 0, sizeof(_res)); + if (res_init() == 0) { + result = ConvertResStateToDnsConfig(_res, config); + } else { + result = CONFIG_PARSE_POSIX_RES_INIT_FAILED; + } +#else // all other OS_POSIX + struct __res_state res; + memset(&res, 0, sizeof(res)); + if (res_ninit(&res) == 0) { + result = ConvertResStateToDnsConfig(res, config); + } else { + result = CONFIG_PARSE_POSIX_RES_INIT_FAILED; + } + // Prefer res_ndestroy where available. +#if defined(OS_MACOSX) || defined(OS_FREEBSD) + res_ndestroy(&res); +#else + res_nclose(&res); +#endif +#endif + // Override timeout value to match default setting on Windows. + config->timeout = base::TimeDelta::FromSeconds(kDnsTimeoutSeconds); + return result; +} + +} // namespace + +class DnsConfigServicePosix::Watcher { + public: + explicit Watcher(DnsConfigServicePosix* service) + : weak_factory_(this), + service_(service) {} + ~Watcher() {} + + bool Watch() { + bool success = true; + if (!config_watcher_.Watch(base::Bind(&Watcher::OnConfigChanged, + base::Unretained(this)))) { + LOG(ERROR) << "DNS config watch failed to start."; + success = false; + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus", + DNS_CONFIG_WATCH_FAILED_TO_START_CONFIG, + DNS_CONFIG_WATCH_MAX); + } + if (!hosts_watcher_.Watch(base::FilePath(kFilePathHosts), false, + base::Bind(&Watcher::OnHostsChanged, + base::Unretained(this)))) { + LOG(ERROR) << "DNS hosts watch failed to start."; + success = false; + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus", + DNS_CONFIG_WATCH_FAILED_TO_START_HOSTS, + DNS_CONFIG_WATCH_MAX); + } + return success; + } + + private: + void OnConfigChanged(bool succeeded) { + // Ignore transient flutter of resolv.conf by delaying the signal a bit. + const base::TimeDelta kDelay = base::TimeDelta::FromMilliseconds(50); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(&Watcher::OnConfigChangedDelayed, + weak_factory_.GetWeakPtr(), + succeeded), + kDelay); + } + void OnConfigChangedDelayed(bool succeeded) { + service_->OnConfigChanged(succeeded); + } + void OnHostsChanged(const base::FilePath& path, bool error) { + service_->OnHostsChanged(!error); + } + + base::WeakPtrFactory<Watcher> weak_factory_; + DnsConfigServicePosix* service_; + ConfigWatcher config_watcher_; + base::FilePathWatcher hosts_watcher_; + + DISALLOW_COPY_AND_ASSIGN(Watcher); +}; + +// A SerialWorker that uses libresolv to initialize res_state and converts +// it to DnsConfig. +class DnsConfigServicePosix::ConfigReader : public SerialWorker { + public: + explicit ConfigReader(DnsConfigServicePosix* service) + : service_(service), success_(false) {} + + virtual void DoWork() OVERRIDE { + base::TimeTicks start_time = base::TimeTicks::Now(); + ConfigParsePosixResult result = ReadDnsConfig(&dns_config_); + success_ = (result == CONFIG_PARSE_POSIX_OK); + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.ConfigParsePosix", + result, CONFIG_PARSE_POSIX_MAX); + UMA_HISTOGRAM_BOOLEAN("AsyncDNS.ConfigParseResult", success_); + UMA_HISTOGRAM_TIMES("AsyncDNS.ConfigParseDuration", + base::TimeTicks::Now() - start_time); + } + + virtual void OnWorkFinished() OVERRIDE { + DCHECK(!IsCancelled()); + if (success_) { + service_->OnConfigRead(dns_config_); + } else { + LOG(WARNING) << "Failed to read DnsConfig."; + } + } + + private: + virtual ~ConfigReader() {} + + DnsConfigServicePosix* service_; + // Written in DoWork, read in OnWorkFinished, no locking necessary. + DnsConfig dns_config_; + bool success_; + + DISALLOW_COPY_AND_ASSIGN(ConfigReader); +}; + +// A SerialWorker that reads the HOSTS file and runs Callback. +class DnsConfigServicePosix::HostsReader : public SerialWorker { + public: + explicit HostsReader(DnsConfigServicePosix* service) + : service_(service), path_(kFilePathHosts), success_(false) {} + + private: + virtual ~HostsReader() {} + + virtual void DoWork() OVERRIDE { + base::TimeTicks start_time = base::TimeTicks::Now(); + success_ = ParseHostsFile(path_, &hosts_); + UMA_HISTOGRAM_BOOLEAN("AsyncDNS.HostParseResult", success_); + UMA_HISTOGRAM_TIMES("AsyncDNS.HostsParseDuration", + base::TimeTicks::Now() - start_time); + } + + virtual void OnWorkFinished() OVERRIDE { + if (success_) { + service_->OnHostsRead(hosts_); + } else { + LOG(WARNING) << "Failed to read DnsHosts."; + } + } + + DnsConfigServicePosix* service_; + const base::FilePath path_; + // Written in DoWork, read in OnWorkFinished, no locking necessary. + DnsHosts hosts_; + bool success_; + + DISALLOW_COPY_AND_ASSIGN(HostsReader); +}; + +DnsConfigServicePosix::DnsConfigServicePosix() + : config_reader_(new ConfigReader(this)), + hosts_reader_(new HostsReader(this)) {} + +DnsConfigServicePosix::~DnsConfigServicePosix() { + config_reader_->Cancel(); + hosts_reader_->Cancel(); +} + +void DnsConfigServicePosix::ReadNow() { + config_reader_->WorkNow(); + hosts_reader_->WorkNow(); +} + +bool DnsConfigServicePosix::StartWatching() { + // TODO(szym): re-start watcher if that makes sense. http://crbug.com/116139 + watcher_.reset(new Watcher(this)); + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus", DNS_CONFIG_WATCH_STARTED, + DNS_CONFIG_WATCH_MAX); + return watcher_->Watch(); +} + +void DnsConfigServicePosix::OnConfigChanged(bool succeeded) { + InvalidateConfig(); + if (succeeded) { + config_reader_->WorkNow(); + } else { + LOG(ERROR) << "DNS config watch failed."; + set_watch_failed(true); + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus", + DNS_CONFIG_WATCH_FAILED_CONFIG, + DNS_CONFIG_WATCH_MAX); + } +} + +void DnsConfigServicePosix::OnHostsChanged(bool succeeded) { + InvalidateHosts(); + if (succeeded) { + hosts_reader_->WorkNow(); + } else { + LOG(ERROR) << "DNS hosts watch failed."; + set_watch_failed(true); + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus", + DNS_CONFIG_WATCH_FAILED_HOSTS, + DNS_CONFIG_WATCH_MAX); + } +} + +ConfigParsePosixResult ConvertResStateToDnsConfig(const struct __res_state& res, + DnsConfig* dns_config) { + CHECK(dns_config != NULL); + if (!(res.options & RES_INIT)) + return CONFIG_PARSE_POSIX_RES_INIT_UNSET; + + dns_config->nameservers.clear(); + +#if defined(OS_MACOSX) || defined(OS_FREEBSD) + union res_sockaddr_union addresses[MAXNS]; + int nscount = res_getservers(const_cast<res_state>(&res), addresses, MAXNS); + DCHECK_GE(nscount, 0); + DCHECK_LE(nscount, MAXNS); + for (int i = 0; i < nscount; ++i) { + IPEndPoint ipe; + if (!ipe.FromSockAddr( + reinterpret_cast<const struct sockaddr*>(&addresses[i]), + sizeof addresses[i])) { + return CONFIG_PARSE_POSIX_BAD_ADDRESS; + } + dns_config->nameservers.push_back(ipe); + } +#elif defined(OS_LINUX) + COMPILE_ASSERT(arraysize(res.nsaddr_list) >= MAXNS && + arraysize(res._u._ext.nsaddrs) >= MAXNS, + incompatible_libresolv_res_state); + DCHECK_LE(res.nscount, MAXNS); + // Initially, glibc stores IPv6 in |_ext.nsaddrs| and IPv4 in |nsaddr_list|. + // In res_send.c:res_nsend, it merges |nsaddr_list| into |nsaddrs|, + // but we have to combine the two arrays ourselves. + for (int i = 0; i < res.nscount; ++i) { + IPEndPoint ipe; + const struct sockaddr* addr = NULL; + size_t addr_len = 0; + if (res.nsaddr_list[i].sin_family) { // The indicator used by res_nsend. + addr = reinterpret_cast<const struct sockaddr*>(&res.nsaddr_list[i]); + addr_len = sizeof res.nsaddr_list[i]; + } else if (res._u._ext.nsaddrs[i] != NULL) { + addr = reinterpret_cast<const struct sockaddr*>(res._u._ext.nsaddrs[i]); + addr_len = sizeof *res._u._ext.nsaddrs[i]; + } else { + return CONFIG_PARSE_POSIX_BAD_EXT_STRUCT; + } + if (!ipe.FromSockAddr(addr, addr_len)) + return CONFIG_PARSE_POSIX_BAD_ADDRESS; + dns_config->nameservers.push_back(ipe); + } +#else // !(defined(OS_LINUX) || defined(OS_MACOSX) || defined(OS_FREEBSD)) + DCHECK_LE(res.nscount, MAXNS); + for (int i = 0; i < res.nscount; ++i) { + IPEndPoint ipe; + if (!ipe.FromSockAddr( + reinterpret_cast<const struct sockaddr*>(&res.nsaddr_list[i]), + sizeof res.nsaddr_list[i])) { + return CONFIG_PARSE_POSIX_BAD_ADDRESS; + } + dns_config->nameservers.push_back(ipe); + } +#endif + + dns_config->search.clear(); + for (int i = 0; (i < MAXDNSRCH) && res.dnsrch[i]; ++i) { + dns_config->search.push_back(std::string(res.dnsrch[i])); + } + + dns_config->ndots = res.ndots; + dns_config->timeout = base::TimeDelta::FromSeconds(res.retrans); + dns_config->attempts = res.retry; +#if defined(RES_ROTATE) + dns_config->rotate = res.options & RES_ROTATE; +#endif + dns_config->edns0 = res.options & RES_USE_EDNS0; + + // The current implementation assumes these options are set. They normally + // cannot be overwritten by /etc/resolv.conf + unsigned kRequiredOptions = RES_RECURSE | RES_DEFNAMES | RES_DNSRCH; + if ((res.options & kRequiredOptions) != kRequiredOptions) + return CONFIG_PARSE_POSIX_MISSING_OPTIONS; + + unsigned kUnhandledOptions = RES_USEVC | RES_IGNTC | RES_USE_DNSSEC; + if (res.options & kUnhandledOptions) + return CONFIG_PARSE_POSIX_UNHANDLED_OPTIONS; + + if (dns_config->nameservers.empty()) + return CONFIG_PARSE_POSIX_NO_NAMESERVERS; + + // If any name server is 0.0.0.0, assume the configuration is invalid. + // TODO(szym): Measure how often this happens. http://crbug.com/125599 + const IPAddressNumber kEmptyAddress(kIPv4AddressSize); + for (unsigned i = 0; i < dns_config->nameservers.size(); ++i) { + if (dns_config->nameservers[i].address() == kEmptyAddress) + return CONFIG_PARSE_POSIX_NULL_ADDRESS; + } + return CONFIG_PARSE_POSIX_OK; +} + +} // namespace internal + +// static +scoped_ptr<DnsConfigService> DnsConfigService::CreateSystemService() { + return scoped_ptr<DnsConfigService>(new internal::DnsConfigServicePosix()); +} + +#else // defined(OS_ANDROID) +// Android NDK provides only a stub <resolv.h> header. +class StubDnsConfigService : public DnsConfigService { + public: + StubDnsConfigService() {} + virtual ~StubDnsConfigService() {} + private: + virtual void ReadNow() OVERRIDE {} + virtual bool StartWatching() OVERRIDE { return false; } +}; +// static +scoped_ptr<DnsConfigService> DnsConfigService::CreateSystemService() { + return scoped_ptr<DnsConfigService>(new StubDnsConfigService()); +} +#endif + +} // namespace net diff --git a/chromium/net/dns/dns_config_service_posix.h b/chromium/net/dns/dns_config_service_posix.h new file mode 100644 index 00000000000..95a4377d932 --- /dev/null +++ b/chromium/net/dns/dns_config_service_posix.h @@ -0,0 +1,67 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_CONFIG_SERVICE_POSIX_H_ +#define NET_DNS_DNS_CONFIG_SERVICE_POSIX_H_ + +#include <sys/types.h> +#include <netinet/in.h> +#include <resolv.h> + +#include "base/compiler_specific.h" +#include "net/base/net_export.h" +#include "net/dns/dns_config_service.h" + +namespace net { + +// Use DnsConfigService::CreateSystemService to use it outside of tests. +namespace internal { + +class NET_EXPORT_PRIVATE DnsConfigServicePosix : public DnsConfigService { + public: + DnsConfigServicePosix(); + virtual ~DnsConfigServicePosix(); + + protected: + // DnsConfigService: + virtual void ReadNow() OVERRIDE; + virtual bool StartWatching() OVERRIDE; + + private: + class Watcher; + class ConfigReader; + class HostsReader; + + void OnConfigChanged(bool succeeded); + void OnHostsChanged(bool succeeded); + + scoped_ptr<Watcher> watcher_; + scoped_refptr<ConfigReader> config_reader_; + scoped_refptr<HostsReader> hosts_reader_; + + DISALLOW_COPY_AND_ASSIGN(DnsConfigServicePosix); +}; + +enum ConfigParsePosixResult { + CONFIG_PARSE_POSIX_OK = 0, + CONFIG_PARSE_POSIX_RES_INIT_FAILED, + CONFIG_PARSE_POSIX_RES_INIT_UNSET, + CONFIG_PARSE_POSIX_BAD_ADDRESS, + CONFIG_PARSE_POSIX_BAD_EXT_STRUCT, + CONFIG_PARSE_POSIX_NULL_ADDRESS, + CONFIG_PARSE_POSIX_NO_NAMESERVERS, + CONFIG_PARSE_POSIX_MISSING_OPTIONS, + CONFIG_PARSE_POSIX_UNHANDLED_OPTIONS, + CONFIG_PARSE_POSIX_MAX // Bounding values for enumeration. +}; + +// Fills in |dns_config| from |res|. +ConfigParsePosixResult NET_EXPORT_PRIVATE ConvertResStateToDnsConfig( + const struct __res_state& res, DnsConfig* dns_config); + +} // namespace internal + +} // namespace net + +#endif // NET_DNS_DNS_CONFIG_SERVICE_POSIX_H_ diff --git a/chromium/net/dns/dns_config_service_posix_unittest.cc b/chromium/net/dns/dns_config_service_posix_unittest.cc new file mode 100644 index 00000000000..92e46ed1977 --- /dev/null +++ b/chromium/net/dns/dns_config_service_posix_unittest.cc @@ -0,0 +1,156 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <resolv.h> + +#include "base/sys_byteorder.h" +#include "net/dns/dns_config_service_posix.h" + +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { +namespace { + +// MAXNS is normally 3, but let's test 4 if possible. +const char* kNameserversIPv4[] = { + "8.8.8.8", + "192.168.1.1", + "63.1.2.4", + "1.0.0.1", +}; + +#if defined(OS_LINUX) +const char* kNameserversIPv6[] = { + NULL, + "2001:DB8:0::42", + NULL, + "::FFFF:129.144.52.38", +}; +#endif + +// Fills in |res| with sane configuration. +void InitializeResState(res_state res) { + memset(res, 0, sizeof(*res)); + res->options = RES_INIT | RES_RECURSE | RES_DEFNAMES | RES_DNSRCH | + RES_ROTATE; + res->ndots = 2; + res->retrans = 4; + res->retry = 7; + + const char kDnsrch[] = "chromium.org" "\0" "example.com"; + memcpy(res->defdname, kDnsrch, sizeof(kDnsrch)); + res->dnsrch[0] = res->defdname; + res->dnsrch[1] = res->defdname + sizeof("chromium.org"); + + for (unsigned i = 0; i < arraysize(kNameserversIPv4) && i < MAXNS; ++i) { + struct sockaddr_in sa; + sa.sin_family = AF_INET; + sa.sin_port = base::HostToNet16(NS_DEFAULTPORT + i); + inet_pton(AF_INET, kNameserversIPv4[i], &sa.sin_addr); + res->nsaddr_list[i] = sa; + ++res->nscount; + } + +#if defined(OS_LINUX) + // Install IPv6 addresses, replacing the corresponding IPv4 addresses. + unsigned nscount6 = 0; + for (unsigned i = 0; i < arraysize(kNameserversIPv6) && i < MAXNS; ++i) { + if (!kNameserversIPv6[i]) + continue; + // Must use malloc to mimick res_ninit. + struct sockaddr_in6 *sa6; + sa6 = (struct sockaddr_in6 *)malloc(sizeof(*sa6)); + sa6->sin6_family = AF_INET6; + sa6->sin6_port = base::HostToNet16(NS_DEFAULTPORT - i); + inet_pton(AF_INET6, kNameserversIPv6[i], &sa6->sin6_addr); + res->_u._ext.nsaddrs[i] = sa6; + memset(&res->nsaddr_list[i], 0, sizeof res->nsaddr_list[i]); + ++nscount6; + } + res->_u._ext.nscount6 = nscount6; +#endif +} + +void CloseResState(res_state res) { +#if defined(OS_LINUX) + for (int i = 0; i < res->nscount; ++i) { + if (res->_u._ext.nsaddrs[i] != NULL) + free(res->_u._ext.nsaddrs[i]); + } +#endif +} + +void InitializeExpectedConfig(DnsConfig* config) { + config->ndots = 2; + config->timeout = base::TimeDelta::FromSeconds(4); + config->attempts = 7; + config->rotate = true; + config->edns0 = false; + config->append_to_multi_label_name = true; + config->search.clear(); + config->search.push_back("chromium.org"); + config->search.push_back("example.com"); + + config->nameservers.clear(); + for (unsigned i = 0; i < arraysize(kNameserversIPv4) && i < MAXNS; ++i) { + IPAddressNumber ip; + ParseIPLiteralToNumber(kNameserversIPv4[i], &ip); + config->nameservers.push_back(IPEndPoint(ip, NS_DEFAULTPORT + i)); + } + +#if defined(OS_LINUX) + for (unsigned i = 0; i < arraysize(kNameserversIPv6) && i < MAXNS; ++i) { + if (!kNameserversIPv6[i]) + continue; + IPAddressNumber ip; + ParseIPLiteralToNumber(kNameserversIPv6[i], &ip); + config->nameservers[i] = IPEndPoint(ip, NS_DEFAULTPORT - i); + } +#endif +} + +TEST(DnsConfigServicePosixTest, ConvertResStateToDnsConfig) { + struct __res_state res; + DnsConfig config; + EXPECT_FALSE(config.IsValid()); + InitializeResState(&res); + ASSERT_EQ(internal::CONFIG_PARSE_POSIX_OK, + internal::ConvertResStateToDnsConfig(res, &config)); + CloseResState(&res); + EXPECT_TRUE(config.IsValid()); + + DnsConfig expected_config; + EXPECT_FALSE(expected_config.EqualsIgnoreHosts(config)); + InitializeExpectedConfig(&expected_config); + EXPECT_TRUE(expected_config.EqualsIgnoreHosts(config)); +} + +TEST(DnsConfigServicePosixTest, RejectEmptyNameserver) { + struct __res_state res = {}; + res.options = RES_INIT | RES_RECURSE | RES_DEFNAMES | RES_DNSRCH; + const char kDnsrch[] = "chromium.org"; + memcpy(res.defdname, kDnsrch, sizeof(kDnsrch)); + res.dnsrch[0] = res.defdname; + + struct sockaddr_in sa = {}; + sa.sin_family = AF_INET; + sa.sin_port = base::HostToNet16(NS_DEFAULTPORT); + sa.sin_addr.s_addr = INADDR_ANY; + res.nsaddr_list[0] = sa; + sa.sin_addr.s_addr = 0xCAFE1337; + res.nsaddr_list[1] = sa; + res.nscount = 2; + + DnsConfig config; + EXPECT_EQ(internal::CONFIG_PARSE_POSIX_NULL_ADDRESS, + internal::ConvertResStateToDnsConfig(res, &config)); + + sa.sin_addr.s_addr = 0xDEADBEEF; + res.nsaddr_list[0] = sa; + EXPECT_EQ(internal::CONFIG_PARSE_POSIX_OK, + internal::ConvertResStateToDnsConfig(res, &config)); +} + +} // namespace +} // namespace net diff --git a/chromium/net/dns/dns_config_service_unittest.cc b/chromium/net/dns/dns_config_service_unittest.cc new file mode 100644 index 00000000000..f42068032de --- /dev/null +++ b/chromium/net/dns/dns_config_service_unittest.cc @@ -0,0 +1,258 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_config_service.h" + +#include "base/basictypes.h" +#include "base/bind.h" +#include "base/cancelable_callback.h" +#include "base/memory/scoped_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/test/test_timeouts.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +class DnsConfigServiceTest : public testing::Test { + public: + void OnConfigChanged(const DnsConfig& config) { + last_config_ = config; + if (quit_on_config_) + base::MessageLoop::current()->Quit(); + } + + protected: + class TestDnsConfigService : public DnsConfigService { + public: + virtual void ReadNow() OVERRIDE {} + virtual bool StartWatching() OVERRIDE { return true; } + + // Expose the protected methods to this test suite. + void InvalidateConfig() { + DnsConfigService::InvalidateConfig(); + } + + void InvalidateHosts() { + DnsConfigService::InvalidateHosts(); + } + + void OnConfigRead(const DnsConfig& config) { + DnsConfigService::OnConfigRead(config); + } + + void OnHostsRead(const DnsHosts& hosts) { + DnsConfigService::OnHostsRead(hosts); + } + + void set_watch_failed(bool value) { + DnsConfigService::set_watch_failed(value); + } + }; + + void WaitForConfig(base::TimeDelta timeout) { + base::CancelableClosure closure(base::MessageLoop::QuitClosure()); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, closure.callback(), timeout); + quit_on_config_ = true; + base::MessageLoop::current()->Run(); + quit_on_config_ = false; + closure.Cancel(); + } + + // Generate a config using the given seed.. + DnsConfig MakeConfig(unsigned seed) { + DnsConfig config; + IPAddressNumber ip; + CHECK(ParseIPLiteralToNumber("1.2.3.4", &ip)); + config.nameservers.push_back(IPEndPoint(ip, seed & 0xFFFF)); + EXPECT_TRUE(config.IsValid()); + return config; + } + + // Generate hosts using the given seed. + DnsHosts MakeHosts(unsigned seed) { + DnsHosts hosts; + std::string hosts_content = "127.0.0.1 localhost"; + hosts_content.append(seed, '1'); + ParseHosts(hosts_content, &hosts); + EXPECT_FALSE(hosts.empty()); + return hosts; + } + + virtual void SetUp() OVERRIDE { + quit_on_config_ = false; + + service_.reset(new TestDnsConfigService()); + service_->WatchConfig(base::Bind(&DnsConfigServiceTest::OnConfigChanged, + base::Unretained(this))); + EXPECT_FALSE(last_config_.IsValid()); + } + + DnsConfig last_config_; + bool quit_on_config_; + + // Service under test. + scoped_ptr<TestDnsConfigService> service_; +}; + +} // namespace + +TEST_F(DnsConfigServiceTest, FirstConfig) { + DnsConfig config = MakeConfig(1); + + service_->OnConfigRead(config); + // No hosts yet, so no config. + EXPECT_TRUE(last_config_.Equals(DnsConfig())); + + service_->OnHostsRead(config.hosts); + EXPECT_TRUE(last_config_.Equals(config)); +} + +TEST_F(DnsConfigServiceTest, Timeout) { + DnsConfig config = MakeConfig(1); + config.hosts = MakeHosts(1); + ASSERT_TRUE(config.IsValid()); + + service_->OnConfigRead(config); + service_->OnHostsRead(config.hosts); + EXPECT_FALSE(last_config_.Equals(DnsConfig())); + EXPECT_TRUE(last_config_.Equals(config)); + + service_->InvalidateConfig(); + WaitForConfig(TestTimeouts::action_timeout()); + EXPECT_FALSE(last_config_.Equals(config)); + EXPECT_TRUE(last_config_.Equals(DnsConfig())); + + service_->OnConfigRead(config); + EXPECT_FALSE(last_config_.Equals(DnsConfig())); + EXPECT_TRUE(last_config_.Equals(config)); + + service_->InvalidateHosts(); + WaitForConfig(TestTimeouts::action_timeout()); + EXPECT_FALSE(last_config_.Equals(config)); + EXPECT_TRUE(last_config_.Equals(DnsConfig())); + + DnsConfig bad_config = last_config_ = MakeConfig(0xBAD); + service_->InvalidateConfig(); + // We don't expect an update. This should time out. + WaitForConfig(base::TimeDelta::FromMilliseconds(100) + + TestTimeouts::tiny_timeout()); + EXPECT_TRUE(last_config_.Equals(bad_config)) << "Unexpected change"; + + last_config_ = DnsConfig(); + service_->OnConfigRead(config); + service_->OnHostsRead(config.hosts); + EXPECT_FALSE(last_config_.Equals(DnsConfig())); + EXPECT_TRUE(last_config_.Equals(config)); +} + +TEST_F(DnsConfigServiceTest, SameConfig) { + DnsConfig config = MakeConfig(1); + config.hosts = MakeHosts(1); + + service_->OnConfigRead(config); + service_->OnHostsRead(config.hosts); + EXPECT_FALSE(last_config_.Equals(DnsConfig())); + EXPECT_TRUE(last_config_.Equals(config)); + + last_config_ = DnsConfig(); + service_->OnConfigRead(config); + EXPECT_TRUE(last_config_.Equals(DnsConfig())) << "Unexpected change"; + + service_->OnHostsRead(config.hosts); + EXPECT_TRUE(last_config_.Equals(DnsConfig())) << "Unexpected change"; +} + +TEST_F(DnsConfigServiceTest, DifferentConfig) { + DnsConfig config1 = MakeConfig(1); + DnsConfig config2 = MakeConfig(2); + DnsConfig config3 = MakeConfig(1); + config1.hosts = MakeHosts(1); + config2.hosts = MakeHosts(1); + config3.hosts = MakeHosts(2); + ASSERT_TRUE(config1.EqualsIgnoreHosts(config3)); + ASSERT_FALSE(config1.Equals(config2)); + ASSERT_FALSE(config1.Equals(config3)); + ASSERT_FALSE(config2.Equals(config3)); + + service_->OnConfigRead(config1); + service_->OnHostsRead(config1.hosts); + EXPECT_FALSE(last_config_.Equals(DnsConfig())); + EXPECT_TRUE(last_config_.Equals(config1)); + + // It doesn't matter for this tests, but increases coverage. + service_->InvalidateConfig(); + service_->InvalidateHosts(); + + service_->OnConfigRead(config2); + EXPECT_TRUE(last_config_.Equals(config1)) << "Unexpected change"; + service_->OnHostsRead(config2.hosts); // Not an actual change. + EXPECT_FALSE(last_config_.Equals(config1)); + EXPECT_TRUE(last_config_.Equals(config2)); + + service_->OnConfigRead(config3); + EXPECT_TRUE(last_config_.EqualsIgnoreHosts(config3)); + service_->OnHostsRead(config3.hosts); + EXPECT_FALSE(last_config_.Equals(config2)); + EXPECT_TRUE(last_config_.Equals(config3)); +} + +TEST_F(DnsConfigServiceTest, WatchFailure) { + DnsConfig config1 = MakeConfig(1); + DnsConfig config2 = MakeConfig(2); + config1.hosts = MakeHosts(1); + config2.hosts = MakeHosts(2); + + service_->OnConfigRead(config1); + service_->OnHostsRead(config1.hosts); + EXPECT_FALSE(last_config_.Equals(DnsConfig())); + EXPECT_TRUE(last_config_.Equals(config1)); + + // Simulate watch failure. + service_->set_watch_failed(true); + service_->InvalidateConfig(); + WaitForConfig(TestTimeouts::action_timeout()); + EXPECT_FALSE(last_config_.Equals(config1)); + EXPECT_TRUE(last_config_.Equals(DnsConfig())); + + DnsConfig bad_config = last_config_ = MakeConfig(0xBAD); + // Actual change in config, so expect an update, but it should be empty. + service_->OnConfigRead(config1); + EXPECT_FALSE(last_config_.Equals(bad_config)); + EXPECT_TRUE(last_config_.Equals(DnsConfig())); + + last_config_ = bad_config; + // Actual change in config, so expect an update, but it should be empty. + service_->InvalidateConfig(); + service_->OnConfigRead(config2); + EXPECT_FALSE(last_config_.Equals(bad_config)); + EXPECT_TRUE(last_config_.Equals(DnsConfig())); + + last_config_ = bad_config; + // No change, so no update. + service_->InvalidateConfig(); + service_->OnConfigRead(config2); + EXPECT_TRUE(last_config_.Equals(bad_config)); +} + +#if (defined(OS_POSIX) && !defined(OS_ANDROID)) || defined(OS_WIN) +// TODO(szym): This is really an integration test and can time out if HOSTS is +// huge. http://crbug.com/107810 +TEST_F(DnsConfigServiceTest, DISABLED_GetSystemConfig) { + service_.reset(); + scoped_ptr<DnsConfigService> service(DnsConfigService::CreateSystemService()); + + service->ReadConfig(base::Bind(&DnsConfigServiceTest::OnConfigChanged, + base::Unretained(this))); + base::TimeDelta kTimeout = TestTimeouts::action_max_timeout(); + WaitForConfig(kTimeout); + ASSERT_TRUE(last_config_.IsValid()) << "Did not receive DnsConfig in " << + kTimeout.InSecondsF() << "s"; +} +#endif // OS_POSIX || OS_WIN + +} // namespace net + diff --git a/chromium/net/dns/dns_config_service_win.cc b/chromium/net/dns/dns_config_service_win.cc new file mode 100644 index 00000000000..6d12155d8cc --- /dev/null +++ b/chromium/net/dns/dns_config_service_win.cc @@ -0,0 +1,737 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_config_service_win.h" + +#include <algorithm> +#include <string> + +#include "base/bind.h" +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/files/file_path.h" +#include "base/files/file_path_watcher.h" +#include "base/logging.h" +#include "base/memory/scoped_ptr.h" +#include "base/metrics/histogram.h" +#include "base/strings/string_split.h" +#include "base/strings/string_util.h" +#include "base/strings/utf_string_conversions.h" +#include "base/synchronization/lock.h" +#include "base/threading/non_thread_safe.h" +#include "base/threading/thread_restrictions.h" +#include "base/time/time.h" +#include "base/win/object_watcher.h" +#include "base/win/registry.h" +#include "base/win/windows_version.h" +#include "net/base/net_util.h" +#include "net/base/network_change_notifier.h" +#include "net/dns/dns_hosts.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/serial_worker.h" +#include "url/url_canon.h" + +#pragma comment(lib, "iphlpapi.lib") + +namespace net { + +namespace internal { + +namespace { + +// Interval between retries to parse config. Used only until parsing succeeds. +const int kRetryIntervalSeconds = 5; + +const wchar_t* const kPrimaryDnsSuffixPath = + L"SOFTWARE\\Policies\\Microsoft\\System\\DNSClient"; + +enum HostsParseWinResult { + HOSTS_PARSE_WIN_OK = 0, + HOSTS_PARSE_WIN_UNREADABLE_HOSTS_FILE, + HOSTS_PARSE_WIN_COMPUTER_NAME_FAILED, + HOSTS_PARSE_WIN_IPHELPER_FAILED, + HOSTS_PARSE_WIN_BAD_ADDRESS, + HOSTS_PARSE_WIN_MAX // Bounding values for enumeration. +}; + +// Convenience for reading values using RegKey. +class RegistryReader : public base::NonThreadSafe { + public: + explicit RegistryReader(const wchar_t* key) { + // Ignoring the result. |key_.Valid()| will catch failures. + key_.Open(HKEY_LOCAL_MACHINE, key, KEY_QUERY_VALUE); + } + + bool ReadString(const wchar_t* name, + DnsSystemSettings::RegString* out) const { + DCHECK(CalledOnValidThread()); + out->set = false; + if (!key_.Valid()) { + // Assume that if the |key_| is invalid then the key is missing. + return true; + } + LONG result = key_.ReadValue(name, &out->value); + if (result == ERROR_SUCCESS) { + out->set = true; + return true; + } + return (result == ERROR_FILE_NOT_FOUND); + } + + bool ReadDword(const wchar_t* name, + DnsSystemSettings::RegDword* out) const { + DCHECK(CalledOnValidThread()); + out->set = false; + if (!key_.Valid()) { + // Assume that if the |key_| is invalid then the key is missing. + return true; + } + LONG result = key_.ReadValueDW(name, &out->value); + if (result == ERROR_SUCCESS) { + out->set = true; + return true; + } + return (result == ERROR_FILE_NOT_FOUND); + } + + private: + base::win::RegKey key_; + + DISALLOW_COPY_AND_ASSIGN(RegistryReader); +}; + +// Wrapper for GetAdaptersAddresses. Returns NULL if failed. +scoped_ptr_malloc<IP_ADAPTER_ADDRESSES> ReadIpHelper(ULONG flags) { + base::ThreadRestrictions::AssertIOAllowed(); + + scoped_ptr_malloc<IP_ADAPTER_ADDRESSES> out; + ULONG len = 15000; // As recommended by MSDN for GetAdaptersAddresses. + UINT rv = ERROR_BUFFER_OVERFLOW; + // Try up to three times. + for (unsigned tries = 0; (tries < 3) && (rv == ERROR_BUFFER_OVERFLOW); + tries++) { + out.reset(reinterpret_cast<PIP_ADAPTER_ADDRESSES>(malloc(len))); + rv = GetAdaptersAddresses(AF_UNSPEC, flags, NULL, out.get(), &len); + } + if (rv != NO_ERROR) + out.reset(); + return out.Pass(); +} + +// Converts a base::string16 domain name to ASCII, possibly using punycode. +// Returns true if the conversion succeeds and output is not empty. In case of +// failure, |domain| might become dirty. +bool ParseDomainASCII(const base::string16& widestr, std::string* domain) { + DCHECK(domain); + if (widestr.empty()) + return false; + + // Check if already ASCII. + if (IsStringASCII(widestr)) { + *domain = UTF16ToASCII(widestr); + return true; + } + + // Otherwise try to convert it from IDN to punycode. + const int kInitialBufferSize = 256; + url_canon::RawCanonOutputT<char16, kInitialBufferSize> punycode; + if (!url_canon::IDNToASCII(widestr.data(), widestr.length(), &punycode)) + return false; + + // |punycode_output| should now be ASCII; convert it to a std::string. + // (We could use UTF16ToASCII() instead, but that requires an extra string + // copy. Since ASCII is a subset of UTF8 the following is equivalent). + bool success = UTF16ToUTF8(punycode.data(), punycode.length(), domain); + DCHECK(success); + DCHECK(IsStringASCII(*domain)); + return success && !domain->empty(); +} + +bool ReadDevolutionSetting(const RegistryReader& reader, + DnsSystemSettings::DevolutionSetting* setting) { + return reader.ReadDword(L"UseDomainNameDevolution", &setting->enabled) && + reader.ReadDword(L"DomainNameDevolutionLevel", &setting->level); +} + +// Reads DnsSystemSettings from IpHelper and registry. +ConfigParseWinResult ReadSystemSettings(DnsSystemSettings* settings) { + settings->addresses = ReadIpHelper(GAA_FLAG_SKIP_ANYCAST | + GAA_FLAG_SKIP_UNICAST | + GAA_FLAG_SKIP_MULTICAST | + GAA_FLAG_SKIP_FRIENDLY_NAME); + if (!settings->addresses.get()) + return CONFIG_PARSE_WIN_READ_IPHELPER; + + RegistryReader tcpip_reader(kTcpipPath); + RegistryReader tcpip6_reader(kTcpip6Path); + RegistryReader dnscache_reader(kDnscachePath); + RegistryReader policy_reader(kPolicyPath); + RegistryReader primary_dns_suffix_reader(kPrimaryDnsSuffixPath); + + if (!policy_reader.ReadString(L"SearchList", + &settings->policy_search_list)) { + return CONFIG_PARSE_WIN_READ_POLICY_SEARCHLIST; + } + + if (!tcpip_reader.ReadString(L"SearchList", &settings->tcpip_search_list)) + return CONFIG_PARSE_WIN_READ_TCPIP_SEARCHLIST; + + if (!tcpip_reader.ReadString(L"Domain", &settings->tcpip_domain)) + return CONFIG_PARSE_WIN_READ_DOMAIN; + + if (!ReadDevolutionSetting(policy_reader, &settings->policy_devolution)) + return CONFIG_PARSE_WIN_READ_POLICY_DEVOLUTION; + + if (!ReadDevolutionSetting(dnscache_reader, &settings->dnscache_devolution)) + return CONFIG_PARSE_WIN_READ_DNSCACHE_DEVOLUTION; + + if (!ReadDevolutionSetting(tcpip_reader, &settings->tcpip_devolution)) + return CONFIG_PARSE_WIN_READ_TCPIP_DEVOLUTION; + + if (!policy_reader.ReadDword(L"AppendToMultiLabelName", + &settings->append_to_multi_label_name)) { + return CONFIG_PARSE_WIN_READ_APPEND_MULTILABEL; + } + + if (!primary_dns_suffix_reader.ReadString(L"PrimaryDnsSuffix", + &settings->primary_dns_suffix)) { + return CONFIG_PARSE_WIN_READ_PRIMARY_SUFFIX; + } + return CONFIG_PARSE_WIN_OK; +} + +// Default address of "localhost" and local computer name can be overridden +// by the HOSTS file, but if it's not there, then we need to fill it in. +HostsParseWinResult AddLocalhostEntries(DnsHosts* hosts) { + const unsigned char kIPv4Localhost[] = { 127, 0, 0, 1 }; + const unsigned char kIPv6Localhost[] = { 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1 }; + IPAddressNumber loopback_ipv4(kIPv4Localhost, + kIPv4Localhost + arraysize(kIPv4Localhost)); + IPAddressNumber loopback_ipv6(kIPv6Localhost, + kIPv6Localhost + arraysize(kIPv6Localhost)); + + // This does not override any pre-existing entries from the HOSTS file. + hosts->insert(std::make_pair(DnsHostsKey("localhost", ADDRESS_FAMILY_IPV4), + loopback_ipv4)); + hosts->insert(std::make_pair(DnsHostsKey("localhost", ADDRESS_FAMILY_IPV6), + loopback_ipv6)); + + WCHAR buffer[MAX_PATH]; + DWORD size = MAX_PATH; + std::string localname; + if (!GetComputerNameExW(ComputerNameDnsHostname, buffer, &size) || + !ParseDomainASCII(buffer, &localname)) { + return HOSTS_PARSE_WIN_COMPUTER_NAME_FAILED; + } + StringToLowerASCII(&localname); + + bool have_ipv4 = + hosts->count(DnsHostsKey(localname, ADDRESS_FAMILY_IPV4)) > 0; + bool have_ipv6 = + hosts->count(DnsHostsKey(localname, ADDRESS_FAMILY_IPV6)) > 0; + + if (have_ipv4 && have_ipv6) + return HOSTS_PARSE_WIN_OK; + + scoped_ptr_malloc<IP_ADAPTER_ADDRESSES> addresses = + ReadIpHelper(GAA_FLAG_SKIP_ANYCAST | + GAA_FLAG_SKIP_DNS_SERVER | + GAA_FLAG_SKIP_MULTICAST | + GAA_FLAG_SKIP_FRIENDLY_NAME); + if (!addresses.get()) + return HOSTS_PARSE_WIN_IPHELPER_FAILED; + + // The order of adapters is the network binding order, so stick to the + // first good adapter for each family. + for (const IP_ADAPTER_ADDRESSES* adapter = addresses.get(); + adapter != NULL && (!have_ipv4 || !have_ipv6); + adapter = adapter->Next) { + if (adapter->OperStatus != IfOperStatusUp) + continue; + if (adapter->IfType == IF_TYPE_SOFTWARE_LOOPBACK) + continue; + + for (const IP_ADAPTER_UNICAST_ADDRESS* address = + adapter->FirstUnicastAddress; + address != NULL; + address = address->Next) { + IPEndPoint ipe; + if (!ipe.FromSockAddr(address->Address.lpSockaddr, + address->Address.iSockaddrLength)) { + return HOSTS_PARSE_WIN_BAD_ADDRESS; + } + if (!have_ipv4 && (ipe.GetFamily() == ADDRESS_FAMILY_IPV4)) { + have_ipv4 = true; + (*hosts)[DnsHostsKey(localname, ADDRESS_FAMILY_IPV4)] = ipe.address(); + } else if (!have_ipv6 && (ipe.GetFamily() == ADDRESS_FAMILY_IPV6)) { + have_ipv6 = true; + (*hosts)[DnsHostsKey(localname, ADDRESS_FAMILY_IPV6)] = ipe.address(); + } + } + } + return HOSTS_PARSE_WIN_OK; +} + +// Watches a single registry key for changes. +class RegistryWatcher : public base::win::ObjectWatcher::Delegate, + public base::NonThreadSafe { + public: + typedef base::Callback<void(bool succeeded)> CallbackType; + RegistryWatcher() {} + + bool Watch(const wchar_t* key, const CallbackType& callback) { + DCHECK(CalledOnValidThread()); + DCHECK(!callback.is_null()); + DCHECK(callback_.is_null()); + callback_ = callback; + if (key_.Open(HKEY_LOCAL_MACHINE, key, KEY_NOTIFY) != ERROR_SUCCESS) + return false; + if (key_.StartWatching() != ERROR_SUCCESS) + return false; + if (!watcher_.StartWatching(key_.watch_event(), this)) + return false; + return true; + } + + virtual void OnObjectSignaled(HANDLE object) OVERRIDE { + DCHECK(CalledOnValidThread()); + bool succeeded = (key_.StartWatching() == ERROR_SUCCESS) && + watcher_.StartWatching(key_.watch_event(), this); + if (!succeeded && key_.Valid()) { + watcher_.StopWatching(); + key_.StopWatching(); + key_.Close(); + } + if (!callback_.is_null()) + callback_.Run(succeeded); + } + + private: + CallbackType callback_; + base::win::RegKey key_; + base::win::ObjectWatcher watcher_; + + DISALLOW_COPY_AND_ASSIGN(RegistryWatcher); +}; + +// Returns true iff |address| is DNS address from IPv6 stateless discovery, +// i.e., matches fec0:0:0:ffff::{1,2,3}. +// http://tools.ietf.org/html/draft-ietf-ipngwg-dns-discovery +bool IsStatelessDiscoveryAddress(const IPAddressNumber& address) { + if (address.size() != kIPv6AddressSize) + return false; + const uint8 kPrefix[] = { + 0xfe, 0xc0, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + return std::equal(kPrefix, kPrefix + arraysize(kPrefix), + address.begin()) && (address.back() < 4); +} + +} // namespace + +base::FilePath GetHostsPath() { + TCHAR buffer[MAX_PATH]; + UINT rc = GetSystemDirectory(buffer, MAX_PATH); + DCHECK(0 < rc && rc < MAX_PATH); + return base::FilePath(buffer).Append( + FILE_PATH_LITERAL("drivers\\etc\\hosts")); +} + +bool ParseSearchList(const base::string16& value, + std::vector<std::string>* output) { + DCHECK(output); + if (value.empty()) + return false; + + output->clear(); + + // If the list includes an empty hostname (",," or ", ,"), it is terminated. + // Although nslookup and network connection property tab ignore such + // fragments ("a,b,,c" becomes ["a", "b", "c"]), our reference is getaddrinfo + // (which sees ["a", "b"]). WMI queries also return a matching search list. + std::vector<base::string16> woutput; + base::SplitString(value, ',', &woutput); + for (size_t i = 0; i < woutput.size(); ++i) { + // Convert non-ASCII to punycode, although getaddrinfo does not properly + // handle such suffixes. + const base::string16& t = woutput[i]; + std::string parsed; + if (!ParseDomainASCII(t, &parsed)) + break; + output->push_back(parsed); + } + return !output->empty(); +} + +ConfigParseWinResult ConvertSettingsToDnsConfig( + const DnsSystemSettings& settings, + DnsConfig* config) { + *config = DnsConfig(); + + // Use GetAdapterAddresses to get effective DNS server order and + // connection-specific DNS suffix. Ignore disconnected and loopback adapters. + // The order of adapters is the network binding order, so stick to the + // first good adapter. + for (const IP_ADAPTER_ADDRESSES* adapter = settings.addresses.get(); + adapter != NULL && config->nameservers.empty(); + adapter = adapter->Next) { + if (adapter->OperStatus != IfOperStatusUp) + continue; + if (adapter->IfType == IF_TYPE_SOFTWARE_LOOPBACK) + continue; + + for (const IP_ADAPTER_DNS_SERVER_ADDRESS* address = + adapter->FirstDnsServerAddress; + address != NULL; + address = address->Next) { + IPEndPoint ipe; + if (ipe.FromSockAddr(address->Address.lpSockaddr, + address->Address.iSockaddrLength)) { + if (IsStatelessDiscoveryAddress(ipe.address())) + continue; + // Override unset port. + if (!ipe.port()) + ipe = IPEndPoint(ipe.address(), dns_protocol::kDefaultPort); + config->nameservers.push_back(ipe); + } else { + return CONFIG_PARSE_WIN_BAD_ADDRESS; + } + } + + // IP_ADAPTER_ADDRESSES in Vista+ has a search list at |FirstDnsSuffix|, + // but it came up empty in all trials. + // |DnsSuffix| stores the effective connection-specific suffix, which is + // obtained via DHCP (regkey: Tcpip\Parameters\Interfaces\{XXX}\DhcpDomain) + // or specified by the user (regkey: Tcpip\Parameters\Domain). + std::string dns_suffix; + if (ParseDomainASCII(adapter->DnsSuffix, &dns_suffix)) + config->search.push_back(dns_suffix); + } + + if (config->nameservers.empty()) + return CONFIG_PARSE_WIN_NO_NAMESERVERS; // No point continuing. + + // Windows always tries a multi-label name "as is" before using suffixes. + config->ndots = 1; + + if (!settings.append_to_multi_label_name.set) { + // The default setting is true for XP, false for Vista+. + if (base::win::GetVersion() >= base::win::VERSION_VISTA) { + config->append_to_multi_label_name = false; + } else { + config->append_to_multi_label_name = true; + } + } else { + config->append_to_multi_label_name = + (settings.append_to_multi_label_name.value != 0); + } + + // SearchList takes precedence, so check it first. + if (settings.policy_search_list.set) { + std::vector<std::string> search; + if (ParseSearchList(settings.policy_search_list.value, &search)) { + config->search.swap(search); + return CONFIG_PARSE_WIN_OK; + } + // Even if invalid, the policy disables the user-specified setting below. + } else if (settings.tcpip_search_list.set) { + std::vector<std::string> search; + if (ParseSearchList(settings.tcpip_search_list.value, &search)) { + config->search.swap(search); + return CONFIG_PARSE_WIN_OK; + } + } + + // In absence of explicit search list, suffix search is: + // [primary suffix, connection-specific suffix, devolution of primary suffix]. + // Primary suffix can be set by policy (primary_dns_suffix) or + // user setting (tcpip_domain). + // + // The policy (primary_dns_suffix) can be edited via Group Policy Editor + // (gpedit.msc) at Local Computer Policy => Computer Configuration + // => Administrative Template => Network => DNS Client => Primary DNS Suffix. + // + // The user setting (tcpip_domain) can be configurred at Computer Name in + // System Settings + std::string primary_suffix; + if ((settings.primary_dns_suffix.set && + ParseDomainASCII(settings.primary_dns_suffix.value, &primary_suffix)) || + (settings.tcpip_domain.set && + ParseDomainASCII(settings.tcpip_domain.value, &primary_suffix))) { + // Primary suffix goes in front. + config->search.insert(config->search.begin(), primary_suffix); + } else { + return CONFIG_PARSE_WIN_OK; // No primary suffix, hence no devolution. + } + + // Devolution is determined by precedence: policy > dnscache > tcpip. + // |enabled|: UseDomainNameDevolution and |level|: DomainNameDevolutionLevel + // are overridden independently. + DnsSystemSettings::DevolutionSetting devolution = settings.policy_devolution; + + if (!devolution.enabled.set) + devolution.enabled = settings.dnscache_devolution.enabled; + if (!devolution.enabled.set) + devolution.enabled = settings.tcpip_devolution.enabled; + if (devolution.enabled.set && (devolution.enabled.value == 0)) + return CONFIG_PARSE_WIN_OK; // Devolution disabled. + + // By default devolution is enabled. + + if (!devolution.level.set) + devolution.level = settings.dnscache_devolution.level; + if (!devolution.level.set) + devolution.level = settings.tcpip_devolution.level; + + // After the recent update, Windows will try to determine a safe default + // value by comparing the forest root domain (FRD) to the primary suffix. + // See http://support.microsoft.com/kb/957579 for details. + // For now, if the level is not set, we disable devolution, assuming that + // we will fallback to the system getaddrinfo anyway. This might cause + // performance loss for resolutions which depend on the system default + // devolution setting. + // + // If the level is explicitly set below 2, devolution is disabled. + if (!devolution.level.set || devolution.level.value < 2) + return CONFIG_PARSE_WIN_OK; // Devolution disabled. + + // Devolve the primary suffix. This naive logic matches the observed + // behavior (see also ParseSearchList). If a suffix is not valid, it will be + // discarded when the fully-qualified name is converted to DNS format. + + unsigned num_dots = std::count(primary_suffix.begin(), + primary_suffix.end(), '.'); + + for (size_t offset = 0; num_dots >= devolution.level.value; --num_dots) { + offset = primary_suffix.find('.', offset + 1); + config->search.push_back(primary_suffix.substr(offset + 1)); + } + return CONFIG_PARSE_WIN_OK; +} + +// Watches registry and HOSTS file for changes. Must live on a thread which +// allows IO. +class DnsConfigServiceWin::Watcher + : public NetworkChangeNotifier::IPAddressObserver { + public: + explicit Watcher(DnsConfigServiceWin* service) : service_(service) {} + ~Watcher() { + NetworkChangeNotifier::RemoveIPAddressObserver(this); + } + + bool Watch() { + RegistryWatcher::CallbackType callback = + base::Bind(&DnsConfigServiceWin::OnConfigChanged, + base::Unretained(service_)); + + bool success = true; + + // The Tcpip key must be present. + if (!tcpip_watcher_.Watch(kTcpipPath, callback)) { + LOG(ERROR) << "DNS registry watch failed to start."; + success = false; + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus", + DNS_CONFIG_WATCH_FAILED_TO_START_CONFIG, + DNS_CONFIG_WATCH_MAX); + } + + // Watch for IPv6 nameservers. + tcpip6_watcher_.Watch(kTcpip6Path, callback); + + // DNS suffix search list and devolution can be configured via group + // policy which sets this registry key. If the key is missing, the policy + // does not apply, and the DNS client uses Tcpip and Dnscache settings. + // If a policy is installed, DnsConfigService will need to be restarted. + // BUG=99509 + + dnscache_watcher_.Watch(kDnscachePath, callback); + policy_watcher_.Watch(kPolicyPath, callback); + + if (!hosts_watcher_.Watch(GetHostsPath(), false, + base::Bind(&Watcher::OnHostsChanged, + base::Unretained(this)))) { + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus", + DNS_CONFIG_WATCH_FAILED_TO_START_HOSTS, + DNS_CONFIG_WATCH_MAX); + LOG(ERROR) << "DNS hosts watch failed to start."; + success = false; + } else { + // Also need to observe changes to local non-loopback IP for DnsHosts. + NetworkChangeNotifier::AddIPAddressObserver(this); + } + return success; + } + + private: + void OnHostsChanged(const base::FilePath& path, bool error) { + if (error) + NetworkChangeNotifier::RemoveIPAddressObserver(this); + service_->OnHostsChanged(!error); + } + + // NetworkChangeNotifier::IPAddressObserver: + virtual void OnIPAddressChanged() OVERRIDE { + // Need to update non-loopback IP of local host. + service_->OnHostsChanged(true); + } + + DnsConfigServiceWin* service_; + + RegistryWatcher tcpip_watcher_; + RegistryWatcher tcpip6_watcher_; + RegistryWatcher dnscache_watcher_; + RegistryWatcher policy_watcher_; + base::FilePathWatcher hosts_watcher_; + + DISALLOW_COPY_AND_ASSIGN(Watcher); +}; + +// Reads config from registry and IpHelper. All work performed on WorkerPool. +class DnsConfigServiceWin::ConfigReader : public SerialWorker { + public: + explicit ConfigReader(DnsConfigServiceWin* service) + : service_(service), + success_(false) {} + + private: + virtual ~ConfigReader() {} + + virtual void DoWork() OVERRIDE { + // Should be called on WorkerPool. + base::TimeTicks start_time = base::TimeTicks::Now(); + DnsSystemSettings settings = {}; + ConfigParseWinResult result = ReadSystemSettings(&settings); + if (result == CONFIG_PARSE_WIN_OK) + result = ConvertSettingsToDnsConfig(settings, &dns_config_); + success_ = (result == CONFIG_PARSE_WIN_OK); + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.ConfigParseWin", + result, CONFIG_PARSE_WIN_MAX); + UMA_HISTOGRAM_BOOLEAN("AsyncDNS.ConfigParseResult", success_); + UMA_HISTOGRAM_TIMES("AsyncDNS.ConfigParseDuration", + base::TimeTicks::Now() - start_time); + } + + virtual void OnWorkFinished() OVERRIDE { + DCHECK(loop()->BelongsToCurrentThread()); + DCHECK(!IsCancelled()); + if (success_) { + service_->OnConfigRead(dns_config_); + } else { + LOG(WARNING) << "Failed to read DnsConfig."; + // Try again in a while in case DnsConfigWatcher missed the signal. + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(&ConfigReader::WorkNow, this), + base::TimeDelta::FromSeconds(kRetryIntervalSeconds)); + } + } + + DnsConfigServiceWin* service_; + // Written in DoWork(), read in OnWorkFinished(). No locking required. + DnsConfig dns_config_; + bool success_; +}; + +// Reads hosts from HOSTS file and fills in localhost and local computer name if +// necessary. All work performed on WorkerPool. +class DnsConfigServiceWin::HostsReader : public SerialWorker { + public: + explicit HostsReader(DnsConfigServiceWin* service) + : path_(GetHostsPath()), + service_(service), + success_(false) { + } + + private: + virtual ~HostsReader() {} + + virtual void DoWork() OVERRIDE { + base::TimeTicks start_time = base::TimeTicks::Now(); + HostsParseWinResult result = HOSTS_PARSE_WIN_UNREADABLE_HOSTS_FILE; + if (ParseHostsFile(path_, &hosts_)) + result = AddLocalhostEntries(&hosts_); + success_ = (result == HOSTS_PARSE_WIN_OK); + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.HostsParseWin", + result, HOSTS_PARSE_WIN_MAX); + UMA_HISTOGRAM_BOOLEAN("AsyncDNS.HostParseResult", success_); + UMA_HISTOGRAM_TIMES("AsyncDNS.HostsParseDuration", + base::TimeTicks::Now() - start_time); + } + + virtual void OnWorkFinished() OVERRIDE { + DCHECK(loop()->BelongsToCurrentThread()); + if (success_) { + service_->OnHostsRead(hosts_); + } else { + LOG(WARNING) << "Failed to read DnsHosts."; + } + } + + const base::FilePath path_; + DnsConfigServiceWin* service_; + // Written in DoWork, read in OnWorkFinished, no locking necessary. + DnsHosts hosts_; + bool success_; + + DISALLOW_COPY_AND_ASSIGN(HostsReader); +}; + +DnsConfigServiceWin::DnsConfigServiceWin() + : config_reader_(new ConfigReader(this)), + hosts_reader_(new HostsReader(this)) {} + +DnsConfigServiceWin::~DnsConfigServiceWin() { + config_reader_->Cancel(); + hosts_reader_->Cancel(); +} + +void DnsConfigServiceWin::ReadNow() { + config_reader_->WorkNow(); + hosts_reader_->WorkNow(); +} + +bool DnsConfigServiceWin::StartWatching() { + // TODO(szym): re-start watcher if that makes sense. http://crbug.com/116139 + watcher_.reset(new Watcher(this)); + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus", DNS_CONFIG_WATCH_STARTED, + DNS_CONFIG_WATCH_MAX); + return watcher_->Watch(); +} + +void DnsConfigServiceWin::OnConfigChanged(bool succeeded) { + InvalidateConfig(); + if (succeeded) { + config_reader_->WorkNow(); + } else { + LOG(ERROR) << "DNS config watch failed."; + set_watch_failed(true); + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus", + DNS_CONFIG_WATCH_FAILED_CONFIG, + DNS_CONFIG_WATCH_MAX); + } +} + +void DnsConfigServiceWin::OnHostsChanged(bool succeeded) { + InvalidateHosts(); + if (succeeded) { + hosts_reader_->WorkNow(); + } else { + LOG(ERROR) << "DNS hosts watch failed."; + set_watch_failed(true); + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.WatchStatus", + DNS_CONFIG_WATCH_FAILED_HOSTS, + DNS_CONFIG_WATCH_MAX); + } +} + +} // namespace internal + +// static +scoped_ptr<DnsConfigService> DnsConfigService::CreateSystemService() { + return scoped_ptr<DnsConfigService>(new internal::DnsConfigServiceWin()); +} + +} // namespace net diff --git a/chromium/net/dns/dns_config_service_win.h b/chromium/net/dns/dns_config_service_win.h new file mode 100644 index 00000000000..06fc0d9663b --- /dev/null +++ b/chromium/net/dns/dns_config_service_win.h @@ -0,0 +1,154 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_CONFIG_SERVICE_WIN_H_ +#define NET_DNS_DNS_CONFIG_SERVICE_WIN_H_ + +// The sole purpose of dns_config_service_win.h is for unittests so we just +// include these headers here. +#include <winsock2.h> +#include <iphlpapi.h> + +#include <string> +#include <vector> + +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/strings/string16.h" +#include "net/base/net_export.h" +#include "net/dns/dns_config_service.h" + +// The general effort of DnsConfigServiceWin is to configure |nameservers| and +// |search| in DnsConfig. The settings are stored in the Windows registry, but +// to simplify the task we use the IP Helper API wherever possible. That API +// yields the complete and ordered |nameservers|, but to determine |search| we +// need to use the registry. On Windows 7, WMI does return the correct |search| +// but on earlier versions it is insufficient. +// +// Experimental evaluation of Windows behavior suggests that domain parsing is +// naive. Domain suffixes in |search| are not validated until they are appended +// to the resolved name. We attempt to replicate this behavior. + +namespace net { + +namespace internal { + +// Registry key paths. +const wchar_t* const kTcpipPath = + L"SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters"; +const wchar_t* const kTcpip6Path = + L"SYSTEM\\CurrentControlSet\\Services\\Tcpip6\\Parameters"; +const wchar_t* const kDnscachePath = + L"SYSTEM\\CurrentControlSet\\Services\\Dnscache\\Parameters"; +const wchar_t* const kPolicyPath = + L"SOFTWARE\\Policies\\Microsoft\\Windows NT\\DNSClient"; + +// Returns the path to the HOSTS file. +base::FilePath GetHostsPath(); + +// Parses |value| as search list (comma-delimited list of domain names) from +// a registry key and stores it in |out|. Returns true on success. Empty +// entries (e.g., "chromium.org,,org") terminate the list. Non-ascii hostnames +// are converted to punycode. +bool NET_EXPORT_PRIVATE ParseSearchList(const base::string16& value, + std::vector<std::string>* out); + +// All relevant settings read from registry and IP Helper. This isolates our +// logic from system calls and is exposed for unit tests. Keep it an aggregate +// struct for easy initialization. +struct NET_EXPORT_PRIVATE DnsSystemSettings { + // The |set| flag distinguishes between empty and unset values. + struct RegString { + bool set; + base::string16 value; + }; + + struct RegDword { + bool set; + DWORD value; + }; + + struct DevolutionSetting { + // UseDomainNameDevolution + RegDword enabled; + // DomainNameDevolutionLevel + RegDword level; + }; + + // Filled in by GetAdapterAddresses. Note that the alternative + // GetNetworkParams does not include IPv6 addresses. + scoped_ptr_malloc<IP_ADAPTER_ADDRESSES> addresses; + + // SOFTWARE\Policies\Microsoft\Windows NT\DNSClient\SearchList + RegString policy_search_list; + // SYSTEM\CurrentControlSet\Tcpip\Parameters\SearchList + RegString tcpip_search_list; + // SYSTEM\CurrentControlSet\Tcpip\Parameters\Domain + RegString tcpip_domain; + // SOFTWARE\Policies\Microsoft\System\DNSClient\PrimaryDnsSuffix + RegString primary_dns_suffix; + + // SOFTWARE\Policies\Microsoft\Windows NT\DNSClient + DevolutionSetting policy_devolution; + // SYSTEM\CurrentControlSet\Dnscache\Parameters + DevolutionSetting dnscache_devolution; + // SYSTEM\CurrentControlSet\Tcpip\Parameters + DevolutionSetting tcpip_devolution; + + // SOFTWARE\Policies\Microsoft\Windows NT\DNSClient\AppendToMultiLabelName + RegDword append_to_multi_label_name; +}; + +enum ConfigParseWinResult { + CONFIG_PARSE_WIN_OK = 0, + CONFIG_PARSE_WIN_READ_IPHELPER, + CONFIG_PARSE_WIN_READ_POLICY_SEARCHLIST, + CONFIG_PARSE_WIN_READ_TCPIP_SEARCHLIST, + CONFIG_PARSE_WIN_READ_DOMAIN, + CONFIG_PARSE_WIN_READ_POLICY_DEVOLUTION, + CONFIG_PARSE_WIN_READ_DNSCACHE_DEVOLUTION, + CONFIG_PARSE_WIN_READ_TCPIP_DEVOLUTION, + CONFIG_PARSE_WIN_READ_APPEND_MULTILABEL, + CONFIG_PARSE_WIN_READ_PRIMARY_SUFFIX, + CONFIG_PARSE_WIN_BAD_ADDRESS, + CONFIG_PARSE_WIN_NO_NAMESERVERS, + CONFIG_PARSE_WIN_MAX // Bounding values for enumeration. +}; + +// Fills in |dns_config| from |settings|. Exposed for tests. +ConfigParseWinResult NET_EXPORT_PRIVATE ConvertSettingsToDnsConfig( + const DnsSystemSettings& settings, + DnsConfig* dns_config); + +// Use DnsConfigService::CreateSystemService to use it outside of tests. +class NET_EXPORT_PRIVATE DnsConfigServiceWin : public DnsConfigService { + public: + DnsConfigServiceWin(); + virtual ~DnsConfigServiceWin(); + + private: + class Watcher; + class ConfigReader; + class HostsReader; + + // DnsConfigService: + virtual void ReadNow() OVERRIDE; + virtual bool StartWatching() OVERRIDE; + + void OnConfigChanged(bool succeeded); + void OnHostsChanged(bool succeeded); + + scoped_ptr<Watcher> watcher_; + scoped_refptr<ConfigReader> config_reader_; + scoped_refptr<HostsReader> hosts_reader_; + + DISALLOW_COPY_AND_ASSIGN(DnsConfigServiceWin); +}; + +} // namespace internal + +} // namespace net + +#endif // NET_DNS_DNS_CONFIG_SERVICE_WIN_H_ + diff --git a/chromium/net/dns/dns_config_service_win_unittest.cc b/chromium/net/dns/dns_config_service_win_unittest.cc new file mode 100644 index 00000000000..b28b8e9554a --- /dev/null +++ b/chromium/net/dns/dns_config_service_win_unittest.cc @@ -0,0 +1,430 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_config_service_win.h" + +#include "base/logging.h" +#include "base/win/windows_version.h" +#include "net/dns/dns_protocol.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +TEST(DnsConfigServiceWinTest, ParseSearchList) { + const struct TestCase { + const wchar_t* input; + const char* output[4]; // NULL-terminated, empty if expected false + } cases[] = { + { L"chromium.org", { "chromium.org", NULL } }, + { L"chromium.org,org", { "chromium.org", "org", NULL } }, + // Empty suffixes terminate the list + { L"crbug.com,com,,org", { "crbug.com", "com", NULL } }, + // IDN are converted to punycode + { L"\u017c\xf3\u0142ta.pi\u0119\u015b\u0107.pl,pl", + { "xn--ta-4ja03asj.xn--pi-wla5e0q.pl", "pl", NULL } }, + // Empty search list is invalid + { L"", { NULL } }, + { L",,", { NULL } }, + }; + + std::vector<std::string> actual_output, expected_output; + for (unsigned i = 0; i < arraysize(cases); ++i) { + const TestCase& t = cases[i]; + actual_output.clear(); + actual_output.push_back("UNSET"); + expected_output.clear(); + for (const char* const* output = t.output; *output; ++output) { + expected_output.push_back(*output); + } + bool result = internal::ParseSearchList(t.input, &actual_output); + if (!expected_output.empty()) { + EXPECT_TRUE(result); + EXPECT_EQ(expected_output, actual_output); + } else { + EXPECT_FALSE(result) << "Unexpected parse success on " << t.input; + } + } +} + +struct AdapterInfo { + IFTYPE if_type; + IF_OPER_STATUS oper_status; + const WCHAR* dns_suffix; + std::string dns_server_addresses[4]; // Empty string indicates end. + int ports[4]; +}; + +scoped_ptr_malloc<IP_ADAPTER_ADDRESSES> CreateAdapterAddresses( + const AdapterInfo* infos) { + size_t num_adapters = 0; + size_t num_addresses = 0; + for (size_t i = 0; infos[i].if_type; ++i) { + ++num_adapters; + for (size_t j = 0; !infos[i].dns_server_addresses[j].empty(); ++j) { + ++num_addresses; + } + } + + size_t heap_size = num_adapters * sizeof(IP_ADAPTER_ADDRESSES) + + num_addresses * (sizeof(IP_ADAPTER_DNS_SERVER_ADDRESS) + + sizeof(struct sockaddr_storage)); + scoped_ptr_malloc<IP_ADAPTER_ADDRESSES> heap( + reinterpret_cast<IP_ADAPTER_ADDRESSES*>(malloc(heap_size))); + CHECK(heap.get()); + memset(heap.get(), 0, heap_size); + + IP_ADAPTER_ADDRESSES* adapters = heap.get(); + IP_ADAPTER_DNS_SERVER_ADDRESS* addresses = + reinterpret_cast<IP_ADAPTER_DNS_SERVER_ADDRESS*>(adapters + num_adapters); + struct sockaddr_storage* storage = + reinterpret_cast<struct sockaddr_storage*>(addresses + num_addresses); + + for (size_t i = 0; i < num_adapters; ++i) { + const AdapterInfo& info = infos[i]; + IP_ADAPTER_ADDRESSES* adapter = adapters + i; + if (i + 1 < num_adapters) + adapter->Next = adapter + 1; + adapter->IfType = info.if_type; + adapter->OperStatus = info.oper_status; + adapter->DnsSuffix = const_cast<PWCHAR>(info.dns_suffix); + IP_ADAPTER_DNS_SERVER_ADDRESS* address = NULL; + for (size_t j = 0; !info.dns_server_addresses[j].empty(); ++j) { + --num_addresses; + if (j == 0) { + address = adapter->FirstDnsServerAddress = addresses + num_addresses; + } else { + // Note that |address| is moving backwards. + address = address->Next = address - 1; + } + IPAddressNumber ip; + CHECK(ParseIPLiteralToNumber(info.dns_server_addresses[j], &ip)); + IPEndPoint ipe(ip, info.ports[j]); + address->Address.lpSockaddr = + reinterpret_cast<LPSOCKADDR>(storage + num_addresses); + socklen_t length = sizeof(struct sockaddr_storage); + CHECK(ipe.ToSockAddr(address->Address.lpSockaddr, &length)); + address->Address.iSockaddrLength = static_cast<int>(length); + } + } + + return heap.Pass(); +} + +TEST(DnsConfigServiceWinTest, ConvertAdapterAddresses) { + // Check nameservers and connection-specific suffix. + const struct TestCase { + AdapterInfo input_adapters[4]; // |if_type| == 0 indicates end. + std::string expected_nameservers[4]; // Empty string indicates end. + std::string expected_suffix; + int expected_ports[4]; + } cases[] = { + { // Ignore loopback and inactive adapters. + { + { IF_TYPE_SOFTWARE_LOOPBACK, IfOperStatusUp, L"funnyloop", + { "2.0.0.2" } }, + { IF_TYPE_FASTETHER, IfOperStatusDormant, L"example.com", + { "1.0.0.1" } }, + { IF_TYPE_USB, IfOperStatusUp, L"chromium.org", + { "10.0.0.10", "2001:FFFF::1111" } }, + { 0 }, + }, + { "10.0.0.10", "2001:FFFF::1111" }, + "chromium.org", + }, + { // Respect configured ports. + { + { IF_TYPE_USB, IfOperStatusUp, L"chromium.org", + { "10.0.0.10", "2001:FFFF::1111" }, { 1024, 24 } }, + { 0 }, + }, + { "10.0.0.10", "2001:FFFF::1111" }, + "chromium.org", + { 1024, 24 }, + }, + { // Use the preferred adapter (first in binding order) and filter + // stateless DNS discovery addresses. + { + { IF_TYPE_SOFTWARE_LOOPBACK, IfOperStatusUp, L"funnyloop", + { "2.0.0.2" } }, + { IF_TYPE_FASTETHER, IfOperStatusUp, L"example.com", + { "1.0.0.1", "fec0:0:0:ffff::2", "8.8.8.8" } }, + { IF_TYPE_USB, IfOperStatusUp, L"chromium.org", + { "10.0.0.10", "2001:FFFF::1111" } }, + { 0 }, + }, + { "1.0.0.1", "8.8.8.8" }, + "example.com", + }, + { // No usable adapters. + { + { IF_TYPE_SOFTWARE_LOOPBACK, IfOperStatusUp, L"localhost", + { "2.0.0.2" } }, + { IF_TYPE_FASTETHER, IfOperStatusDormant, L"example.com", + { "1.0.0.1" } }, + { IF_TYPE_USB, IfOperStatusUp, L"chromium.org" }, + { 0 }, + }, + }, + }; + + for (size_t i = 0; i < arraysize(cases); ++i) { + const TestCase& t = cases[i]; + internal::DnsSystemSettings settings = { + CreateAdapterAddresses(t.input_adapters), + // Default settings for the rest. + }; + std::vector<IPEndPoint> expected_nameservers; + for (size_t j = 0; !t.expected_nameservers[j].empty(); ++j) { + IPAddressNumber ip; + ASSERT_TRUE(ParseIPLiteralToNumber(t.expected_nameservers[j], &ip)); + int port = t.expected_ports[j]; + if (!port) + port = dns_protocol::kDefaultPort; + expected_nameservers.push_back(IPEndPoint(ip, port)); + } + + DnsConfig config; + internal::ConfigParseWinResult result = + internal::ConvertSettingsToDnsConfig(settings, &config); + internal::ConfigParseWinResult expected_result = + expected_nameservers.empty() ? internal::CONFIG_PARSE_WIN_NO_NAMESERVERS + : internal::CONFIG_PARSE_WIN_OK; + EXPECT_EQ(expected_result, result); + EXPECT_EQ(expected_nameservers, config.nameservers); + if (result == internal::CONFIG_PARSE_WIN_OK) { + ASSERT_EQ(1u, config.search.size()); + EXPECT_EQ(t.expected_suffix, config.search[0]); + } + } +} + +TEST(DnsConfigServiceWinTest, ConvertSuffixSearch) { + AdapterInfo infos[2] = { + { IF_TYPE_USB, IfOperStatusUp, L"connection.suffix", { "1.0.0.1" } }, + { 0 }, + }; + + const struct TestCase { + internal::DnsSystemSettings input_settings; + std::string expected_search[5]; + } cases[] = { + { // Policy SearchList override. + { + CreateAdapterAddresses(infos), + { true, L"policy.searchlist.a,policy.searchlist.b" }, + { true, L"tcpip.searchlist.a,tcpip.searchlist.b" }, + { true, L"tcpip.domain" }, + { true, L"primary.dns.suffix" }, + }, + { "policy.searchlist.a", "policy.searchlist.b" }, + }, + { // User-specified SearchList override. + { + CreateAdapterAddresses(infos), + { false }, + { true, L"tcpip.searchlist.a,tcpip.searchlist.b" }, + { true, L"tcpip.domain" }, + { true, L"primary.dns.suffix" }, + }, + { "tcpip.searchlist.a", "tcpip.searchlist.b" }, + }, + { // Void SearchList. Using tcpip.domain + { + CreateAdapterAddresses(infos), + { true, L",bad.searchlist,parsed.as.empty" }, + { true, L"tcpip.searchlist,good.but.overridden" }, + { true, L"tcpip.domain" }, + { false }, + }, + { "tcpip.domain", "connection.suffix" }, + }, + { // Void SearchList. Using primary.dns.suffix + { + CreateAdapterAddresses(infos), + { true, L",bad.searchlist,parsed.as.empty" }, + { true, L"tcpip.searchlist,good.but.overridden" }, + { true, L"tcpip.domain" }, + { true, L"primary.dns.suffix" }, + }, + { "primary.dns.suffix", "connection.suffix" }, + }, + { // Void SearchList. Using tcpip.domain when primary.dns.suffix is empty + { + CreateAdapterAddresses(infos), + { true, L",bad.searchlist,parsed.as.empty" }, + { true, L"tcpip.searchlist,good.but.overridden" }, + { true, L"tcpip.domain" }, + { true, L"" }, + }, + { "tcpip.domain", "connection.suffix" }, + }, + { // Void SearchList. Using tcpip.domain when primary.dns.suffix is NULL + { + CreateAdapterAddresses(infos), + { true, L",bad.searchlist,parsed.as.empty" }, + { true, L"tcpip.searchlist,good.but.overridden" }, + { true, L"tcpip.domain" }, + { true }, + }, + { "tcpip.domain", "connection.suffix" }, + }, + { // No primary suffix. Devolution does not matter. + { + CreateAdapterAddresses(infos), + { false }, + { false }, + { true }, + { true }, + { { true, 1 }, { true, 2 } }, + }, + { "connection.suffix" }, + }, + { // Devolution enabled by policy, level by dnscache. + { + CreateAdapterAddresses(infos), + { false }, + { false }, + { true, L"a.b.c.d.e" }, + { false }, + { { true, 1 }, { false } }, // policy_devolution: enabled, level + { { true, 0 }, { true, 3 } }, // dnscache_devolution + { { true, 0 }, { true, 1 } }, // tcpip_devolution + }, + { "a.b.c.d.e", "connection.suffix", "b.c.d.e", "c.d.e" }, + }, + { // Devolution enabled by dnscache, level by policy. + { + CreateAdapterAddresses(infos), + { false }, + { false }, + { true, L"a.b.c.d.e" }, + { true, L"f.g.i.l.j" }, + { { false }, { true, 4 } }, + { { true, 1 }, { false } }, + { { true, 0 }, { true, 3 } }, + }, + { "f.g.i.l.j", "connection.suffix", "g.i.l.j" }, + }, + { // Devolution enabled by default. + { + CreateAdapterAddresses(infos), + { false }, + { false }, + { true, L"a.b.c.d.e" }, + { false }, + { { false }, { false } }, + { { false }, { true, 3 } }, + { { false }, { true, 1 } }, + }, + { "a.b.c.d.e", "connection.suffix", "b.c.d.e", "c.d.e" }, + }, + { // Devolution enabled at level = 2, but nothing to devolve. + { + CreateAdapterAddresses(infos), + { false }, + { false }, + { true, L"a.b" }, + { false }, + { { false }, { false } }, + { { false }, { true, 2 } }, + { { false }, { true, 2 } }, + }, + { "a.b", "connection.suffix" }, + }, + { // Devolution disabled when no explicit level. + // Windows XP and Vista use a default level = 2, but we don't. + { + CreateAdapterAddresses(infos), + { false }, + { false }, + { true, L"a.b.c.d.e" }, + { false }, + { { true, 1 }, { false } }, + { { true, 1 }, { false } }, + { { true, 1 }, { false } }, + }, + { "a.b.c.d.e", "connection.suffix" }, + }, + { // Devolution disabled by policy level. + { + CreateAdapterAddresses(infos), + { false }, + { false }, + { true, L"a.b.c.d.e" }, + { false }, + { { false }, { true, 1 } }, + { { true, 1 }, { true, 3 } }, + { { true, 1 }, { true, 4 } }, + }, + { "a.b.c.d.e", "connection.suffix" }, + }, + { // Devolution disabled by user setting. + { + CreateAdapterAddresses(infos), + { false }, + { false }, + { true, L"a.b.c.d.e" }, + { false }, + { { false }, { true, 3 } }, + { { false }, { true, 3 } }, + { { true, 0 }, { true, 3 } }, + }, + { "a.b.c.d.e", "connection.suffix" }, + }, + }; + + for (size_t i = 0; i < arraysize(cases); ++i) { + const TestCase& t = cases[i]; + DnsConfig config; + EXPECT_EQ(internal::CONFIG_PARSE_WIN_OK, + internal::ConvertSettingsToDnsConfig(t.input_settings, &config)); + std::vector<std::string> expected_search; + for (size_t j = 0; !t.expected_search[j].empty(); ++j) { + expected_search.push_back(t.expected_search[j]); + } + EXPECT_EQ(expected_search, config.search); + } +} + +TEST(DnsConfigServiceWinTest, AppendToMultiLabelName) { + AdapterInfo infos[2] = { + { IF_TYPE_USB, IfOperStatusUp, L"connection.suffix", { "1.0.0.1" } }, + { 0 }, + }; + + // The default setting was true pre-Vista. + bool default_value = (base::win::GetVersion() < base::win::VERSION_VISTA); + + const struct TestCase { + internal::DnsSystemSettings::RegDword input; + bool expected_output; + } cases[] = { + { { true, 0 }, false }, + { { true, 1 }, true }, + { { false, 0 }, default_value }, + }; + + for (size_t i = 0; i < arraysize(cases); ++i) { + const TestCase& t = cases[i]; + internal::DnsSystemSettings settings = { + CreateAdapterAddresses(infos), + { false }, { false }, { false }, { false }, + { { false }, { false } }, + { { false }, { false } }, + { { false }, { false } }, + t.input, + }; + DnsConfig config; + EXPECT_EQ(internal::CONFIG_PARSE_WIN_OK, + internal::ConvertSettingsToDnsConfig(settings, &config)); + EXPECT_EQ(config.append_to_multi_label_name, t.expected_output); + } +} + +} // namespace + +} // namespace net + diff --git a/chromium/net/dns/dns_hosts.cc b/chromium/net/dns/dns_hosts.cc new file mode 100644 index 00000000000..852d35c8bb4 --- /dev/null +++ b/chromium/net/dns/dns_hosts.cc @@ -0,0 +1,169 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_hosts.h" + +#include "base/file_util.h" +#include "base/logging.h" +#include "base/metrics/histogram.h" +#include "base/strings/string_util.h" +#include "base/strings/string_tokenizer.h" + +using base::StringPiece; + +namespace net { + +// Parses the contents of a hosts file. Returns one token (IP or hostname) at +// a time. Doesn't copy anything; accepts the file as a StringPiece and +// returns tokens as StringPieces. +class HostsParser { + public: + explicit HostsParser(const StringPiece& text) + : text_(text), + data_(text.data()), + end_(text.size()), + pos_(0), + token_(), + token_is_ip_(false) {} + + // Advances to the next token (IP or hostname). Returns whether another + // token was available. |token_is_ip| and |token| can be used to find out + // the type and text of the token. + bool Advance() { + bool next_is_ip = (pos_ == 0); + while (pos_ < end_ && pos_ != std::string::npos) { + switch (text_[pos_]) { + case ' ': + case '\t': + SkipWhitespace(); + break; + + case '\r': + case '\n': + next_is_ip = true; + pos_++; + break; + + case '#': + SkipRestOfLine(); + break; + + default: { + size_t token_start = pos_; + SkipToken(); + size_t token_end = (pos_ == std::string::npos) ? end_ : pos_; + + token_ = StringPiece(data_ + token_start, token_end - token_start); + token_is_ip_ = next_is_ip; + + return true; + } + } + } + + text_ = StringPiece(); + return false; + } + + // Fast-forwards the parser to the next line. Should be called if an IP + // address doesn't parse, to avoid wasting time tokenizing hostnames that + // will be ignored. + void SkipRestOfLine() { + pos_ = text_.find("\n", pos_); + } + + // Returns whether the last-parsed token is an IP address (true) or a + // hostname (false). + bool token_is_ip() { return token_is_ip_; } + + // Returns the text of the last-parsed token as a StringPiece referencing + // the same underlying memory as the StringPiece passed to the constructor. + // Returns an empty StringPiece if no token has been parsed or the end of + // the input string has been reached. + const StringPiece& token() { return token_; } + + private: + void SkipToken() { + pos_ = text_.find_first_of(" \t\n\r#", pos_); + } + + void SkipWhitespace() { + pos_ = text_.find_first_not_of(" \t", pos_); + } + + StringPiece text_; + const char* data_; + const size_t end_; + + size_t pos_; + StringPiece token_; + bool token_is_ip_; + + DISALLOW_COPY_AND_ASSIGN(HostsParser); +}; + + + +void ParseHosts(const std::string& contents, DnsHosts* dns_hosts) { + CHECK(dns_hosts); + DnsHosts& hosts = *dns_hosts; + + StringPiece ip_text; + IPAddressNumber ip; + AddressFamily family = ADDRESS_FAMILY_IPV4; + HostsParser parser(contents); + while (parser.Advance()) { + if (parser.token_is_ip()) { + StringPiece new_ip_text = parser.token(); + // Some ad-blocking hosts files contain thousands of entries pointing to + // the same IP address (usually 127.0.0.1). Don't bother parsing the IP + // again if it's the same as the one above it. + if (new_ip_text != ip_text) { + IPAddressNumber new_ip; + if (ParseIPLiteralToNumber(parser.token().as_string(), &new_ip)) { + ip_text = new_ip_text; + ip.swap(new_ip); + family = (ip.size() == 4) ? ADDRESS_FAMILY_IPV4 : ADDRESS_FAMILY_IPV6; + } else { + parser.SkipRestOfLine(); + } + } + } else { + DnsHostsKey key(parser.token().as_string(), family); + StringToLowerASCII(&key.first); + IPAddressNumber& mapped_ip = hosts[key]; + if (mapped_ip.empty()) + mapped_ip = ip; + // else ignore this entry (first hit counts) + } + } +} + +bool ParseHostsFile(const base::FilePath& path, DnsHosts* dns_hosts) { + dns_hosts->clear(); + // Missing file indicates empty HOSTS. + if (!base::PathExists(path)) + return true; + + int64 size; + if (!file_util::GetFileSize(path, &size)) + return false; + + UMA_HISTOGRAM_COUNTS("AsyncDNS.HostsSize", size); + + // Reject HOSTS files larger than |kMaxHostsSize| bytes. + const int64 kMaxHostsSize = 1 << 25; // 32MB + if (size > kMaxHostsSize) + return false; + + std::string contents; + if (!file_util::ReadFileToString(path, &contents)) + return false; + + ParseHosts(contents, dns_hosts); + return true; +} + +} // namespace net + diff --git a/chromium/net/dns/dns_hosts.h b/chromium/net/dns/dns_hosts.h new file mode 100644 index 00000000000..c2b290907ad --- /dev/null +++ b/chromium/net/dns/dns_hosts.h @@ -0,0 +1,79 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_HOSTS_H_ +#define NET_DNS_DNS_HOSTS_H_ + +#include <map> +#include <string> +#include <utility> +#include <vector> + +#include "base/basictypes.h" +#include "base/containers/hash_tables.h" +#include "base/files/file_path.h" +#include "net/base/address_family.h" +#include "net/base/net_export.h" +#include "net/base/net_util.h" // can't forward-declare IPAddressNumber + +namespace net { + typedef std::pair<std::string, AddressFamily> DnsHostsKey; +}; + +namespace BASE_HASH_NAMESPACE { +#if defined(COMPILER_GCC) + +template<> +struct hash<net::DnsHostsKey> { + std::size_t operator()(const net::DnsHostsKey& key) const { + hash<base::StringPiece> string_piece_hash; + return string_piece_hash(key.first) + key.second; + } +}; + +#elif defined(COMPILER_MSVC) + +inline size_t hash_value(const net::DnsHostsKey& key) { + return hash_value(key.first) + key.second; +} + +#endif // COMPILER + +} // namespace BASE_HASH_NAMESPACE + +namespace net { + +// Parsed results of a Hosts file. +// +// Although Hosts files map IP address to a list of domain names, for name +// resolution the desired mapping direction is: domain name to IP address. +// When parsing Hosts, we apply the "first hit" rule as Windows and glibc do. +// With a Hosts file of: +// 300.300.300.300 localhost # bad ip +// 127.0.0.1 localhost +// 10.0.0.1 localhost +// The expected resolution of localhost is 127.0.0.1. +#if !defined(OS_ANDROID) +typedef base::hash_map<DnsHostsKey, IPAddressNumber> DnsHosts; +#else +// Android's hash_map doesn't support ==, so fall back to map. (Chromium on +// Android doesn't use the built-in DNS resolver anyway, so it's irrelevant.) +typedef std::map<DnsHostsKey, IPAddressNumber> DnsHosts; +#endif + +// Parses |contents| (as read from /etc/hosts or equivalent) and stores results +// in |dns_hosts|. Invalid lines are ignored (as in most implementations). +void NET_EXPORT_PRIVATE ParseHosts(const std::string& contents, + DnsHosts* dns_hosts); + +// As above but reads the file pointed to by |path|. +bool NET_EXPORT_PRIVATE ParseHostsFile(const base::FilePath& path, + DnsHosts* dns_hosts); + + + +} // namespace net + +#endif // NET_DNS_DNS_HOSTS_H_ + diff --git a/chromium/net/dns/dns_hosts_unittest.cc b/chromium/net/dns/dns_hosts_unittest.cc new file mode 100644 index 00000000000..c0e8805fc47 --- /dev/null +++ b/chromium/net/dns/dns_hosts_unittest.cc @@ -0,0 +1,124 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_hosts.h" + +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +TEST(DnsHostsTest, ParseHosts) { + std::string contents = + "127.0.0.1 localhost\tlocalhost.localdomain # standard\n" + "\n" + "1.0.0.1 localhost # ignored, first hit above\n" + "fe00::x example company # ignored, malformed IPv6\n" + "1.0.0.300 company # ignored, malformed IPv4\n" + "1.0.0.1 # ignored, missing hostname\n" + "1.0.0.1\t CoMpANy # normalized to 'company' \n" + "::1\tlocalhost ip6-localhost ip6-loopback # comment # within a comment\n" + "\t fe00::0 ip6-localnet\r\n" + "2048::2 example\n" + "2048::1 company example # ignored for 'example' \n" + "127.0.0.1 cache1\n" + "127.0.0.1 cache2 # should reuse parsed IP\n" + "256.0.0.0 cache3 # bogus IP should not clear parsed IP cache\n" + "127.0.0.1 cache4 # should still be reused\n" + "127.0.0.2 cache5\n" + "gibberish"; + + const struct { + const char* host; + AddressFamily family; + const char* ip; + } entries[] = { + { "localhost", ADDRESS_FAMILY_IPV4, "127.0.0.1" }, + { "localhost.localdomain", ADDRESS_FAMILY_IPV4, "127.0.0.1" }, + { "company", ADDRESS_FAMILY_IPV4, "1.0.0.1" }, + { "localhost", ADDRESS_FAMILY_IPV6, "::1" }, + { "ip6-localhost", ADDRESS_FAMILY_IPV6, "::1" }, + { "ip6-loopback", ADDRESS_FAMILY_IPV6, "::1" }, + { "ip6-localnet", ADDRESS_FAMILY_IPV6, "fe00::0" }, + { "company", ADDRESS_FAMILY_IPV6, "2048::1" }, + { "example", ADDRESS_FAMILY_IPV6, "2048::2" }, + { "cache1", ADDRESS_FAMILY_IPV4, "127.0.0.1" }, + { "cache2", ADDRESS_FAMILY_IPV4, "127.0.0.1" }, + { "cache4", ADDRESS_FAMILY_IPV4, "127.0.0.1" }, + { "cache5", ADDRESS_FAMILY_IPV4, "127.0.0.2" }, + }; + + DnsHosts expected; + for (size_t i = 0; i < ARRAYSIZE_UNSAFE(entries); ++i) { + DnsHostsKey key(entries[i].host, entries[i].family); + IPAddressNumber& ip = expected[key]; + ASSERT_TRUE(ip.empty()); + ASSERT_TRUE(ParseIPLiteralToNumber(entries[i].ip, &ip)); + ASSERT_EQ(ip.size(), (entries[i].family == ADDRESS_FAMILY_IPV4) ? 4u : 16u); + } + + DnsHosts hosts; + ParseHosts(contents, &hosts); + ASSERT_EQ(expected, hosts); +} + +TEST(DnsHostsTest, HostsParser_Empty) { + DnsHosts hosts; + ParseHosts("", &hosts); + EXPECT_EQ(0u, hosts.size()); +} + +TEST(DnsHostsTest, HostsParser_OnlyWhitespace) { + DnsHosts hosts; + ParseHosts(" ", &hosts); + EXPECT_EQ(0u, hosts.size()); +} + +TEST(DnsHostsTest, HostsParser_EndsWithNothing) { + DnsHosts hosts; + ParseHosts("127.0.0.1 localhost", &hosts); + EXPECT_EQ(1u, hosts.size()); +} + +TEST(DnsHostsTest, HostsParser_EndsWithWhitespace) { + DnsHosts hosts; + ParseHosts("127.0.0.1 localhost ", &hosts); + EXPECT_EQ(1u, hosts.size()); +} + +TEST(DnsHostsTest, HostsParser_EndsWithComment) { + DnsHosts hosts; + ParseHosts("127.0.0.1 localhost # comment", &hosts); + EXPECT_EQ(1u, hosts.size()); +} + +TEST(DnsHostsTest, HostsParser_EndsWithNewline) { + DnsHosts hosts; + ParseHosts("127.0.0.1 localhost\n", &hosts); + EXPECT_EQ(1u, hosts.size()); +} + +TEST(DnsHostsTest, HostsParser_EndsWithTwoNewlines) { + DnsHosts hosts; + ParseHosts("127.0.0.1 localhost\n\n", &hosts); + EXPECT_EQ(1u, hosts.size()); +} + +TEST(DnsHostsTest, HostsParser_EndsWithNewlineAndWhitespace) { + DnsHosts hosts; + ParseHosts("127.0.0.1 localhost\n ", &hosts); + EXPECT_EQ(1u, hosts.size()); +} + +TEST(DnsHostsTest, HostsParser_EndsWithNewlineAndToken) { + DnsHosts hosts; + ParseHosts("127.0.0.1 localhost\ntoken", &hosts); + EXPECT_EQ(1u, hosts.size()); +} + +} // namespace + +} // namespace net + diff --git a/chromium/net/dns/dns_protocol.h b/chromium/net/dns/dns_protocol.h new file mode 100644 index 00000000000..a8aad65c2fb --- /dev/null +++ b/chromium/net/dns/dns_protocol.h @@ -0,0 +1,143 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_PROTOCOL_H_ +#define NET_DNS_DNS_PROTOCOL_H_ + +#include "base/basictypes.h" +#include "net/base/net_export.h" + +namespace net { + +namespace dns_protocol { + +static const uint16 kDefaultPort = 53; +static const uint16 kDefaultPortMulticast = 5353; + +// DNS packet consists of a header followed by questions and/or answers. +// For the meaning of specific fields, please see RFC 1035 and 2535 + +// Header format. +// 1 1 1 1 1 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | ID | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// |QR| Opcode |AA|TC|RD|RA| Z|AD|CD| RCODE | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | QDCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | ANCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | NSCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | ARCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + +// Question format. +// 1 1 1 1 1 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | | +// / QNAME / +// / / +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | QTYPE | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | QCLASS | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + +// Answer format. +// 1 1 1 1 1 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | | +// / / +// / NAME / +// | | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | TYPE | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | CLASS | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | TTL | +// | | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | RDLENGTH | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--| +// / RDATA / +// / / +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + +#pragma pack(push) +#pragma pack(1) + +// On-the-wire header. All uint16 are in network order. +// Used internally in DnsQuery and DnsResponseParser. +struct NET_EXPORT_PRIVATE Header { + uint16 id; + uint16 flags; + uint16 qdcount; + uint16 ancount; + uint16 nscount; + uint16 arcount; +}; + +#pragma pack(pop) + +static const uint8 kLabelMask = 0xc0; +static const uint8 kLabelPointer = 0xc0; +static const uint8 kLabelDirect = 0x0; +static const uint16 kOffsetMask = 0x3fff; + +// In MDns the most significant bit of the rrclass is designated as the +// "cache-flush bit", as described in http://www.rfc-editor.org/rfc/rfc6762.txt +// section 10.2. +static const uint16 kMDnsClassMask = 0x7FFF; + +static const int kMaxNameLength = 255; + +// RFC 1035, section 4.2.1: Messages carried by UDP are restricted to 512 +// bytes (not counting the IP nor UDP headers). +static const int kMaxUDPSize = 512; + +// RFC 6762, section 17: Messages over the local link are restricted by the +// medium's MTU, and must be under 9000 bytes +static const int kMaxMulticastSize = 9000; + +// DNS class types. +static const uint16 kClassIN = 1; + +// DNS resource record types. See +// http://www.iana.org/assignments/dns-parameters +static const uint16 kTypeA = 1; +static const uint16 kTypeCNAME = 5; +static const uint16 kTypePTR = 12; +static const uint16 kTypeTXT = 16; +static const uint16 kTypeAAAA = 28; +static const uint16 kTypeSRV = 33; +static const uint16 kTypeNSEC = 47; + + +// DNS rcode values. +static const uint8 kRcodeMask = 0xf; +static const uint8 kRcodeNOERROR = 0; +static const uint8 kRcodeFORMERR = 1; +static const uint8 kRcodeSERVFAIL = 2; +static const uint8 kRcodeNXDOMAIN = 3; +static const uint8 kRcodeNOTIMP = 4; +static const uint8 kRcodeREFUSED = 5; + +// DNS flags. +static const uint16 kFlagResponse = 0x8000; +static const uint16 kFlagRA = 0x80; +static const uint16 kFlagRD = 0x100; +static const uint16 kFlagTC = 0x200; +static const uint16 kFlagAA = 0x400; + +} // namespace dns_protocol + +} // namespace net + +#endif // NET_DNS_DNS_PROTOCOL_H_ diff --git a/chromium/net/dns/dns_query.cc b/chromium/net/dns/dns_query.cc new file mode 100644 index 00000000000..270757e7be9 --- /dev/null +++ b/chromium/net/dns/dns_query.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_query.h" + +#include <limits> + +#include "base/sys_byteorder.h" +#include "net/base/big_endian.h" +#include "net/base/dns_util.h" +#include "net/base/io_buffer.h" +#include "net/dns/dns_protocol.h" + +namespace net { + +// DNS query consists of a 12-byte header followed by a question section. +// For details, see RFC 1035 section 4.1.1. This header template sets RD +// bit, which directs the name server to pursue query recursively, and sets +// the QDCOUNT to 1, meaning the question section has a single entry. +DnsQuery::DnsQuery(uint16 id, const base::StringPiece& qname, uint16 qtype) + : qname_size_(qname.size()) { + DCHECK(!DNSDomainToString(qname).empty()); + // QNAME + QTYPE + QCLASS + size_t question_size = qname_size_ + sizeof(uint16) + sizeof(uint16); + io_buffer_ = new IOBufferWithSize(sizeof(dns_protocol::Header) + + question_size); + dns_protocol::Header* header = + reinterpret_cast<dns_protocol::Header*>(io_buffer_->data()); + memset(header, 0, sizeof(dns_protocol::Header)); + header->id = base::HostToNet16(id); + header->flags = base::HostToNet16(dns_protocol::kFlagRD); + header->qdcount = base::HostToNet16(1); + + // Write question section after the header. + BigEndianWriter writer(reinterpret_cast<char*>(header + 1), question_size); + writer.WriteBytes(qname.data(), qname.size()); + writer.WriteU16(qtype); + writer.WriteU16(dns_protocol::kClassIN); +} + +DnsQuery::~DnsQuery() { +} + +DnsQuery* DnsQuery::CloneWithNewId(uint16 id) const { + return new DnsQuery(*this, id); +} + +uint16 DnsQuery::id() const { + const dns_protocol::Header* header = + reinterpret_cast<const dns_protocol::Header*>(io_buffer_->data()); + return base::NetToHost16(header->id); +} + +base::StringPiece DnsQuery::qname() const { + return base::StringPiece(io_buffer_->data() + sizeof(dns_protocol::Header), + qname_size_); +} + +uint16 DnsQuery::qtype() const { + uint16 type; + ReadBigEndian<uint16>(io_buffer_->data() + + sizeof(dns_protocol::Header) + + qname_size_, &type); + return type; +} + +base::StringPiece DnsQuery::question() const { + return base::StringPiece(io_buffer_->data() + sizeof(dns_protocol::Header), + qname_size_ + sizeof(uint16) + sizeof(uint16)); +} + +DnsQuery::DnsQuery(const DnsQuery& orig, uint16 id) { + qname_size_ = orig.qname_size_; + io_buffer_ = new IOBufferWithSize(orig.io_buffer()->size()); + memcpy(io_buffer_.get()->data(), orig.io_buffer()->data(), + io_buffer_.get()->size()); + dns_protocol::Header* header = + reinterpret_cast<dns_protocol::Header*>(io_buffer_->data()); + header->id = base::HostToNet16(id); +} + +void DnsQuery::set_flags(uint16 flags) { + dns_protocol::Header* header = + reinterpret_cast<dns_protocol::Header*>(io_buffer_->data()); + header->flags = flags; +} + +} // namespace net diff --git a/chromium/net/dns/dns_query.h b/chromium/net/dns/dns_query.h new file mode 100644 index 00000000000..e1469bdfbc0 --- /dev/null +++ b/chromium/net/dns/dns_query.h @@ -0,0 +1,58 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_QUERY_H_ +#define NET_DNS_DNS_QUERY_H_ + +#include "base/basictypes.h" +#include "base/memory/ref_counted.h" +#include "base/strings/string_piece.h" +#include "net/base/net_export.h" + +namespace net { + +class IOBufferWithSize; + +// Represents on-the-wire DNS query message as an object. +// TODO(szym): add support for the OPT pseudo-RR (EDNS0/DNSSEC). +class NET_EXPORT_PRIVATE DnsQuery { + public: + // Constructs a query message from |qname| which *MUST* be in a valid + // DNS name format, and |qtype|. The qclass is set to IN. + DnsQuery(uint16 id, const base::StringPiece& qname, uint16 qtype); + ~DnsQuery(); + + // Clones |this| verbatim, with ID field of the header set to |id|. + DnsQuery* CloneWithNewId(uint16 id) const; + + // DnsQuery field accessors. + uint16 id() const; + base::StringPiece qname() const; + uint16 qtype() const; + + // Returns the Question section of the query. Used when matching the + // response. + base::StringPiece question() const; + + // IOBuffer accessor to be used for writing out the query. + IOBufferWithSize* io_buffer() const { return io_buffer_.get(); } + + void set_flags(uint16 flags); + + private: + DnsQuery(const DnsQuery& orig, uint16 id); + + // Size of the DNS name (*NOT* hostname) we are trying to resolve; used + // to calculate offsets. + size_t qname_size_; + + // Contains query bytes to be consumed by higher level Write() call. + scoped_refptr<IOBufferWithSize> io_buffer_; + + DISALLOW_COPY_AND_ASSIGN(DnsQuery); +}; + +} // namespace net + +#endif // NET_DNS_DNS_QUERY_H_ diff --git a/chromium/net/dns/dns_query_unittest.cc b/chromium/net/dns/dns_query_unittest.cc new file mode 100644 index 00000000000..ffe02e70a61 --- /dev/null +++ b/chromium/net/dns/dns_query_unittest.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_query.h" + +#include "base/bind.h" +#include "net/base/dns_util.h" +#include "net/base/io_buffer.h" +#include "net/dns/dns_protocol.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +TEST(DnsQueryTest, Constructor) { + // This includes \0 at the end. + const char qname_data[] = "\x03""www""\x07""example""\x03""com"; + const uint8 query_data[] = { + // Header + 0xbe, 0xef, + 0x01, 0x00, // Flags -- set RD (recursion desired) bit. + 0x00, 0x01, // Set QDCOUNT (question count) to 1, all the + // rest are 0 for a query. + 0x00, 0x00, + 0x00, 0x00, + 0x00, 0x00, + + // Question + 0x03, 'w', 'w', 'w', // QNAME: www.example.com in DNS format. + 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00, + + 0x00, 0x01, // QTYPE: A query. + 0x00, 0x01, // QCLASS: IN class. + }; + + base::StringPiece qname(qname_data, sizeof(qname_data)); + DnsQuery q1(0xbeef, qname, dns_protocol::kTypeA); + EXPECT_EQ(dns_protocol::kTypeA, q1.qtype()); + + ASSERT_EQ(static_cast<int>(sizeof(query_data)), q1.io_buffer()->size()); + EXPECT_EQ(0, memcmp(q1.io_buffer()->data(), query_data, sizeof(query_data))); + EXPECT_EQ(qname, q1.qname()); + + base::StringPiece question(reinterpret_cast<const char*>(query_data) + 12, + 21); + EXPECT_EQ(question, q1.question()); +} + +TEST(DnsQueryTest, Clone) { + // This includes \0 at the end. + const char qname_data[] = "\x03""www""\x07""example""\x03""com"; + base::StringPiece qname(qname_data, sizeof(qname_data)); + + DnsQuery q1(0, qname, dns_protocol::kTypeA); + EXPECT_EQ(0, q1.id()); + scoped_ptr<DnsQuery> q2(q1.CloneWithNewId(42)); + EXPECT_EQ(42, q2->id()); + EXPECT_EQ(q1.io_buffer()->size(), q2->io_buffer()->size()); + EXPECT_EQ(q1.qtype(), q2->qtype()); + EXPECT_EQ(q1.question(), q2->question()); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/dns/dns_response.cc b/chromium/net/dns/dns_response.cc new file mode 100644 index 00000000000..d29d3c4813c --- /dev/null +++ b/chromium/net/dns/dns_response.cc @@ -0,0 +1,337 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_response.h" + +#include "base/strings/string_util.h" +#include "base/sys_byteorder.h" +#include "net/base/address_list.h" +#include "net/base/big_endian.h" +#include "net/base/dns_util.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/dns_query.h" + +namespace net { + +DnsResourceRecord::DnsResourceRecord() { +} + +DnsResourceRecord::~DnsResourceRecord() { +} + +DnsRecordParser::DnsRecordParser() : packet_(NULL), length_(0), cur_(0) { +} + +DnsRecordParser::DnsRecordParser(const void* packet, + size_t length, + size_t offset) + : packet_(reinterpret_cast<const char*>(packet)), + length_(length), + cur_(packet_ + offset) { + DCHECK_LE(offset, length); +} + +unsigned DnsRecordParser::ReadName(const void* const vpos, + std::string* out) const { + const char* const pos = reinterpret_cast<const char*>(vpos); + DCHECK(packet_); + DCHECK_LE(packet_, pos); + DCHECK_LE(pos, packet_ + length_); + + const char* p = pos; + const char* end = packet_ + length_; + // Count number of seen bytes to detect loops. + unsigned seen = 0; + // Remember how many bytes were consumed before first jump. + unsigned consumed = 0; + + if (pos >= end) + return 0; + + if (out) { + out->clear(); + out->reserve(dns_protocol::kMaxNameLength); + } + + for (;;) { + // The first two bits of the length give the type of the length. It's + // either a direct length or a pointer to the remainder of the name. + switch (*p & dns_protocol::kLabelMask) { + case dns_protocol::kLabelPointer: { + if (p + sizeof(uint16) > end) + return 0; + if (consumed == 0) { + consumed = p - pos + sizeof(uint16); + if (!out) + return consumed; // If name is not stored, that's all we need. + } + seen += sizeof(uint16); + // If seen the whole packet, then we must be in a loop. + if (seen > length_) + return 0; + uint16 offset; + ReadBigEndian<uint16>(p, &offset); + offset &= dns_protocol::kOffsetMask; + p = packet_ + offset; + if (p >= end) + return 0; + break; + } + case dns_protocol::kLabelDirect: { + uint8 label_len = *p; + ++p; + // Note: root domain (".") is NOT included. + if (label_len == 0) { + if (consumed == 0) { + consumed = p - pos; + } // else we set |consumed| before first jump + return consumed; + } + if (p + label_len >= end) + return 0; // Truncated or missing label. + if (out) { + if (!out->empty()) + out->append("."); + out->append(p, label_len); + } + p += label_len; + seen += 1 + label_len; + break; + } + default: + // unhandled label type + return 0; + } + } +} + +bool DnsRecordParser::ReadRecord(DnsResourceRecord* out) { + DCHECK(packet_); + size_t consumed = ReadName(cur_, &out->name); + if (!consumed) + return false; + BigEndianReader reader(cur_ + consumed, + packet_ + length_ - (cur_ + consumed)); + uint16 rdlen; + if (reader.ReadU16(&out->type) && + reader.ReadU16(&out->klass) && + reader.ReadU32(&out->ttl) && + reader.ReadU16(&rdlen) && + reader.ReadPiece(&out->rdata, rdlen)) { + cur_ = reader.ptr(); + return true; + } + return false; +} + +bool DnsRecordParser::SkipQuestion() { + size_t consumed = ReadName(cur_, NULL); + if (!consumed) + return false; + + const char* next = cur_ + consumed + 2 * sizeof(uint16); // QTYPE + QCLASS + if (next > packet_ + length_) + return false; + + cur_ = next; + + return true; +} + +DnsResponse::DnsResponse() + : io_buffer_(new IOBufferWithSize(dns_protocol::kMaxUDPSize + 1)) { +} + +DnsResponse::DnsResponse(size_t length) + : io_buffer_(new IOBufferWithSize(length)) { +} + +DnsResponse::DnsResponse(const void* data, + size_t length, + size_t answer_offset) + : io_buffer_(new IOBufferWithSize(length)), + parser_(io_buffer_->data(), length, answer_offset) { + DCHECK(data); + memcpy(io_buffer_->data(), data, length); +} + +DnsResponse::~DnsResponse() { +} + +bool DnsResponse::InitParse(int nbytes, const DnsQuery& query) { + DCHECK_GE(nbytes, 0); + // Response includes query, it should be at least that size. + if (nbytes < query.io_buffer()->size() || nbytes >= io_buffer_->size()) + return false; + + // Match the query id. + if (base::NetToHost16(header()->id) != query.id()) + return false; + + // Match question count. + if (base::NetToHost16(header()->qdcount) != 1) + return false; + + // Match the question section. + const size_t hdr_size = sizeof(dns_protocol::Header); + const base::StringPiece question = query.question(); + if (question != base::StringPiece(io_buffer_->data() + hdr_size, + question.size())) { + return false; + } + + // Construct the parser. + parser_ = DnsRecordParser(io_buffer_->data(), + nbytes, + hdr_size + question.size()); + return true; +} + +bool DnsResponse::InitParseWithoutQuery(int nbytes) { + DCHECK_GE(nbytes, 0); + + size_t hdr_size = sizeof(dns_protocol::Header); + + if (nbytes < static_cast<int>(hdr_size) || nbytes >= io_buffer_->size()) + return false; + + parser_ = DnsRecordParser( + io_buffer_->data(), nbytes, hdr_size); + + unsigned qdcount = base::NetToHost16(header()->qdcount); + for (unsigned i = 0; i < qdcount; ++i) { + if (!parser_.SkipQuestion()) { + parser_ = DnsRecordParser(); // Make parser invalid again. + return false; + } + } + + return true; +} + +bool DnsResponse::IsValid() const { + return parser_.IsValid(); +} + +uint16 DnsResponse::flags() const { + DCHECK(parser_.IsValid()); + return base::NetToHost16(header()->flags) & ~(dns_protocol::kRcodeMask); +} + +uint8 DnsResponse::rcode() const { + DCHECK(parser_.IsValid()); + return base::NetToHost16(header()->flags) & dns_protocol::kRcodeMask; +} + +unsigned DnsResponse::answer_count() const { + DCHECK(parser_.IsValid()); + return base::NetToHost16(header()->ancount); +} + +unsigned DnsResponse::additional_answer_count() const { + DCHECK(parser_.IsValid()); + return base::NetToHost16(header()->arcount); +} + +base::StringPiece DnsResponse::qname() const { + DCHECK(parser_.IsValid()); + // The response is HEADER QNAME QTYPE QCLASS ANSWER. + // |parser_| is positioned at the beginning of ANSWER, so the end of QNAME is + // two uint16s before it. + const size_t hdr_size = sizeof(dns_protocol::Header); + const size_t qname_size = parser_.GetOffset() - 2 * sizeof(uint16) - hdr_size; + return base::StringPiece(io_buffer_->data() + hdr_size, qname_size); +} + +uint16 DnsResponse::qtype() const { + DCHECK(parser_.IsValid()); + // QTYPE starts where QNAME ends. + const size_t type_offset = parser_.GetOffset() - 2 * sizeof(uint16); + uint16 type; + ReadBigEndian<uint16>(io_buffer_->data() + type_offset, &type); + return type; +} + +std::string DnsResponse::GetDottedName() const { + return DNSDomainToString(qname()); +} + +DnsRecordParser DnsResponse::Parser() const { + DCHECK(parser_.IsValid()); + // Return a copy of the parser. + return parser_; +} + +const dns_protocol::Header* DnsResponse::header() const { + return reinterpret_cast<const dns_protocol::Header*>(io_buffer_->data()); +} + +DnsResponse::Result DnsResponse::ParseToAddressList( + AddressList* addr_list, + base::TimeDelta* ttl) const { + DCHECK(IsValid()); + // DnsTransaction already verified that |response| matches the issued query. + // We still need to determine if there is a valid chain of CNAMEs from the + // query name to the RR owner name. + // We err on the side of caution with the assumption that if we are too picky, + // we can always fall back to the system getaddrinfo. + + // Expected owner of record. No trailing dot. + std::string expected_name = GetDottedName(); + + uint16 expected_type = qtype(); + DCHECK(expected_type == dns_protocol::kTypeA || + expected_type == dns_protocol::kTypeAAAA); + + size_t expected_size = (expected_type == dns_protocol::kTypeAAAA) + ? kIPv6AddressSize : kIPv4AddressSize; + + uint32 ttl_sec = kuint32max; + IPAddressList ip_addresses; + DnsRecordParser parser = Parser(); + DnsResourceRecord record; + unsigned ancount = answer_count(); + for (unsigned i = 0; i < ancount; ++i) { + if (!parser.ReadRecord(&record)) + return DNS_MALFORMED_RESPONSE; + + if (record.type == dns_protocol::kTypeCNAME) { + // Following the CNAME chain, only if no addresses seen. + if (!ip_addresses.empty()) + return DNS_CNAME_AFTER_ADDRESS; + + if (base::strcasecmp(record.name.c_str(), expected_name.c_str()) != 0) + return DNS_NAME_MISMATCH; + + if (record.rdata.size() != + parser.ReadName(record.rdata.begin(), &expected_name)) + return DNS_MALFORMED_CNAME; + + ttl_sec = std::min(ttl_sec, record.ttl); + } else if (record.type == expected_type) { + if (record.rdata.size() != expected_size) + return DNS_SIZE_MISMATCH; + + if (base::strcasecmp(record.name.c_str(), expected_name.c_str()) != 0) + return DNS_NAME_MISMATCH; + + ttl_sec = std::min(ttl_sec, record.ttl); + ip_addresses.push_back(IPAddressNumber(record.rdata.begin(), + record.rdata.end())); + } + } + + // TODO(szym): Extract TTL for NODATA results. http://crbug.com/115051 + + // getcanonname in eglibc returns the first owner name of an A or AAAA RR. + // If the response passed all the checks so far, then |expected_name| is it. + *addr_list = AddressList::CreateFromIPAddressList(ip_addresses, + expected_name); + *ttl = base::TimeDelta::FromSeconds(ttl_sec); + return DNS_PARSE_OK; +} + +} // namespace net diff --git a/chromium/net/dns/dns_response.h b/chromium/net/dns/dns_response.h new file mode 100644 index 00000000000..57abecf6253 --- /dev/null +++ b/chromium/net/dns/dns_response.h @@ -0,0 +1,169 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_RESPONSE_H_ +#define NET_DNS_DNS_RESPONSE_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/memory/ref_counted.h" +#include "base/strings/string_piece.h" +#include "base/time/time.h" +#include "net/base/net_export.h" +#include "net/base/net_util.h" + +namespace net { + +class AddressList; +class DnsQuery; +class IOBufferWithSize; + +namespace dns_protocol { +struct Header; +} + +// Parsed resource record. +struct NET_EXPORT_PRIVATE DnsResourceRecord { + DnsResourceRecord(); + ~DnsResourceRecord(); + + std::string name; // in dotted form + uint16 type; + uint16 klass; + uint32 ttl; + base::StringPiece rdata; // points to the original response buffer +}; + +// Iterator to walk over resource records of the DNS response packet. +class NET_EXPORT_PRIVATE DnsRecordParser { + public: + // Construct an uninitialized iterator. + DnsRecordParser(); + + // Construct an iterator to process the |packet| of given |length|. + // |offset| points to the beginning of the answer section. + DnsRecordParser(const void* packet, size_t length, size_t offset); + + // Returns |true| if initialized. + bool IsValid() const { return packet_ != NULL; } + + // Returns |true| if no more bytes remain in the packet. + bool AtEnd() const { return cur_ == packet_ + length_; } + + // Returns current offset into the packet. + size_t GetOffset() const { return cur_ - packet_; } + + // Parses a (possibly compressed) DNS name from the packet starting at + // |pos|. Stores output (even partial) in |out| unless |out| is NULL. |out| + // is stored in the dotted form, e.g., "example.com". Returns number of bytes + // consumed or 0 on failure. + // This is exposed to allow parsing compressed names within RRDATA for TYPEs + // such as NS, CNAME, PTR, MX, SOA. + // See RFC 1035 section 4.1.4. + unsigned ReadName(const void* pos, std::string* out) const; + + // Parses the next resource record into |record|. Returns true if succeeded. + bool ReadRecord(DnsResourceRecord* record); + + // Skip a question section, returns true if succeeded. + bool SkipQuestion(); + + private: + const char* packet_; + size_t length_; + // Current offset within the packet. + const char* cur_; +}; + +// Buffer-holder for the DNS response allowing easy access to the header fields +// and resource records. After reading into |io_buffer| must call InitParse to +// position the RR parser. +class NET_EXPORT_PRIVATE DnsResponse { + public: + // Possible results from ParseToAddressList. + enum Result { + DNS_PARSE_OK = 0, + DNS_MALFORMED_RESPONSE, // DnsRecordParser failed before the end of + // packet. + DNS_MALFORMED_CNAME, // Could not parse CNAME out of RRDATA. + DNS_NAME_MISMATCH, // Got an address but no ordered chain of CNAMEs + // leads there. + DNS_SIZE_MISMATCH, // Got an address but size does not match. + DNS_CNAME_AFTER_ADDRESS, // Found CNAME after an address record. + DNS_ADDRESS_TTL_MISMATCH, // OBSOLETE. No longer used. + DNS_NO_ADDRESSES, // OBSOLETE. No longer used. + // Only add new values here. + DNS_PARSE_RESULT_MAX, // Bounding value for histograms. + }; + + // Constructs a response buffer large enough to store one byte more than + // largest possible response, to detect malformed responses. + DnsResponse(); + + // Constructs a response buffer of given length. Used for TCP transactions. + explicit DnsResponse(size_t length); + + // Constructs a response from |data|. Used for testing purposes only! + DnsResponse(const void* data, size_t length, size_t answer_offset); + + ~DnsResponse(); + + // Internal buffer accessor into which actual bytes of response will be + // read. + IOBufferWithSize* io_buffer() { return io_buffer_.get(); } + + // Assuming the internal buffer holds |nbytes| bytes, returns true iff the + // packet matches the |query| id and question. + bool InitParse(int nbytes, const DnsQuery& query); + + // Assuming the internal buffer holds |nbytes| bytes, initialize the parser + // without matching it against an existing query. + bool InitParseWithoutQuery(int nbytes); + + // Returns true if response is valid, that is, after successful InitParse. + bool IsValid() const; + + // All of the methods below are valid only if the response is valid. + + // Accessors for the header. + uint16 flags() const; // excluding rcode + uint8 rcode() const; + + unsigned answer_count() const; + unsigned additional_answer_count() const; + + // Accessors to the question. The qname is unparsed. + base::StringPiece qname() const; + uint16 qtype() const; + + // Returns qname in dotted format. + std::string GetDottedName() const; + + // Returns an iterator to the resource records in the answer section. + // The iterator is valid only in the scope of the DnsResponse. + // This operation is idempotent. + DnsRecordParser Parser() const; + + // Extracts an AddressList from this response. Returns SUCCESS if succeeded. + // Otherwise returns a detailed error number. + Result ParseToAddressList(AddressList* addr_list, base::TimeDelta* ttl) const; + + private: + // Convenience for header access. + const dns_protocol::Header* header() const; + + // Buffer into which response bytes are read. + scoped_refptr<IOBufferWithSize> io_buffer_; + + // Iterator constructed after InitParse positioned at the answer section. + // It is never updated afterwards, so can be used in accessors. + DnsRecordParser parser_; + + DISALLOW_COPY_AND_ASSIGN(DnsResponse); +}; + +} // namespace net + +#endif // NET_DNS_DNS_RESPONSE_H_ diff --git a/chromium/net/dns/dns_response_unittest.cc b/chromium/net/dns/dns_response_unittest.cc new file mode 100644 index 00000000000..12b0377fea6 --- /dev/null +++ b/chromium/net/dns/dns_response_unittest.cc @@ -0,0 +1,581 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_response.h" + +#include "base/time/time.h" +#include "net/base/address_list.h" +#include "net/base/io_buffer.h" +#include "net/base/net_util.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/dns_query.h" +#include "net/dns/dns_test_util.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +TEST(DnsRecordParserTest, Constructor) { + const char data[] = { 0 }; + + EXPECT_FALSE(DnsRecordParser().IsValid()); + EXPECT_TRUE(DnsRecordParser(data, 1, 0).IsValid()); + EXPECT_TRUE(DnsRecordParser(data, 1, 1).IsValid()); + + EXPECT_FALSE(DnsRecordParser(data, 1, 0).AtEnd()); + EXPECT_TRUE(DnsRecordParser(data, 1, 1).AtEnd()); +} + +TEST(DnsRecordParserTest, ReadName) { + const uint8 data[] = { + // all labels "foo.example.com" + 0x03, 'f', 'o', 'o', + 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + 0x03, 'c', 'o', 'm', + // byte 0x10 + 0x00, + // byte 0x11 + // part label, part pointer, "bar.example.com" + 0x03, 'b', 'a', 'r', + 0xc0, 0x04, + // byte 0x17 + // all pointer to "bar.example.com", 2 jumps + 0xc0, 0x11, + // byte 0x1a + }; + + std::string out; + DnsRecordParser parser(data, sizeof(data), 0); + ASSERT_TRUE(parser.IsValid()); + + EXPECT_EQ(0x11u, parser.ReadName(data + 0x00, &out)); + EXPECT_EQ("foo.example.com", out); + // Check that the last "." is never stored. + out.clear(); + EXPECT_EQ(0x1u, parser.ReadName(data + 0x10, &out)); + EXPECT_EQ("", out); + out.clear(); + EXPECT_EQ(0x6u, parser.ReadName(data + 0x11, &out)); + EXPECT_EQ("bar.example.com", out); + out.clear(); + EXPECT_EQ(0x2u, parser.ReadName(data + 0x17, &out)); + EXPECT_EQ("bar.example.com", out); + + // Parse name without storing it. + EXPECT_EQ(0x11u, parser.ReadName(data + 0x00, NULL)); + EXPECT_EQ(0x1u, parser.ReadName(data + 0x10, NULL)); + EXPECT_EQ(0x6u, parser.ReadName(data + 0x11, NULL)); + EXPECT_EQ(0x2u, parser.ReadName(data + 0x17, NULL)); + + // Check that it works even if initial position is different. + parser = DnsRecordParser(data, sizeof(data), 0x12); + EXPECT_EQ(0x6u, parser.ReadName(data + 0x11, NULL)); +} + +TEST(DnsRecordParserTest, ReadNameFail) { + const uint8 data[] = { + // label length beyond packet + 0x30, 'x', 'x', + 0x00, + // pointer offset beyond packet + 0xc0, 0x20, + // pointer loop + 0xc0, 0x08, + 0xc0, 0x06, + // incorrect label type (currently supports only direct and pointer) + 0x80, 0x00, + // truncated name (missing root label) + 0x02, 'x', 'x', + }; + + DnsRecordParser parser(data, sizeof(data), 0); + ASSERT_TRUE(parser.IsValid()); + + std::string out; + EXPECT_EQ(0u, parser.ReadName(data + 0x00, &out)); + EXPECT_EQ(0u, parser.ReadName(data + 0x04, &out)); + EXPECT_EQ(0u, parser.ReadName(data + 0x08, &out)); + EXPECT_EQ(0u, parser.ReadName(data + 0x0a, &out)); + EXPECT_EQ(0u, parser.ReadName(data + 0x0c, &out)); + EXPECT_EQ(0u, parser.ReadName(data + 0x0e, &out)); +} + +TEST(DnsRecordParserTest, ReadRecord) { + const uint8 data[] = { + // Type CNAME record. + 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00, + 0x00, 0x05, // TYPE is CNAME. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, 0x24, 0x74, // TTL is 0x00012474. + 0x00, 0x06, // RDLENGTH is 6 bytes. + 0x03, 'f', 'o', 'o', // compressed name in record + 0xc0, 0x00, + // Type A record. + 0x03, 'b', 'a', 'r', // compressed owner name + 0xc0, 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x20, 0x13, 0x55, // TTL is 0x00201355. + 0x00, 0x04, // RDLENGTH is 4 bytes. + 0x7f, 0x02, 0x04, 0x01, // IP is 127.2.4.1 + }; + + std::string out; + DnsRecordParser parser(data, sizeof(data), 0); + + DnsResourceRecord record; + EXPECT_TRUE(parser.ReadRecord(&record)); + EXPECT_EQ("example.com", record.name); + EXPECT_EQ(dns_protocol::kTypeCNAME, record.type); + EXPECT_EQ(dns_protocol::kClassIN, record.klass); + EXPECT_EQ(0x00012474u, record.ttl); + EXPECT_EQ(6u, record.rdata.length()); + EXPECT_EQ(6u, parser.ReadName(record.rdata.data(), &out)); + EXPECT_EQ("foo.example.com", out); + EXPECT_FALSE(parser.AtEnd()); + + EXPECT_TRUE(parser.ReadRecord(&record)); + EXPECT_EQ("bar.example.com", record.name); + EXPECT_EQ(dns_protocol::kTypeA, record.type); + EXPECT_EQ(dns_protocol::kClassIN, record.klass); + EXPECT_EQ(0x00201355u, record.ttl); + EXPECT_EQ(4u, record.rdata.length()); + EXPECT_EQ(base::StringPiece("\x7f\x02\x04\x01"), record.rdata); + EXPECT_TRUE(parser.AtEnd()); + + // Test truncated record. + parser = DnsRecordParser(data, sizeof(data) - 2, 0); + EXPECT_TRUE(parser.ReadRecord(&record)); + EXPECT_FALSE(parser.AtEnd()); + EXPECT_FALSE(parser.ReadRecord(&record)); +} + +TEST(DnsResponseTest, InitParse) { + // This includes \0 at the end. + const char qname_data[] = "\x0A""codereview""\x08""chromium""\x03""org"; + const base::StringPiece qname(qname_data, sizeof(qname_data)); + // Compilers want to copy when binding temporary to const &, so must use heap. + scoped_ptr<DnsQuery> query(new DnsQuery(0xcafe, qname, dns_protocol::kTypeA)); + + const uint8 response_data[] = { + // Header + 0xca, 0xfe, // ID + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x01, // 1 question + 0x00, 0x02, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Question + // This part is echoed back from the respective query. + 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w', + 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm', + 0x03, 'o', 'r', 'g', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + + // Answer 1 + 0xc0, 0x0c, // NAME is a pointer to name in Question section. + 0x00, 0x05, // TYPE is CNAME. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x12, // RDLENGTH is 18 bytes. + // ghs.l.google.com in DNS format. + 0x03, 'g', 'h', 's', + 0x01, 'l', + 0x06, 'g', 'o', 'o', 'g', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00, + + // Answer 2 + 0xc0, 0x35, // NAME is a pointer to name in Answer 1. + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, // TTL (4 bytes) is 53 seconds. + 0x00, 0x35, + 0x00, 0x04, // RDLENGTH is 4 bytes. + 0x4a, 0x7d, // RDATA is the IP: 74.125.95.121 + 0x5f, 0x79, + }; + + DnsResponse resp; + memcpy(resp.io_buffer()->data(), response_data, sizeof(response_data)); + + // Reject too short. + EXPECT_FALSE(resp.InitParse(query->io_buffer()->size() - 1, *query)); + EXPECT_FALSE(resp.IsValid()); + + // Reject wrong id. + scoped_ptr<DnsQuery> other_query(query->CloneWithNewId(0xbeef)); + EXPECT_FALSE(resp.InitParse(sizeof(response_data), *other_query)); + EXPECT_FALSE(resp.IsValid()); + + // Reject wrong question. + scoped_ptr<DnsQuery> wrong_query( + new DnsQuery(0xcafe, qname, dns_protocol::kTypeCNAME)); + EXPECT_FALSE(resp.InitParse(sizeof(response_data), *wrong_query)); + EXPECT_FALSE(resp.IsValid()); + + // Accept matching question. + EXPECT_TRUE(resp.InitParse(sizeof(response_data), *query)); + EXPECT_TRUE(resp.IsValid()); + + // Check header access. + EXPECT_EQ(0x8180, resp.flags()); + EXPECT_EQ(0x0, resp.rcode()); + EXPECT_EQ(2u, resp.answer_count()); + + // Check question access. + EXPECT_EQ(query->qname(), resp.qname()); + EXPECT_EQ(query->qtype(), resp.qtype()); + EXPECT_EQ("codereview.chromium.org", resp.GetDottedName()); + + DnsResourceRecord record; + DnsRecordParser parser = resp.Parser(); + EXPECT_TRUE(parser.ReadRecord(&record)); + EXPECT_FALSE(parser.AtEnd()); + EXPECT_TRUE(parser.ReadRecord(&record)); + EXPECT_TRUE(parser.AtEnd()); + EXPECT_FALSE(parser.ReadRecord(&record)); +} + +TEST(DnsResponseTest, InitParseWithoutQuery) { + DnsResponse resp; + memcpy(resp.io_buffer()->data(), kT0ResponseDatagram, + sizeof(kT0ResponseDatagram)); + + // Accept matching question. + EXPECT_TRUE(resp.InitParseWithoutQuery(sizeof(kT0ResponseDatagram))); + EXPECT_TRUE(resp.IsValid()); + + // Check header access. + EXPECT_EQ(0x8180, resp.flags()); + EXPECT_EQ(0x0, resp.rcode()); + EXPECT_EQ(kT0RecordCount, resp.answer_count()); + + // Check question access. + EXPECT_EQ(kT0Qtype, resp.qtype()); + EXPECT_EQ(kT0HostName, resp.GetDottedName()); + + DnsResourceRecord record; + DnsRecordParser parser = resp.Parser(); + for (unsigned i = 0; i < kT0RecordCount; i ++) { + EXPECT_FALSE(parser.AtEnd()); + EXPECT_TRUE(parser.ReadRecord(&record)); + } + EXPECT_TRUE(parser.AtEnd()); + EXPECT_FALSE(parser.ReadRecord(&record)); +} + +TEST(DnsResponseTest, InitParseWithoutQueryNoQuestions) { + const uint8 response_data[] = { + // Header + 0xca, 0xfe, // ID + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No question + 0x00, 0x01, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Answer 1 + 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w', + 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm', + 0x03, 'o', 'r', 'g', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, // TTL (4 bytes) is 53 seconds. + 0x00, 0x35, + 0x00, 0x04, // RDLENGTH is 4 bytes. + 0x4a, 0x7d, // RDATA is the IP: 74.125.95.121 + 0x5f, 0x79, + }; + + DnsResponse resp; + memcpy(resp.io_buffer()->data(), response_data, sizeof(response_data)); + + EXPECT_TRUE(resp.InitParseWithoutQuery(sizeof(response_data))); + + // Check header access. + EXPECT_EQ(0x8180, resp.flags()); + EXPECT_EQ(0x0, resp.rcode()); + EXPECT_EQ(0x1u, resp.answer_count()); + + DnsResourceRecord record; + DnsRecordParser parser = resp.Parser(); + + EXPECT_FALSE(parser.AtEnd()); + EXPECT_TRUE(parser.ReadRecord(&record)); + EXPECT_EQ("codereview.chromium.org", record.name); + EXPECT_EQ(0x00000035u, record.ttl); + EXPECT_EQ(dns_protocol::kTypeA, record.type); + + EXPECT_TRUE(parser.AtEnd()); + EXPECT_FALSE(parser.ReadRecord(&record)); +} + +TEST(DnsResponseTest, InitParseWithoutQueryTwoQuestions) { + const uint8 response_data[] = { + // Header + 0xca, 0xfe, // ID + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x02, // 2 questions + 0x00, 0x01, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Question 1 + 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w', + 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm', + 0x03, 'o', 'r', 'g', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + + // Question 2 + 0x0b, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w', '2', + 0xc0, 0x18, // pointer to "chromium.org" + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + + // Answer 1 + 0xc0, 0x0c, // NAME is a pointer to name in Question section. + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, // TTL (4 bytes) is 53 seconds. + 0x00, 0x35, + 0x00, 0x04, // RDLENGTH is 4 bytes. + 0x4a, 0x7d, // RDATA is the IP: 74.125.95.121 + 0x5f, 0x79, + }; + + DnsResponse resp; + memcpy(resp.io_buffer()->data(), response_data, sizeof(response_data)); + + EXPECT_TRUE(resp.InitParseWithoutQuery(sizeof(response_data))); + + // Check header access. + EXPECT_EQ(0x8180, resp.flags()); + EXPECT_EQ(0x0, resp.rcode()); + EXPECT_EQ(0x01u, resp.answer_count()); + + DnsResourceRecord record; + DnsRecordParser parser = resp.Parser(); + + EXPECT_FALSE(parser.AtEnd()); + EXPECT_TRUE(parser.ReadRecord(&record)); + EXPECT_EQ("codereview.chromium.org", record.name); + EXPECT_EQ(0x35u, record.ttl); + EXPECT_EQ(dns_protocol::kTypeA, record.type); + + EXPECT_TRUE(parser.AtEnd()); + EXPECT_FALSE(parser.ReadRecord(&record)); +} + +TEST(DnsResponseTest, InitParseWithoutQueryPacketTooShort) { + const uint8 response_data[] = { + // Header + 0xca, 0xfe, // ID + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No question + }; + + DnsResponse resp; + memcpy(resp.io_buffer()->data(), response_data, sizeof(response_data)); + + EXPECT_FALSE(resp.InitParseWithoutQuery(sizeof(response_data))); +} + +void VerifyAddressList(const std::vector<const char*>& ip_addresses, + const AddressList& addrlist) { + ASSERT_EQ(ip_addresses.size(), addrlist.size()); + + for (size_t i = 0; i < addrlist.size(); ++i) { + EXPECT_EQ(ip_addresses[i], addrlist[i].ToStringWithoutPort()); + } +} + +TEST(DnsResponseTest, ParseToAddressList) { + const struct TestCase { + size_t query_size; + const uint8* response_data; + size_t response_size; + const char* const* expected_addresses; + size_t num_expected_addresses; + const char* expected_cname; + int expected_ttl_sec; + } cases[] = { + { + kT0QuerySize, + kT0ResponseDatagram, arraysize(kT0ResponseDatagram), + kT0IpAddresses, arraysize(kT0IpAddresses), + kT0CanonName, + kT0TTL, + }, + { + kT1QuerySize, + kT1ResponseDatagram, arraysize(kT1ResponseDatagram), + kT1IpAddresses, arraysize(kT1IpAddresses), + kT1CanonName, + kT1TTL, + }, + { + kT2QuerySize, + kT2ResponseDatagram, arraysize(kT2ResponseDatagram), + kT2IpAddresses, arraysize(kT2IpAddresses), + kT2CanonName, + kT2TTL, + }, + { + kT3QuerySize, + kT3ResponseDatagram, arraysize(kT3ResponseDatagram), + kT3IpAddresses, arraysize(kT3IpAddresses), + kT3CanonName, + kT3TTL, + }, + }; + + for (size_t i = 0; i < ARRAYSIZE_UNSAFE(cases); ++i) { + const TestCase& t = cases[i]; + DnsResponse response(t.response_data, t.response_size, t.query_size); + AddressList addr_list; + base::TimeDelta ttl; + EXPECT_EQ(DnsResponse::DNS_PARSE_OK, + response.ParseToAddressList(&addr_list, &ttl)); + std::vector<const char*> expected_addresses( + t.expected_addresses, + t.expected_addresses + t.num_expected_addresses); + VerifyAddressList(expected_addresses, addr_list); + EXPECT_EQ(t.expected_cname, addr_list.canonical_name()); + EXPECT_EQ(base::TimeDelta::FromSeconds(t.expected_ttl_sec), ttl); + } +} + +const uint8 kResponseTruncatedRecord[] = { + // Header: 1 question, 1 answer RR + 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + // Question: name = 'a', type = A (0x1) + 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01, + // Answer: name = 'a', type = A, TTL = 0xFF, RDATA = 10.10.10.10 + 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, + 0x00, 0x04, 0x0A, 0x0A, 0x0A, // Truncated RDATA. +}; + +const uint8 kResponseTruncatedCNAME[] = { + // Header: 1 question, 1 answer RR + 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + // Question: name = 'a', type = A (0x1) + 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01, + // Answer: name = 'a', type = CNAME, TTL = 0xFF, RDATA = 'foo' (truncated) + 0x01, 'a', 0x00, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, + 0x00, 0x03, 0x03, 'f', 'o', // Truncated name. +}; + +const uint8 kResponseNameMismatch[] = { + // Header: 1 question, 1 answer RR + 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + // Question: name = 'a', type = A (0x1) + 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01, + // Answer: name = 'b', type = A, TTL = 0xFF, RDATA = 10.10.10.10 + 0x01, 'b', 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, + 0x00, 0x04, 0x0A, 0x0A, 0x0A, 0x0A, +}; + +const uint8 kResponseNameMismatchInChain[] = { + // Header: 1 question, 3 answer RR + 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, + // Question: name = 'a', type = A (0x1) + 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01, + // Answer: name = 'a', type = CNAME, TTL = 0xFF, RDATA = 'b' + 0x01, 'a', 0x00, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, + 0x00, 0x03, 0x01, 'b', 0x00, + // Answer: name = 'b', type = A, TTL = 0xFF, RDATA = 10.10.10.10 + 0x01, 'b', 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, + 0x00, 0x04, 0x0A, 0x0A, 0x0A, 0x0A, + // Answer: name = 'c', type = A, TTL = 0xFF, RDATA = 10.10.10.11 + 0x01, 'c', 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, + 0x00, 0x04, 0x0A, 0x0A, 0x0A, 0x0B, +}; + +const uint8 kResponseSizeMismatch[] = { + // Header: 1 answer RR + 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + // Question: name = 'a', type = AAAA (0x1c) + 0x01, 'a', 0x00, 0x00, 0x1c, 0x00, 0x01, + // Answer: name = 'a', type = AAAA, TTL = 0xFF, RDATA = 10.10.10.10 + 0x01, 'a', 0x00, 0x00, 0x1c, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, + 0x00, 0x04, 0x0A, 0x0A, 0x0A, 0x0A, +}; + +const uint8 kResponseCNAMEAfterAddress[] = { + // Header: 2 answer RR + 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, + // Question: name = 'a', type = A (0x1) + 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01, + // Answer: name = 'a', type = A, TTL = 0xFF, RDATA = 10.10.10.10. + 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, + 0x00, 0x04, 0x0A, 0x0A, 0x0A, 0x0A, + // Answer: name = 'a', type = CNAME, TTL = 0xFF, RDATA = 'b' + 0x01, 'a', 0x00, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, + 0x00, 0x03, 0x01, 'b', 0x00, +}; + +const uint8 kResponseNoAddresses[] = { + // Header: 1 question, 1 answer RR, 1 authority RR + 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, + // Question: name = 'a', type = A (0x1) + 0x01, 'a', 0x00, 0x00, 0x01, 0x00, 0x01, + // Answer: name = 'a', type = CNAME, TTL = 0xFF, RDATA = 'b' + 0x01, 'a', 0x00, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, + 0x00, 0x03, 0x01, 'b', 0x00, + // Authority section + // Answer: name = 'b', type = A, TTL = 0xFF, RDATA = 10.10.10.10 + 0x01, 'b', 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, + 0x00, 0x04, 0x0A, 0x0A, 0x0A, 0x0A, +}; + +TEST(DnsResponseTest, ParseToAddressListFail) { + const struct TestCase { + const uint8* data; + size_t size; + DnsResponse::Result expected_result; + } cases[] = { + { kResponseTruncatedRecord, arraysize(kResponseTruncatedRecord), + DnsResponse::DNS_MALFORMED_RESPONSE }, + { kResponseTruncatedCNAME, arraysize(kResponseTruncatedCNAME), + DnsResponse::DNS_MALFORMED_CNAME }, + { kResponseNameMismatch, arraysize(kResponseNameMismatch), + DnsResponse::DNS_NAME_MISMATCH }, + { kResponseNameMismatchInChain, arraysize(kResponseNameMismatchInChain), + DnsResponse::DNS_NAME_MISMATCH }, + { kResponseSizeMismatch, arraysize(kResponseSizeMismatch), + DnsResponse::DNS_SIZE_MISMATCH }, + { kResponseCNAMEAfterAddress, arraysize(kResponseCNAMEAfterAddress), + DnsResponse::DNS_CNAME_AFTER_ADDRESS }, + // Not actually a failure, just an empty result. + { kResponseNoAddresses, arraysize(kResponseNoAddresses), + DnsResponse::DNS_PARSE_OK }, + }; + + const size_t kQuerySize = 12 + 7; + + for (size_t i = 0; i < ARRAYSIZE_UNSAFE(cases); ++i) { + const TestCase& t = cases[i]; + + DnsResponse response(t.data, t.size, kQuerySize); + AddressList addr_list; + base::TimeDelta ttl; + EXPECT_EQ(t.expected_result, + response.ParseToAddressList(&addr_list, &ttl)); + } +} + +} // namespace + +} // namespace net diff --git a/chromium/net/dns/dns_session.cc b/chromium/net/dns/dns_session.cc new file mode 100644 index 00000000000..ea8b6a14274 --- /dev/null +++ b/chromium/net/dns/dns_session.cc @@ -0,0 +1,298 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_session.h" + +#include "base/basictypes.h" +#include "base/bind.h" +#include "base/lazy_instance.h" +#include "base/metrics/histogram.h" +#include "base/metrics/sample_vector.h" +#include "base/rand_util.h" +#include "base/stl_util.h" +#include "base/time/time.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/dns/dns_config_service.h" +#include "net/dns/dns_socket_pool.h" +#include "net/socket/stream_socket.h" +#include "net/udp/datagram_client_socket.h" + +namespace net { + +namespace { +// Never exceed max timeout. +const unsigned kMaxTimeoutMs = 5000; +// Set min timeout, in case we are talking to a local DNS proxy. +const unsigned kMinTimeoutMs = 10; + +// Number of buckets in the histogram of observed RTTs. +const size_t kRTTBucketCount = 100; +// Target percentile in the RTT histogram used for retransmission timeout. +const unsigned kRTOPercentile = 99; +} // namespace + +// Runtime statistics of DNS server. +struct DnsSession::ServerStats { + ServerStats(base::TimeDelta rtt_estimate_param, RttBuckets* buckets) + : last_failure_count(0), rtt_estimate(rtt_estimate_param) { + rtt_histogram.reset(new base::SampleVector(buckets)); + // Seed histogram with 2 samples at |rtt_estimate| timeout. + rtt_histogram->Accumulate(rtt_estimate.InMilliseconds(), 2); + } + + // Count of consecutive failures after last success. + int last_failure_count; + + // Last time when server returned failure or timeout. + base::Time last_failure; + // Last time when server returned success. + base::Time last_success; + + // Estimated RTT using moving average. + base::TimeDelta rtt_estimate; + // Estimated error in the above. + base::TimeDelta rtt_deviation; + + // A histogram of observed RTT . + scoped_ptr<base::SampleVector> rtt_histogram; + + DISALLOW_COPY_AND_ASSIGN(ServerStats); +}; + +// static +base::LazyInstance<DnsSession::RttBuckets>::Leaky DnsSession::rtt_buckets_ = + LAZY_INSTANCE_INITIALIZER; + +DnsSession::RttBuckets::RttBuckets() : base::BucketRanges(kRTTBucketCount + 1) { + base::Histogram::InitializeBucketRanges(1, 5000, this); +} + +DnsSession::SocketLease::SocketLease(scoped_refptr<DnsSession> session, + unsigned server_index, + scoped_ptr<DatagramClientSocket> socket) + : session_(session), server_index_(server_index), socket_(socket.Pass()) {} + +DnsSession::SocketLease::~SocketLease() { + session_->FreeSocket(server_index_, socket_.Pass()); +} + +DnsSession::DnsSession(const DnsConfig& config, + scoped_ptr<DnsSocketPool> socket_pool, + const RandIntCallback& rand_int_callback, + NetLog* net_log) + : config_(config), + socket_pool_(socket_pool.Pass()), + rand_callback_(base::Bind(rand_int_callback, 0, kuint16max)), + net_log_(net_log), + server_index_(0) { + socket_pool_->Initialize(&config_.nameservers, net_log); + UMA_HISTOGRAM_CUSTOM_COUNTS( + "AsyncDNS.ServerCount", config_.nameservers.size(), 0, 10, 10); + for (size_t i = 0; i < config_.nameservers.size(); ++i) { + server_stats_.push_back(new ServerStats(config_.timeout, + rtt_buckets_.Pointer())); + } +} + +DnsSession::~DnsSession() { + RecordServerStats(); +} + +int DnsSession::NextQueryId() const { return rand_callback_.Run(); } + +unsigned DnsSession::NextFirstServerIndex() { + unsigned index = NextGoodServerIndex(server_index_); + if (config_.rotate) + server_index_ = (server_index_ + 1) % config_.nameservers.size(); + return index; +} + +unsigned DnsSession::NextGoodServerIndex(unsigned server_index) { + unsigned index = server_index; + base::Time oldest_server_failure(base::Time::Now()); + unsigned oldest_server_failure_index = 0; + + UMA_HISTOGRAM_BOOLEAN("AsyncDNS.ServerIsGood", + server_stats_[server_index]->last_failure.is_null()); + + do { + base::Time cur_server_failure = server_stats_[index]->last_failure; + // If number of failures on this server doesn't exceed number of allowed + // attempts, return its index. + if (server_stats_[server_index]->last_failure_count < config_.attempts) { + return index; + } + // Track oldest failed server. + if (cur_server_failure < oldest_server_failure) { + oldest_server_failure = cur_server_failure; + oldest_server_failure_index = index; + } + index = (index + 1) % config_.nameservers.size(); + } while (index != server_index); + + // If we are here it means that there are no successful servers, so we have + // to use one that has failed oldest. + return oldest_server_failure_index; +} + +void DnsSession::RecordServerFailure(unsigned server_index) { + UMA_HISTOGRAM_CUSTOM_COUNTS( + "AsyncDNS.ServerFailureIndex", server_index, 0, 10, 10); + ++(server_stats_[server_index]->last_failure_count); + server_stats_[server_index]->last_failure = base::Time::Now(); +} + +void DnsSession::RecordServerSuccess(unsigned server_index) { + if (server_stats_[server_index]->last_success.is_null()) { + UMA_HISTOGRAM_COUNTS_100("AsyncDNS.ServerFailuresAfterNetworkChange", + server_stats_[server_index]->last_failure_count); + } else { + UMA_HISTOGRAM_COUNTS_100("AsyncDNS.ServerFailuresBeforeSuccess", + server_stats_[server_index]->last_failure_count); + } + server_stats_[server_index]->last_failure_count = 0; + server_stats_[server_index]->last_failure = base::Time(); + server_stats_[server_index]->last_success = base::Time::Now(); +} + +void DnsSession::RecordRTT(unsigned server_index, base::TimeDelta rtt) { + DCHECK_LT(server_index, server_stats_.size()); + + // For measurement, assume it is the first attempt (no backoff). + base::TimeDelta timeout_jacobson = NextTimeoutFromJacobson(server_index, 0); + base::TimeDelta timeout_histogram = NextTimeoutFromHistogram(server_index, 0); + UMA_HISTOGRAM_TIMES("AsyncDNS.TimeoutErrorJacobson", rtt - timeout_jacobson); + UMA_HISTOGRAM_TIMES("AsyncDNS.TimeoutErrorHistogram", + rtt - timeout_histogram); + UMA_HISTOGRAM_TIMES("AsyncDNS.TimeoutErrorJacobsonUnder", + timeout_jacobson - rtt); + UMA_HISTOGRAM_TIMES("AsyncDNS.TimeoutErrorHistogramUnder", + timeout_histogram - rtt); + + // Jacobson/Karels algorithm for TCP. + // Using parameters: alpha = 1/8, delta = 1/4, beta = 4 + base::TimeDelta& estimate = server_stats_[server_index]->rtt_estimate; + base::TimeDelta& deviation = server_stats_[server_index]->rtt_deviation; + base::TimeDelta current_error = rtt - estimate; + estimate += current_error / 8; // * alpha + base::TimeDelta abs_error = base::TimeDelta::FromInternalValue( + std::abs(current_error.ToInternalValue())); + deviation += (abs_error - deviation) / 4; // * delta + + // Histogram-based method. + server_stats_[server_index]->rtt_histogram + ->Accumulate(rtt.InMilliseconds(), 1); +} + +void DnsSession::RecordLostPacket(unsigned server_index, int attempt) { + base::TimeDelta timeout_jacobson = + NextTimeoutFromJacobson(server_index, attempt); + base::TimeDelta timeout_histogram = + NextTimeoutFromHistogram(server_index, attempt); + UMA_HISTOGRAM_TIMES("AsyncDNS.TimeoutSpentJacobson", timeout_jacobson); + UMA_HISTOGRAM_TIMES("AsyncDNS.TimeoutSpentHistogram", timeout_histogram); +} + +void DnsSession::RecordServerStats() { + for (size_t index = 0; index < server_stats_.size(); ++index) { + if (server_stats_[index]->last_failure_count) { + if (server_stats_[index]->last_success.is_null()) { + UMA_HISTOGRAM_COUNTS("AsyncDNS.ServerFailuresWithoutSuccess", + server_stats_[index]->last_failure_count); + } else { + UMA_HISTOGRAM_COUNTS("AsyncDNS.ServerFailuresAfterSuccess", + server_stats_[index]->last_failure_count); + } + } + } +} + + +base::TimeDelta DnsSession::NextTimeout(unsigned server_index, int attempt) { + // Respect config timeout if it exceeds |kMaxTimeoutMs|. + if (config_.timeout.InMilliseconds() >= kMaxTimeoutMs) + return config_.timeout; + return NextTimeoutFromHistogram(server_index, attempt); +} + +// Allocate a socket, already connected to the server address. +scoped_ptr<DnsSession::SocketLease> DnsSession::AllocateSocket( + unsigned server_index, const NetLog::Source& source) { + scoped_ptr<DatagramClientSocket> socket; + + socket = socket_pool_->AllocateSocket(server_index); + if (!socket.get()) + return scoped_ptr<SocketLease>(); + + socket->NetLog().BeginEvent(NetLog::TYPE_SOCKET_IN_USE, + source.ToEventParametersCallback()); + + SocketLease* lease = new SocketLease(this, server_index, socket.Pass()); + return scoped_ptr<SocketLease>(lease); +} + +scoped_ptr<StreamSocket> DnsSession::CreateTCPSocket( + unsigned server_index, const NetLog::Source& source) { + return socket_pool_->CreateTCPSocket(server_index, source); +} + +// Release a socket. +void DnsSession::FreeSocket(unsigned server_index, + scoped_ptr<DatagramClientSocket> socket) { + DCHECK(socket.get()); + + socket->NetLog().EndEvent(NetLog::TYPE_SOCKET_IN_USE); + + socket_pool_->FreeSocket(server_index, socket.Pass()); +} + +base::TimeDelta DnsSession::NextTimeoutFromJacobson(unsigned server_index, + int attempt) { + DCHECK_LT(server_index, server_stats_.size()); + + base::TimeDelta timeout = server_stats_[server_index]->rtt_estimate + + 4 * server_stats_[server_index]->rtt_deviation; + + timeout = std::max(timeout, base::TimeDelta::FromMilliseconds(kMinTimeoutMs)); + + // The timeout doubles every full round. + unsigned num_backoffs = attempt / config_.nameservers.size(); + + return std::min(timeout * (1 << num_backoffs), + base::TimeDelta::FromMilliseconds(kMaxTimeoutMs)); +} + +base::TimeDelta DnsSession::NextTimeoutFromHistogram(unsigned server_index, + int attempt) { + DCHECK_LT(server_index, server_stats_.size()); + + COMPILE_ASSERT(std::numeric_limits<base::HistogramBase::Count>::is_signed, + histogram_base_count_assumed_to_be_signed); + + // Use fixed percentile of observed samples. + const base::SampleVector& samples = + *server_stats_[server_index]->rtt_histogram; + + base::HistogramBase::Count total = samples.TotalCount(); + base::HistogramBase::Count remaining_count = kRTOPercentile * total / 100; + size_t index = 0; + while (remaining_count > 0 && index < rtt_buckets_.Get().size()) { + remaining_count -= samples.GetCountAtIndex(index); + ++index; + } + + base::TimeDelta timeout = + base::TimeDelta::FromMilliseconds(rtt_buckets_.Get().range(index)); + + timeout = std::max(timeout, base::TimeDelta::FromMilliseconds(kMinTimeoutMs)); + + // The timeout still doubles every full round. + unsigned num_backoffs = attempt / config_.nameservers.size(); + + return std::min(timeout * (1 << num_backoffs), + base::TimeDelta::FromMilliseconds(kMaxTimeoutMs)); +} + +} // namespace net diff --git a/chromium/net/dns/dns_session.h b/chromium/net/dns/dns_session.h new file mode 100644 index 00000000000..01ba5e5d154 --- /dev/null +++ b/chromium/net/dns/dns_session.h @@ -0,0 +1,147 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_SESSION_H_ +#define NET_DNS_DNS_SESSION_H_ + +#include <vector> + +#include "base/lazy_instance.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/scoped_vector.h" +#include "base/metrics/bucket_ranges.h" +#include "base/time/time.h" +#include "net/base/net_export.h" +#include "net/base/rand_callback.h" +#include "net/dns/dns_config_service.h" +#include "net/dns/dns_socket_pool.h" + +namespace base { +class BucketRanges; +class SampleVector; +} + +namespace net { + +class ClientSocketFactory; +class DatagramClientSocket; +class NetLog; +class StreamSocket; + +// Session parameters and state shared between DNS transactions. +// Ref-counted so that DnsClient::Request can keep working in absence of +// DnsClient. A DnsSession must be recreated when DnsConfig changes. +class NET_EXPORT_PRIVATE DnsSession + : NON_EXPORTED_BASE(public base::RefCounted<DnsSession>) { + public: + typedef base::Callback<int()> RandCallback; + + class NET_EXPORT_PRIVATE SocketLease { + public: + SocketLease(scoped_refptr<DnsSession> session, + unsigned server_index, + scoped_ptr<DatagramClientSocket> socket); + ~SocketLease(); + + unsigned server_index() const { return server_index_; } + + DatagramClientSocket* socket() { return socket_.get(); } + + private: + scoped_refptr<DnsSession> session_; + unsigned server_index_; + scoped_ptr<DatagramClientSocket> socket_; + + DISALLOW_COPY_AND_ASSIGN(SocketLease); + }; + + DnsSession(const DnsConfig& config, + scoped_ptr<DnsSocketPool> socket_pool, + const RandIntCallback& rand_int_callback, + NetLog* net_log); + + const DnsConfig& config() const { return config_; } + NetLog* net_log() const { return net_log_; } + + // Return the next random query ID. + int NextQueryId() const; + + // Return the index of the first configured server to use on first attempt. + unsigned NextFirstServerIndex(); + + // Start with |server_index| and find the index of the next known good server + // to use on this attempt. Returns |server_index| if this server has no + // recorded failures, or if there are no other servers that have not failed + // or have failed longer time ago. + unsigned NextGoodServerIndex(unsigned server_index); + + // Record that server failed to respond (due to SRV_FAIL or timeout). + void RecordServerFailure(unsigned server_index); + + // Record that server responded successfully. + void RecordServerSuccess(unsigned server_index); + + // Record how long it took to receive a response from the server. + void RecordRTT(unsigned server_index, base::TimeDelta rtt); + + // Record suspected loss of a packet for a specific server. + void RecordLostPacket(unsigned server_index, int attempt); + + // Record server stats before it is destroyed. + void RecordServerStats(); + + // Return the timeout for the next query. |attempt| counts from 0 and is used + // for exponential backoff. + base::TimeDelta NextTimeout(unsigned server_index, int attempt); + + // Allocate a socket, already connected to the server address. + // When the SocketLease is destroyed, the socket will be freed. + scoped_ptr<SocketLease> AllocateSocket(unsigned server_index, + const NetLog::Source& source); + + // Creates a StreamSocket from the factory for a transaction over TCP. These + // sockets are not pooled. + scoped_ptr<StreamSocket> CreateTCPSocket(unsigned server_index, + const NetLog::Source& source); + + private: + friend class base::RefCounted<DnsSession>; + ~DnsSession(); + + // Release a socket. + void FreeSocket(unsigned server_index, + scoped_ptr<DatagramClientSocket> socket); + + // Return the timeout using the TCP timeout method. + base::TimeDelta NextTimeoutFromJacobson(unsigned server_index, int attempt); + + // Compute the timeout using the histogram method. + base::TimeDelta NextTimeoutFromHistogram(unsigned server_index, int attempt); + + const DnsConfig config_; + scoped_ptr<DnsSocketPool> socket_pool_; + RandCallback rand_callback_; + NetLog* net_log_; + + // Current index into |config_.nameservers| to begin resolution with. + int server_index_; + + struct ServerStats; + + // Track runtime statistics of each DNS server. + ScopedVector<ServerStats> server_stats_; + + // Buckets shared for all |ServerStats::rtt_histogram|. + struct RttBuckets : public base::BucketRanges { + RttBuckets(); + }; + static base::LazyInstance<RttBuckets>::Leaky rtt_buckets_; + + DISALLOW_COPY_AND_ASSIGN(DnsSession); +}; + +} // namespace net + +#endif // NET_DNS_DNS_SESSION_H_ diff --git a/chromium/net/dns/dns_session_unittest.cc b/chromium/net/dns/dns_session_unittest.cc new file mode 100644 index 00000000000..ed726f23234 --- /dev/null +++ b/chromium/net/dns/dns_session_unittest.cc @@ -0,0 +1,252 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_session.h" + +#include <list> + +#include "base/bind.h" +#include "base/memory/scoped_ptr.h" +#include "base/rand_util.h" +#include "base/stl_util.h" +#include "net/base/net_log.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/dns_socket_pool.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/stream_socket.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +class TestClientSocketFactory : public ClientSocketFactory { + public: + virtual ~TestClientSocketFactory(); + + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + net::NetLog* net_log, + const net::NetLog::Source& source) OVERRIDE; + + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( + const AddressList& addresses, + NetLog*, const NetLog::Source&) OVERRIDE { + NOTIMPLEMENTED(); + return scoped_ptr<StreamSocket>(); + } + + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) OVERRIDE { + NOTIMPLEMENTED(); + return scoped_ptr<SSLClientSocket>(); + } + + virtual void ClearSSLSessionCache() OVERRIDE { + NOTIMPLEMENTED(); + } + + private: + std::list<SocketDataProvider*> data_providers_; +}; + +struct PoolEvent { + enum { ALLOCATE, FREE } action; + unsigned server_index; +}; + +class DnsSessionTest : public testing::Test { + public: + void OnSocketAllocated(unsigned server_index); + void OnSocketFreed(unsigned server_index); + + protected: + void Initialize(unsigned num_servers); + scoped_ptr<DnsSession::SocketLease> Allocate(unsigned server_index); + bool DidAllocate(unsigned server_index); + bool DidFree(unsigned server_index); + bool NoMoreEvents(); + + DnsConfig config_; + scoped_ptr<TestClientSocketFactory> test_client_socket_factory_; + scoped_refptr<DnsSession> session_; + NetLog::Source source_; + + private: + bool ExpectEvent(const PoolEvent& event); + std::list<PoolEvent> events_; +}; + +class MockDnsSocketPool : public DnsSocketPool { + public: + MockDnsSocketPool(ClientSocketFactory* factory, DnsSessionTest* test) + : DnsSocketPool(factory), test_(test) { } + + virtual ~MockDnsSocketPool() { } + + virtual void Initialize( + const std::vector<IPEndPoint>* nameservers, + NetLog* net_log) OVERRIDE { + InitializeInternal(nameservers, net_log); + } + + virtual scoped_ptr<DatagramClientSocket> AllocateSocket( + unsigned server_index) OVERRIDE { + test_->OnSocketAllocated(server_index); + return CreateConnectedSocket(server_index); + } + + virtual void FreeSocket( + unsigned server_index, + scoped_ptr<DatagramClientSocket> socket) OVERRIDE { + test_->OnSocketFreed(server_index); + } + + private: + DnsSessionTest* test_; +}; + +void DnsSessionTest::Initialize(unsigned num_servers) { + CHECK(num_servers < 256u); + config_.nameservers.clear(); + IPAddressNumber dns_ip; + bool rv = ParseIPLiteralToNumber("192.168.1.0", &dns_ip); + EXPECT_TRUE(rv); + for (unsigned char i = 0; i < num_servers; ++i) { + dns_ip[3] = i; + IPEndPoint dns_endpoint(dns_ip, dns_protocol::kDefaultPort); + config_.nameservers.push_back(dns_endpoint); + } + + test_client_socket_factory_.reset(new TestClientSocketFactory()); + + DnsSocketPool* dns_socket_pool = + new MockDnsSocketPool(test_client_socket_factory_.get(), this); + + session_ = new DnsSession(config_, + scoped_ptr<DnsSocketPool>(dns_socket_pool), + base::Bind(&base::RandInt), + NULL /* NetLog */); + + events_.clear(); +} + +scoped_ptr<DnsSession::SocketLease> DnsSessionTest::Allocate( + unsigned server_index) { + return session_->AllocateSocket(server_index, source_); +} + +bool DnsSessionTest::DidAllocate(unsigned server_index) { + PoolEvent expected_event = { PoolEvent::ALLOCATE, server_index }; + return ExpectEvent(expected_event); +} + +bool DnsSessionTest::DidFree(unsigned server_index) { + PoolEvent expected_event = { PoolEvent::FREE, server_index }; + return ExpectEvent(expected_event); +} + +bool DnsSessionTest::NoMoreEvents() { + return events_.empty(); +} + +void DnsSessionTest::OnSocketAllocated(unsigned server_index) { + PoolEvent event = { PoolEvent::ALLOCATE, server_index }; + events_.push_back(event); +} + +void DnsSessionTest::OnSocketFreed(unsigned server_index) { + PoolEvent event = { PoolEvent::FREE, server_index }; + events_.push_back(event); +} + +bool DnsSessionTest::ExpectEvent(const PoolEvent& expected) { + if (events_.empty()) { + return false; + } + + const PoolEvent actual = events_.front(); + if ((expected.action != actual.action) + || (expected.server_index != actual.server_index)) { + return false; + } + events_.pop_front(); + + return true; +} + +scoped_ptr<DatagramClientSocket> +TestClientSocketFactory::CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + net::NetLog* net_log, + const net::NetLog::Source& source) { + // We're not actually expecting to send or receive any data, so use the + // simplest SocketDataProvider with no data supplied. + SocketDataProvider* data_provider = new StaticSocketDataProvider(); + data_providers_.push_back(data_provider); + scoped_ptr<MockUDPClientSocket> socket( + new MockUDPClientSocket(data_provider, net_log)); + data_provider->set_socket(socket.get()); + return socket.PassAs<DatagramClientSocket>(); +} + +TestClientSocketFactory::~TestClientSocketFactory() { + STLDeleteElements(&data_providers_); +} + +TEST_F(DnsSessionTest, AllocateFree) { + scoped_ptr<DnsSession::SocketLease> lease1, lease2; + + Initialize(2); + EXPECT_TRUE(NoMoreEvents()); + + lease1 = Allocate(0); + EXPECT_TRUE(DidAllocate(0)); + EXPECT_TRUE(NoMoreEvents()); + + lease2 = Allocate(1); + EXPECT_TRUE(DidAllocate(1)); + EXPECT_TRUE(NoMoreEvents()); + + lease1.reset(); + EXPECT_TRUE(DidFree(0)); + EXPECT_TRUE(NoMoreEvents()); + + lease2.reset(); + EXPECT_TRUE(DidFree(1)); + EXPECT_TRUE(NoMoreEvents()); +} + +// Expect default calculated timeout to be within 10ms of in DnsConfig. +TEST_F(DnsSessionTest, HistogramTimeoutNormal) { + Initialize(2); + base::TimeDelta timeoutDelta = session_->NextTimeout(0, 0) - config_.timeout; + EXPECT_LT(timeoutDelta.InMilliseconds(), 10); +} + +// Expect short calculated timeout to be within 10ms of in DnsConfig. +TEST_F(DnsSessionTest, HistogramTimeoutShort) { + config_.timeout = base::TimeDelta::FromMilliseconds(15); + Initialize(2); + base::TimeDelta timeoutDelta = session_->NextTimeout(0, 0) - config_.timeout; + EXPECT_LT(timeoutDelta.InMilliseconds(), 10); +} + +// Expect long calculated timeout to be equal to one in DnsConfig. +TEST_F(DnsSessionTest, HistogramTimeoutLong) { + config_.timeout = base::TimeDelta::FromSeconds(15); + Initialize(2); + base::TimeDelta timeout = session_->NextTimeout(0, 0); + EXPECT_EQ(config_.timeout.InMilliseconds(), timeout.InMilliseconds()); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/dns/dns_socket_pool.cc b/chromium/net/dns/dns_socket_pool.cc new file mode 100644 index 00000000000..7a7ecd6ee8f --- /dev/null +++ b/chromium/net/dns/dns_socket_pool.cc @@ -0,0 +1,234 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_socket_pool.h" + +#include "base/logging.h" +#include "base/rand_util.h" +#include "base/stl_util.h" +#include "net/base/address_list.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/rand_callback.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/stream_socket.h" +#include "net/udp/datagram_client_socket.h" + +namespace net { + +namespace { + +// When we initialize the SocketPool, we allocate kInitialPoolSize sockets. +// When we allocate a socket, we ensure we have at least kAllocateMinSize +// sockets to choose from. When we free a socket, we retain it if we have +// less than kRetainMaxSize sockets in the pool. + +// On Windows, we can't request specific (random) ports, since that will +// trigger firewall prompts, so request default ones, but keep a pile of +// them. Everywhere else, request fresh, random ports each time. +#if defined(OS_WIN) +const DatagramSocket::BindType kBindType = DatagramSocket::DEFAULT_BIND; +const unsigned kInitialPoolSize = 256; +const unsigned kAllocateMinSize = 256; +const unsigned kRetainMaxSize = 0; +#else +const DatagramSocket::BindType kBindType = DatagramSocket::RANDOM_BIND; +const unsigned kInitialPoolSize = 0; +const unsigned kAllocateMinSize = 1; +const unsigned kRetainMaxSize = 0; +#endif + +} // namespace + +DnsSocketPool::DnsSocketPool(ClientSocketFactory* socket_factory) + : socket_factory_(socket_factory), + net_log_(NULL), + nameservers_(NULL), + initialized_(false) { +} + +void DnsSocketPool::InitializeInternal( + const std::vector<IPEndPoint>* nameservers, + NetLog* net_log) { + DCHECK(nameservers); + DCHECK(!initialized_); + + net_log_ = net_log; + nameservers_ = nameservers; + initialized_ = true; +} + +scoped_ptr<StreamSocket> DnsSocketPool::CreateTCPSocket( + unsigned server_index, + const NetLog::Source& source) { + DCHECK_LT(server_index, nameservers_->size()); + + return scoped_ptr<StreamSocket>( + socket_factory_->CreateTransportClientSocket( + AddressList((*nameservers_)[server_index]), net_log_, source)); +} + +scoped_ptr<DatagramClientSocket> DnsSocketPool::CreateConnectedSocket( + unsigned server_index) { + DCHECK_LT(server_index, nameservers_->size()); + + scoped_ptr<DatagramClientSocket> socket; + + NetLog::Source no_source; + socket = socket_factory_->CreateDatagramClientSocket( + kBindType, base::Bind(&base::RandInt), net_log_, no_source); + + if (socket.get()) { + int rv = socket->Connect((*nameservers_)[server_index]); + if (rv != OK) { + LOG(WARNING) << "Failed to connect socket: " << rv; + socket.reset(); + } + } else { + LOG(WARNING) << "Failed to create socket."; + } + + return socket.Pass(); +} + +class NullDnsSocketPool : public DnsSocketPool { + public: + NullDnsSocketPool(ClientSocketFactory* factory) + : DnsSocketPool(factory) { + } + + virtual void Initialize( + const std::vector<IPEndPoint>* nameservers, + NetLog* net_log) OVERRIDE { + InitializeInternal(nameservers, net_log); + } + + virtual scoped_ptr<DatagramClientSocket> AllocateSocket( + unsigned server_index) OVERRIDE { + return CreateConnectedSocket(server_index); + } + + virtual void FreeSocket( + unsigned server_index, + scoped_ptr<DatagramClientSocket> socket) OVERRIDE { + } + + private: + DISALLOW_COPY_AND_ASSIGN(NullDnsSocketPool); +}; + +// static +scoped_ptr<DnsSocketPool> DnsSocketPool::CreateNull( + ClientSocketFactory* factory) { + return scoped_ptr<DnsSocketPool>(new NullDnsSocketPool(factory)); +} + +class DefaultDnsSocketPool : public DnsSocketPool { + public: + DefaultDnsSocketPool(ClientSocketFactory* factory) + : DnsSocketPool(factory) { + }; + + virtual ~DefaultDnsSocketPool(); + + virtual void Initialize( + const std::vector<IPEndPoint>* nameservers, + NetLog* net_log) OVERRIDE; + + virtual scoped_ptr<DatagramClientSocket> AllocateSocket( + unsigned server_index) OVERRIDE; + + virtual void FreeSocket( + unsigned server_index, + scoped_ptr<DatagramClientSocket> socket) OVERRIDE; + + private: + void FillPool(unsigned server_index, unsigned size); + + typedef std::vector<DatagramClientSocket*> SocketVector; + + std::vector<SocketVector> pools_; + + DISALLOW_COPY_AND_ASSIGN(DefaultDnsSocketPool); +}; + +// static +scoped_ptr<DnsSocketPool> DnsSocketPool::CreateDefault( + ClientSocketFactory* factory) { + return scoped_ptr<DnsSocketPool>(new DefaultDnsSocketPool(factory)); +} + +void DefaultDnsSocketPool::Initialize( + const std::vector<IPEndPoint>* nameservers, + NetLog* net_log) { + InitializeInternal(nameservers, net_log); + + DCHECK(pools_.empty()); + const unsigned num_servers = nameservers->size(); + pools_.resize(num_servers); + for (unsigned server_index = 0; server_index < num_servers; ++server_index) + FillPool(server_index, kInitialPoolSize); +} + +DefaultDnsSocketPool::~DefaultDnsSocketPool() { + unsigned num_servers = pools_.size(); + for (unsigned server_index = 0; server_index < num_servers; ++server_index) { + SocketVector& pool = pools_[server_index]; + STLDeleteElements(&pool); + } +} + +scoped_ptr<DatagramClientSocket> DefaultDnsSocketPool::AllocateSocket( + unsigned server_index) { + DCHECK_LT(server_index, pools_.size()); + SocketVector& pool = pools_[server_index]; + + FillPool(server_index, kAllocateMinSize); + if (pool.size() == 0) { + LOG(WARNING) << "No DNS sockets available in pool " << server_index << "!"; + return scoped_ptr<DatagramClientSocket>(); + } + + if (pool.size() < kAllocateMinSize) { + LOG(WARNING) << "Low DNS port entropy: wanted " << kAllocateMinSize + << " sockets to choose from, but only have " << pool.size() + << " in pool " << server_index << "."; + } + + unsigned socket_index = base::RandInt(0, pool.size() - 1); + DatagramClientSocket* socket = pool[socket_index]; + pool[socket_index] = pool.back(); + pool.pop_back(); + + return scoped_ptr<DatagramClientSocket>(socket); +} + +void DefaultDnsSocketPool::FreeSocket( + unsigned server_index, + scoped_ptr<DatagramClientSocket> socket) { + DCHECK_LT(server_index, pools_.size()); + + // In some builds, kRetainMaxSize will be 0 if we never reuse sockets. + // In that case, don't compile this code to avoid a "tautological + // comparison" warning from clang. +#if kRetainMaxSize > 0 + SocketVector& pool = pools_[server_index]; + if (pool.size() < kRetainMaxSize) + pool.push_back(socket.release()); +#endif +} + +void DefaultDnsSocketPool::FillPool(unsigned server_index, unsigned size) { + SocketVector& pool = pools_[server_index]; + + for (unsigned pool_index = pool.size(); pool_index < size; ++pool_index) { + DatagramClientSocket* socket = + CreateConnectedSocket(server_index).release(); + if (!socket) + break; + pool.push_back(socket); + } +} + +} // namespace net diff --git a/chromium/net/dns/dns_socket_pool.h b/chromium/net/dns/dns_socket_pool.h new file mode 100644 index 00000000000..6bfe474d6a9 --- /dev/null +++ b/chromium/net/dns/dns_socket_pool.h @@ -0,0 +1,91 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_SOCKET_POOL_H_ +#define NET_DNS_DNS_SOCKET_POOL_H_ + +#include <vector> + +#include "base/memory/scoped_ptr.h" +#include "net/base/net_export.h" +#include "net/base/net_log.h" + +namespace net { + +class ClientSocketFactory; +class DatagramClientSocket; +class IPEndPoint; +class NetLog; +class StreamSocket; + +// A DnsSocketPool is an abstraction layer around a ClientSocketFactory that +// allows preallocation, reuse, or other strategies to manage sockets connected +// to DNS servers. +class NET_EXPORT_PRIVATE DnsSocketPool { + public: + virtual ~DnsSocketPool() { } + + // Creates a DnsSocketPool that implements the default strategy for managing + // sockets. (This varies by platform; see DnsSocketPoolImpl in + // dns_socket_pool.cc for details.) + static scoped_ptr<DnsSocketPool> CreateDefault( + ClientSocketFactory* factory); + + // Creates a DnsSocketPool that implements a "null" strategy -- no sockets are + // preallocated, allocation requests are satisfied by calling the factory + // directly, and returned sockets are deleted immediately. + static scoped_ptr<DnsSocketPool> CreateNull( + ClientSocketFactory* factory); + + // Initializes the DnsSocketPool. |nameservers| is the list of nameservers + // for which the DnsSocketPool will manage sockets; |net_log| is the NetLog + // used when constructing sockets with the factory. + // + // Initialize may not be called more than once, and must be called before + // calling AllocateSocket or FreeSocket. + virtual void Initialize( + const std::vector<IPEndPoint>* nameservers, + NetLog* net_log) = 0; + + // Allocates a socket that is already connected to the nameserver referenced + // by |server_index|. May return a scoped_ptr to NULL if no sockets are + // available to reuse and the factory fails to produce a socket (or produces + // one on which Connect fails). + virtual scoped_ptr<DatagramClientSocket> AllocateSocket( + unsigned server_index) = 0; + + // Frees a socket allocated by AllocateSocket. |server_index| must be the + // same index passed to AllocateSocket. + virtual void FreeSocket( + unsigned server_index, + scoped_ptr<DatagramClientSocket> socket) = 0; + + // Creates a StreamSocket from the factory for a transaction over TCP. These + // sockets are not pooled. + scoped_ptr<StreamSocket> CreateTCPSocket( + unsigned server_index, + const NetLog::Source& source); + + protected: + DnsSocketPool(ClientSocketFactory* socket_factory); + + void InitializeInternal( + const std::vector<IPEndPoint>* nameservers, + NetLog* net_log); + + scoped_ptr<DatagramClientSocket> CreateConnectedSocket( + unsigned server_index); + + private: + ClientSocketFactory* socket_factory_; + NetLog* net_log_; + const std::vector<IPEndPoint>* nameservers_; + bool initialized_; + + DISALLOW_COPY_AND_ASSIGN(DnsSocketPool); +}; + +} // namespace net + +#endif // NET_DNS_DNS_SOCKET_POOL_H_ diff --git a/chromium/net/dns/dns_test_util.cc b/chromium/net/dns/dns_test_util.cc new file mode 100644 index 00000000000..37bf855dd10 --- /dev/null +++ b/chromium/net/dns/dns_test_util.cc @@ -0,0 +1,210 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_test_util.h" + +#include <string> + +#include "base/bind.h" +#include "base/memory/weak_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/sys_byteorder.h" +#include "net/base/big_endian.h" +#include "net/base/dns_util.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/dns/address_sorter.h" +#include "net/dns/dns_client.h" +#include "net/dns/dns_config_service.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/dns_query.h" +#include "net/dns/dns_response.h" +#include "net/dns/dns_transaction.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { +namespace { + +// A DnsTransaction which uses MockDnsClientRuleList to determine the response. +class MockTransaction : public DnsTransaction, + public base::SupportsWeakPtr<MockTransaction> { + public: + MockTransaction(const MockDnsClientRuleList& rules, + const std::string& hostname, + uint16 qtype, + const DnsTransactionFactory::CallbackType& callback) + : result_(MockDnsClientRule::FAIL), + hostname_(hostname), + qtype_(qtype), + callback_(callback), + started_(false) { + // Find the relevant rule which matches |qtype| and prefix of |hostname|. + for (size_t i = 0; i < rules.size(); ++i) { + const std::string& prefix = rules[i].prefix; + if ((rules[i].qtype == qtype) && + (hostname.size() >= prefix.size()) && + (hostname.compare(0, prefix.size(), prefix) == 0)) { + result_ = rules[i].result; + break; + } + } + } + + virtual const std::string& GetHostname() const OVERRIDE { + return hostname_; + } + + virtual uint16 GetType() const OVERRIDE { + return qtype_; + } + + virtual void Start() OVERRIDE { + EXPECT_FALSE(started_); + started_ = true; + // Using WeakPtr to cleanly cancel when transaction is destroyed. + base::MessageLoop::current()->PostTask( + FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr())); + } + + private: + void Finish() { + switch (result_) { + case MockDnsClientRule::EMPTY: + case MockDnsClientRule::OK: { + std::string qname; + DNSDomainFromDot(hostname_, &qname); + DnsQuery query(0, qname, qtype_); + + DnsResponse response; + char* buffer = response.io_buffer()->data(); + int nbytes = query.io_buffer()->size(); + memcpy(buffer, query.io_buffer()->data(), nbytes); + dns_protocol::Header* header = + reinterpret_cast<dns_protocol::Header*>(buffer); + header->flags |= dns_protocol::kFlagResponse; + + if (MockDnsClientRule::OK == result_) { + const uint16 kPointerToQueryName = + static_cast<uint16>(0xc000 | sizeof(*header)); + + const uint32 kTTL = 86400; // One day. + + // Size of RDATA which is a IPv4 or IPv6 address. + size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ? + net::kIPv4AddressSize : net::kIPv6AddressSize; + + // 12 is the sum of sizes of the compressed name reference, TYPE, + // CLASS, TTL and RDLENGTH. + size_t answer_size = 12 + rdata_size; + + // Write answer with loopback IP address. + header->ancount = base::HostToNet16(1); + BigEndianWriter writer(buffer + nbytes, answer_size); + writer.WriteU16(kPointerToQueryName); + writer.WriteU16(qtype_); + writer.WriteU16(net::dns_protocol::kClassIN); + writer.WriteU32(kTTL); + writer.WriteU16(rdata_size); + if (qtype_ == net::dns_protocol::kTypeA) { + char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; + writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); + } else { + char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1 }; + writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); + } + nbytes += answer_size; + } + EXPECT_TRUE(response.InitParse(nbytes, query)); + callback_.Run(this, OK, &response); + } break; + case MockDnsClientRule::FAIL: + callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); + break; + case MockDnsClientRule::TIMEOUT: + callback_.Run(this, ERR_DNS_TIMED_OUT, NULL); + break; + default: + NOTREACHED(); + break; + } + } + + MockDnsClientRule::Result result_; + const std::string hostname_; + const uint16 qtype_; + DnsTransactionFactory::CallbackType callback_; + bool started_; +}; + + +// A DnsTransactionFactory which creates MockTransaction. +class MockTransactionFactory : public DnsTransactionFactory { + public: + explicit MockTransactionFactory(const MockDnsClientRuleList& rules) + : rules_(rules) {} + virtual ~MockTransactionFactory() {} + + virtual scoped_ptr<DnsTransaction> CreateTransaction( + const std::string& hostname, + uint16 qtype, + const DnsTransactionFactory::CallbackType& callback, + const BoundNetLog&) OVERRIDE { + return scoped_ptr<DnsTransaction>( + new MockTransaction(rules_, hostname, qtype, callback)); + } + + private: + MockDnsClientRuleList rules_; +}; + +class MockAddressSorter : public AddressSorter { + public: + virtual ~MockAddressSorter() {} + virtual void Sort(const AddressList& list, + const CallbackType& callback) const OVERRIDE { + // Do nothing. + callback.Run(true, list); + } +}; + +// MockDnsClient provides MockTransactionFactory. +class MockDnsClient : public DnsClient { + public: + MockDnsClient(const DnsConfig& config, + const MockDnsClientRuleList& rules) + : config_(config), factory_(rules) {} + virtual ~MockDnsClient() {} + + virtual void SetConfig(const DnsConfig& config) OVERRIDE { + config_ = config; + } + + virtual const DnsConfig* GetConfig() const OVERRIDE { + return config_.IsValid() ? &config_ : NULL; + } + + virtual DnsTransactionFactory* GetTransactionFactory() OVERRIDE { + return config_.IsValid() ? &factory_ : NULL; + } + + virtual AddressSorter* GetAddressSorter() OVERRIDE { + return &address_sorter_; + } + + private: + DnsConfig config_; + MockTransactionFactory factory_; + MockAddressSorter address_sorter_; +}; + +} // namespace + +// static +scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config, + const MockDnsClientRuleList& rules) { + return scoped_ptr<DnsClient>(new MockDnsClient(config, rules)); +} + +} // namespace net diff --git a/chromium/net/dns/dns_test_util.h b/chromium/net/dns/dns_test_util.h new file mode 100644 index 00000000000..d447b299c86 --- /dev/null +++ b/chromium/net/dns/dns_test_util.h @@ -0,0 +1,205 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_TEST_UTIL_H_ +#define NET_DNS_DNS_TEST_UTIL_H_ + +#include <string> +#include <vector> + +#include "base/basictypes.h" +#include "base/memory/scoped_ptr.h" +#include "net/dns/dns_config_service.h" +#include "net/dns/dns_protocol.h" + +namespace net { + +//----------------------------------------------------------------------------- +// Query/response set for www.google.com, ID is fixed to 0. +static const char kT0HostName[] = "www.google.com"; +static const uint16 kT0Qtype = dns_protocol::kTypeA; +static const char kT0DnsName[] = { + 0x03, 'w', 'w', 'w', + 0x06, 'g', 'o', 'o', 'g', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00 +}; +static const size_t kT0QuerySize = 32; +static const uint8 kT0ResponseDatagram[] = { + // response contains one CNAME for www.l.google.com and the following + // IP addresses: 74.125.226.{179,180,176,177,178} + 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x06, + 0x00, 0x00, 0x00, 0x00, 0x03, 0x77, 0x77, 0x77, + 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, + 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, + 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x01, + 0x4d, 0x13, 0x00, 0x08, 0x03, 0x77, 0x77, 0x77, + 0x01, 0x6c, 0xc0, 0x10, 0xc0, 0x2c, 0x00, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x04, + 0x4a, 0x7d, 0xe2, 0xb3, 0xc0, 0x2c, 0x00, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x04, + 0x4a, 0x7d, 0xe2, 0xb4, 0xc0, 0x2c, 0x00, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x04, + 0x4a, 0x7d, 0xe2, 0xb0, 0xc0, 0x2c, 0x00, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x04, + 0x4a, 0x7d, 0xe2, 0xb1, 0xc0, 0x2c, 0x00, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x04, + 0x4a, 0x7d, 0xe2, 0xb2 +}; +static const char* const kT0IpAddresses[] = { + "74.125.226.179", "74.125.226.180", "74.125.226.176", + "74.125.226.177", "74.125.226.178" +}; +static const char kT0CanonName[] = "www.l.google.com"; +static const int kT0TTL = 0x000000e4; +// +1 for the CNAME record. +static const unsigned kT0RecordCount = arraysize(kT0IpAddresses) + 1; + +//----------------------------------------------------------------------------- +// Query/response set for codereview.chromium.org, ID is fixed to 1. +static const char kT1HostName[] = "codereview.chromium.org"; +static const uint16 kT1Qtype = dns_protocol::kTypeA; +static const char kT1DnsName[] = { + 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w', + 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm', + 0x03, 'o', 'r', 'g', + 0x00 +}; +static const size_t kT1QuerySize = 41; +static const uint8 kT1ResponseDatagram[] = { + // response contains one CNAME for ghs.l.google.com and the following + // IP address: 64.233.169.121 + 0x00, 0x01, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x0a, 0x63, 0x6f, 0x64, + 0x65, 0x72, 0x65, 0x76, 0x69, 0x65, 0x77, 0x08, + 0x63, 0x68, 0x72, 0x6f, 0x6d, 0x69, 0x75, 0x6d, + 0x03, 0x6f, 0x72, 0x67, 0x00, 0x00, 0x01, 0x00, + 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, + 0x01, 0x41, 0x75, 0x00, 0x12, 0x03, 0x67, 0x68, + 0x73, 0x01, 0x6c, 0x06, 0x67, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0xc0, + 0x35, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, + 0x0b, 0x00, 0x04, 0x40, 0xe9, 0xa9, 0x79 +}; +static const char* const kT1IpAddresses[] = { + "64.233.169.121" +}; +static const char kT1CanonName[] = "ghs.l.google.com"; +static const int kT1TTL = 0x0000010b; +// +1 for the CNAME record. +static const unsigned kT1RecordCount = arraysize(kT1IpAddresses) + 1; + +//----------------------------------------------------------------------------- +// Query/response set for www.ccs.neu.edu, ID is fixed to 2. +static const char kT2HostName[] = "www.ccs.neu.edu"; +static const uint16 kT2Qtype = dns_protocol::kTypeA; +static const char kT2DnsName[] = { + 0x03, 'w', 'w', 'w', + 0x03, 'c', 'c', 's', + 0x03, 'n', 'e', 'u', + 0x03, 'e', 'd', 'u', + 0x00 +}; +static const size_t kT2QuerySize = 33; +static const uint8 kT2ResponseDatagram[] = { + // response contains one CNAME for vulcan.ccs.neu.edu and the following + // IP address: 129.10.116.81 + 0x00, 0x02, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x03, 0x77, 0x77, 0x77, + 0x03, 0x63, 0x63, 0x73, 0x03, 0x6e, 0x65, 0x75, + 0x03, 0x65, 0x64, 0x75, 0x00, 0x00, 0x01, 0x00, + 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, + 0x00, 0x01, 0x2c, 0x00, 0x09, 0x06, 0x76, 0x75, + 0x6c, 0x63, 0x61, 0x6e, 0xc0, 0x10, 0xc0, 0x2d, + 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x2c, + 0x00, 0x04, 0x81, 0x0a, 0x74, 0x51 +}; +static const char* const kT2IpAddresses[] = { + "129.10.116.81" +}; +static const char kT2CanonName[] = "vulcan.ccs.neu.edu"; +static const int kT2TTL = 0x0000012c; +// +1 for the CNAME record. +static const unsigned kT2RecordCount = arraysize(kT2IpAddresses) + 1; + +//----------------------------------------------------------------------------- +// Query/response set for www.google.az, ID is fixed to 3. +static const char kT3HostName[] = "www.google.az"; +static const uint16 kT3Qtype = dns_protocol::kTypeA; +static const char kT3DnsName[] = { + 0x03, 'w', 'w', 'w', + 0x06, 'g', 'o', 'o', 'g', 'l', 'e', + 0x02, 'a', 'z', + 0x00 +}; +static const size_t kT3QuerySize = 31; +static const uint8 kT3ResponseDatagram[] = { + // response contains www.google.com as CNAME for www.google.az and + // www.l.google.com as CNAME for www.google.com and the following + // IP addresses: 74.125.226.{178,179,180,176,177} + // The TTLs on the records are: 0x00015099, 0x00025099, 0x00000415, + // 0x00003015, 0x00002015, 0x00000015, 0x00001015. + // The last record is an imaginary TXT record for t.google.com. + 0x00, 0x03, 0x81, 0x80, 0x00, 0x01, 0x00, 0x08, + 0x00, 0x00, 0x00, 0x00, 0x03, 0x77, 0x77, 0x77, + 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x02, + 0x61, 0x7a, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, + 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x01, 0x50, + 0x99, 0x00, 0x10, 0x03, 0x77, 0x77, 0x77, 0x06, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, + 0x6f, 0x6d, 0x00, 0xc0, 0x2b, 0x00, 0x05, 0x00, + 0x01, 0x00, 0x02, 0x50, 0x99, 0x00, 0x08, 0x03, + 0x77, 0x77, 0x77, 0x01, 0x6c, 0xc0, 0x2f, 0xc0, + 0x47, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x04, + 0x15, 0x00, 0x04, 0x4a, 0x7d, 0xe2, 0xb2, 0xc0, + 0x47, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x30, + 0x15, 0x00, 0x04, 0x4a, 0x7d, 0xe2, 0xb3, 0xc0, + 0x47, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x20, + 0x15, 0x00, 0x04, 0x4a, 0x7d, 0xe2, 0xb4, 0xc0, + 0x47, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x15, 0x00, 0x04, 0x4a, 0x7d, 0xe2, 0xb0, 0xc0, + 0x47, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x10, + 0x15, 0x00, 0x04, 0x4a, 0x7d, 0xe2, 0xb1, 0x01, + 0x74, 0xc0, 0x2f, 0x00, 0x10, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x01, 0x00, 0x04, 0xde, 0xad, 0xfe, + 0xed +}; +static const char* const kT3IpAddresses[] = { + "74.125.226.178", "74.125.226.179", "74.125.226.180", + "74.125.226.176", "74.125.226.177" +}; +static const char kT3CanonName[] = "www.l.google.com"; +static const int kT3TTL = 0x00000015; +// +2 for the CNAME records, +1 for TXT record. +static const unsigned kT3RecordCount = arraysize(kT3IpAddresses) + 3; + +class DnsClient; + +struct MockDnsClientRule { + enum Result { + FAIL, // Fail asynchronously with ERR_NAME_NOT_RESOLVED. + TIMEOUT, // Fail asynchronously with ERR_DNS_TIMEOUT. + EMPTY, // Return an empty response. + OK, // Return a response with loopback address. + }; + + MockDnsClientRule(const std::string& prefix_arg, + uint16 qtype_arg, + Result result_arg) + : result(result_arg), prefix(prefix_arg), qtype(qtype_arg) { } + + Result result; + std::string prefix; + uint16 qtype; +}; + +typedef std::vector<MockDnsClientRule> MockDnsClientRuleList; + +// Creates mock DnsClient for testing HostResolverImpl. +scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config, + const MockDnsClientRuleList& rules); + +} // namespace net + +#endif // NET_DNS_DNS_TEST_UTIL_H_ diff --git a/chromium/net/dns/dns_transaction.cc b/chromium/net/dns/dns_transaction.cc new file mode 100644 index 00000000000..170ea678d4b --- /dev/null +++ b/chromium/net/dns/dns_transaction.cc @@ -0,0 +1,963 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_transaction.h" + +#include <deque> +#include <string> +#include <vector> + +#include "base/bind.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/scoped_vector.h" +#include "base/memory/weak_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/metrics/histogram.h" +#include "base/rand_util.h" +#include "base/stl_util.h" +#include "base/strings/string_piece.h" +#include "base/threading/non_thread_safe.h" +#include "base/timer/timer.h" +#include "base/values.h" +#include "net/base/big_endian.h" +#include "net/base/completion_callback.h" +#include "net/base/dns_util.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/dns_query.h" +#include "net/dns/dns_response.h" +#include "net/dns/dns_session.h" +#include "net/socket/stream_socket.h" +#include "net/udp/datagram_client_socket.h" + +namespace net { + +namespace { + +// Provide a common macro to simplify code and readability. We must use a +// macro as the underlying HISTOGRAM macro creates static variables. +#define DNS_HISTOGRAM(name, time) UMA_HISTOGRAM_CUSTOM_TIMES(name, time, \ + base::TimeDelta::FromMilliseconds(1), base::TimeDelta::FromHours(1), 100) + +// Count labels in the fully-qualified name in DNS format. +int CountLabels(const std::string& name) { + size_t count = 0; + for (size_t i = 0; i < name.size() && name[i]; i += name[i] + 1) + ++count; + return count; +} + +bool IsIPLiteral(const std::string& hostname) { + IPAddressNumber ip; + return ParseIPLiteralToNumber(hostname, &ip); +} + +base::Value* NetLogStartCallback(const std::string* hostname, + uint16 qtype, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetString("hostname", *hostname); + dict->SetInteger("query_type", qtype); + return dict; +}; + +// ---------------------------------------------------------------------------- + +// A single asynchronous DNS exchange, which consists of sending out a +// DNS query, waiting for a response, and returning the response that it +// matches. Logging is done in the socket and in the outer DnsTransaction. +class DnsAttempt { + public: + explicit DnsAttempt(unsigned server_index) + : result_(ERR_FAILED), server_index_(server_index) {} + + virtual ~DnsAttempt() {} + // Starts the attempt. Returns ERR_IO_PENDING if cannot complete synchronously + // and calls |callback| upon completion. + virtual int Start(const CompletionCallback& callback) = 0; + + // Returns the query of this attempt. + virtual const DnsQuery* GetQuery() const = 0; + + // Returns the response or NULL if has not received a matching response from + // the server. + virtual const DnsResponse* GetResponse() const = 0; + + // Returns the net log bound to the source of the socket. + virtual const BoundNetLog& GetSocketNetLog() const = 0; + + // Returns the index of the destination server within DnsConfig::nameservers. + unsigned server_index() const { return server_index_; } + + // Returns a Value representing the received response, along with a reference + // to the NetLog source source of the UDP socket used. The request must have + // completed before this is called. + base::Value* NetLogResponseCallback(NetLog::LogLevel log_level) const { + DCHECK(GetResponse()->IsValid()); + + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetInteger("rcode", GetResponse()->rcode()); + dict->SetInteger("answer_count", GetResponse()->answer_count()); + GetSocketNetLog().source().AddToEventParameters(dict); + return dict; + } + + void set_result(int result) { + result_ = result; + } + + // True if current attempt is pending (waiting for server response). + bool is_pending() const { + return result_ == ERR_IO_PENDING; + } + + // True if attempt is completed (received server response). + bool is_completed() const { + return (result_ == OK) || (result_ == ERR_NAME_NOT_RESOLVED) || + (result_ == ERR_DNS_SERVER_REQUIRES_TCP); + } + + private: + // Result of last operation. + int result_; + + const unsigned server_index_; +}; + +class DnsUDPAttempt : public DnsAttempt { + public: + DnsUDPAttempt(unsigned server_index, + scoped_ptr<DnsSession::SocketLease> socket_lease, + scoped_ptr<DnsQuery> query) + : DnsAttempt(server_index), + next_state_(STATE_NONE), + received_malformed_response_(false), + socket_lease_(socket_lease.Pass()), + query_(query.Pass()) {} + + // DnsAttempt: + virtual int Start(const CompletionCallback& callback) OVERRIDE { + DCHECK_EQ(STATE_NONE, next_state_); + callback_ = callback; + start_time_ = base::TimeTicks::Now(); + next_state_ = STATE_SEND_QUERY; + return DoLoop(OK); + } + + virtual const DnsQuery* GetQuery() const OVERRIDE { + return query_.get(); + } + + virtual const DnsResponse* GetResponse() const OVERRIDE { + const DnsResponse* resp = response_.get(); + return (resp != NULL && resp->IsValid()) ? resp : NULL; + } + + virtual const BoundNetLog& GetSocketNetLog() const OVERRIDE { + return socket_lease_->socket()->NetLog(); + } + + private: + enum State { + STATE_SEND_QUERY, + STATE_SEND_QUERY_COMPLETE, + STATE_READ_RESPONSE, + STATE_READ_RESPONSE_COMPLETE, + STATE_NONE, + }; + + DatagramClientSocket* socket() { + return socket_lease_->socket(); + } + + int DoLoop(int result) { + CHECK_NE(STATE_NONE, next_state_); + int rv = result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_SEND_QUERY: + rv = DoSendQuery(); + break; + case STATE_SEND_QUERY_COMPLETE: + rv = DoSendQueryComplete(rv); + break; + case STATE_READ_RESPONSE: + rv = DoReadResponse(); + break; + case STATE_READ_RESPONSE_COMPLETE: + rv = DoReadResponseComplete(rv); + break; + default: + NOTREACHED(); + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + + set_result(rv); + // If we received a malformed response, and are now waiting for another one, + // indicate to the transaction that the server might be misbehaving. + if (rv == ERR_IO_PENDING && received_malformed_response_) + return ERR_DNS_MALFORMED_RESPONSE; + if (rv == OK) { + DCHECK_EQ(STATE_NONE, next_state_); + DNS_HISTOGRAM("AsyncDNS.UDPAttemptSuccess", + base::TimeTicks::Now() - start_time_); + } else if (rv != ERR_IO_PENDING) { + DNS_HISTOGRAM("AsyncDNS.UDPAttemptFail", + base::TimeTicks::Now() - start_time_); + } + return rv; + } + + int DoSendQuery() { + next_state_ = STATE_SEND_QUERY_COMPLETE; + return socket()->Write(query_->io_buffer(), + query_->io_buffer()->size(), + base::Bind(&DnsUDPAttempt::OnIOComplete, + base::Unretained(this))); + } + + int DoSendQueryComplete(int rv) { + DCHECK_NE(ERR_IO_PENDING, rv); + if (rv < 0) + return rv; + + // Writing to UDP should not result in a partial datagram. + if (rv != query_->io_buffer()->size()) + return ERR_MSG_TOO_BIG; + + next_state_ = STATE_READ_RESPONSE; + return OK; + } + + int DoReadResponse() { + next_state_ = STATE_READ_RESPONSE_COMPLETE; + response_.reset(new DnsResponse()); + return socket()->Read(response_->io_buffer(), + response_->io_buffer()->size(), + base::Bind(&DnsUDPAttempt::OnIOComplete, + base::Unretained(this))); + } + + int DoReadResponseComplete(int rv) { + DCHECK_NE(ERR_IO_PENDING, rv); + if (rv < 0) + return rv; + + DCHECK(rv); + if (!response_->InitParse(rv, *query_)) { + // Other implementations simply ignore mismatched responses. Since each + // DnsUDPAttempt binds to a different port, we might find that responses + // to previously timed out queries lead to failures in the future. + // Our solution is to make another attempt, in case the query truly + // failed, but keep this attempt alive, in case it was a false alarm. + received_malformed_response_ = true; + next_state_ = STATE_READ_RESPONSE; + return OK; + } + if (response_->flags() & dns_protocol::kFlagTC) + return ERR_DNS_SERVER_REQUIRES_TCP; + // TODO(szym): Extract TTL for NXDOMAIN results. http://crbug.com/115051 + if (response_->rcode() == dns_protocol::kRcodeNXDOMAIN) + return ERR_NAME_NOT_RESOLVED; + if (response_->rcode() != dns_protocol::kRcodeNOERROR) + return ERR_DNS_SERVER_FAILED; + + return OK; + } + + void OnIOComplete(int rv) { + rv = DoLoop(rv); + if (rv != ERR_IO_PENDING) + callback_.Run(rv); + } + + State next_state_; + bool received_malformed_response_; + base::TimeTicks start_time_; + + scoped_ptr<DnsSession::SocketLease> socket_lease_; + scoped_ptr<DnsQuery> query_; + + scoped_ptr<DnsResponse> response_; + + CompletionCallback callback_; + + DISALLOW_COPY_AND_ASSIGN(DnsUDPAttempt); +}; + +class DnsTCPAttempt : public DnsAttempt { + public: + DnsTCPAttempt(unsigned server_index, + scoped_ptr<StreamSocket> socket, + scoped_ptr<DnsQuery> query) + : DnsAttempt(server_index), + next_state_(STATE_NONE), + socket_(socket.Pass()), + query_(query.Pass()), + length_buffer_(new IOBufferWithSize(sizeof(uint16))), + response_length_(0) {} + + // DnsAttempt: + virtual int Start(const CompletionCallback& callback) OVERRIDE { + DCHECK_EQ(STATE_NONE, next_state_); + callback_ = callback; + start_time_ = base::TimeTicks::Now(); + next_state_ = STATE_CONNECT_COMPLETE; + int rv = socket_->Connect(base::Bind(&DnsTCPAttempt::OnIOComplete, + base::Unretained(this))); + if (rv == ERR_IO_PENDING) { + set_result(rv); + return rv; + } + return DoLoop(rv); + } + + virtual const DnsQuery* GetQuery() const OVERRIDE { + return query_.get(); + } + + virtual const DnsResponse* GetResponse() const OVERRIDE { + const DnsResponse* resp = response_.get(); + return (resp != NULL && resp->IsValid()) ? resp : NULL; + } + + virtual const BoundNetLog& GetSocketNetLog() const OVERRIDE { + return socket_->NetLog(); + } + + private: + enum State { + STATE_CONNECT_COMPLETE, + STATE_SEND_LENGTH, + STATE_SEND_QUERY, + STATE_READ_LENGTH, + STATE_READ_RESPONSE, + STATE_NONE, + }; + + int DoLoop(int result) { + CHECK_NE(STATE_NONE, next_state_); + int rv = result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_CONNECT_COMPLETE: + rv = DoConnectComplete(rv); + break; + case STATE_SEND_LENGTH: + rv = DoSendLength(rv); + break; + case STATE_SEND_QUERY: + rv = DoSendQuery(rv); + break; + case STATE_READ_LENGTH: + rv = DoReadLength(rv); + break; + case STATE_READ_RESPONSE: + rv = DoReadResponse(rv); + break; + default: + NOTREACHED(); + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + + set_result(rv); + if (rv == OK) { + DCHECK_EQ(STATE_NONE, next_state_); + DNS_HISTOGRAM("AsyncDNS.TCPAttemptSuccess", + base::TimeTicks::Now() - start_time_); + } else if (rv != ERR_IO_PENDING) { + DNS_HISTOGRAM("AsyncDNS.TCPAttemptFail", + base::TimeTicks::Now() - start_time_); + } + return rv; + } + + int DoConnectComplete(int rv) { + DCHECK_NE(ERR_IO_PENDING, rv); + if (rv < 0) + return rv; + + WriteBigEndian<uint16>(length_buffer_->data(), query_->io_buffer()->size()); + buffer_ = + new DrainableIOBuffer(length_buffer_.get(), length_buffer_->size()); + next_state_ = STATE_SEND_LENGTH; + return OK; + } + + int DoSendLength(int rv) { + DCHECK_NE(ERR_IO_PENDING, rv); + if (rv < 0) + return rv; + + buffer_->DidConsume(rv); + if (buffer_->BytesRemaining() > 0) { + next_state_ = STATE_SEND_LENGTH; + return socket_->Write( + buffer_.get(), + buffer_->BytesRemaining(), + base::Bind(&DnsTCPAttempt::OnIOComplete, base::Unretained(this))); + } + buffer_ = new DrainableIOBuffer(query_->io_buffer(), + query_->io_buffer()->size()); + next_state_ = STATE_SEND_QUERY; + return OK; + } + + int DoSendQuery(int rv) { + DCHECK_NE(ERR_IO_PENDING, rv); + if (rv < 0) + return rv; + + buffer_->DidConsume(rv); + if (buffer_->BytesRemaining() > 0) { + next_state_ = STATE_SEND_QUERY; + return socket_->Write( + buffer_.get(), + buffer_->BytesRemaining(), + base::Bind(&DnsTCPAttempt::OnIOComplete, base::Unretained(this))); + } + buffer_ = + new DrainableIOBuffer(length_buffer_.get(), length_buffer_->size()); + next_state_ = STATE_READ_LENGTH; + return OK; + } + + int DoReadLength(int rv) { + DCHECK_NE(ERR_IO_PENDING, rv); + if (rv < 0) + return rv; + + buffer_->DidConsume(rv); + if (buffer_->BytesRemaining() > 0) { + next_state_ = STATE_READ_LENGTH; + return socket_->Read( + buffer_.get(), + buffer_->BytesRemaining(), + base::Bind(&DnsTCPAttempt::OnIOComplete, base::Unretained(this))); + } + ReadBigEndian<uint16>(length_buffer_->data(), &response_length_); + // Check if advertised response is too short. (Optimization only.) + if (response_length_ < query_->io_buffer()->size()) + return ERR_DNS_MALFORMED_RESPONSE; + // Allocate more space so that DnsResponse::InitParse sanity check passes. + response_.reset(new DnsResponse(response_length_ + 1)); + buffer_ = new DrainableIOBuffer(response_->io_buffer(), response_length_); + next_state_ = STATE_READ_RESPONSE; + return OK; + } + + int DoReadResponse(int rv) { + DCHECK_NE(ERR_IO_PENDING, rv); + if (rv < 0) + return rv; + + buffer_->DidConsume(rv); + if (buffer_->BytesRemaining() > 0) { + next_state_ = STATE_READ_RESPONSE; + return socket_->Read( + buffer_.get(), + buffer_->BytesRemaining(), + base::Bind(&DnsTCPAttempt::OnIOComplete, base::Unretained(this))); + } + if (!response_->InitParse(buffer_->BytesConsumed(), *query_)) + return ERR_DNS_MALFORMED_RESPONSE; + if (response_->flags() & dns_protocol::kFlagTC) + return ERR_UNEXPECTED; + // TODO(szym): Frankly, none of these are expected. + if (response_->rcode() == dns_protocol::kRcodeNXDOMAIN) + return ERR_NAME_NOT_RESOLVED; + if (response_->rcode() != dns_protocol::kRcodeNOERROR) + return ERR_DNS_SERVER_FAILED; + + return OK; + } + + void OnIOComplete(int rv) { + rv = DoLoop(rv); + if (rv != ERR_IO_PENDING) + callback_.Run(rv); + } + + State next_state_; + base::TimeTicks start_time_; + + scoped_ptr<StreamSocket> socket_; + scoped_ptr<DnsQuery> query_; + scoped_refptr<IOBufferWithSize> length_buffer_; + scoped_refptr<DrainableIOBuffer> buffer_; + + uint16 response_length_; + scoped_ptr<DnsResponse> response_; + + CompletionCallback callback_; + + DISALLOW_COPY_AND_ASSIGN(DnsTCPAttempt); +}; + +// ---------------------------------------------------------------------------- + +// Implements DnsTransaction. Configuration is supplied by DnsSession. +// The suffix list is built according to the DnsConfig from the session. +// The timeout for each DnsUDPAttempt is given by DnsSession::NextTimeout. +// The first server to attempt on each query is given by +// DnsSession::NextFirstServerIndex, and the order is round-robin afterwards. +// Each server is attempted DnsConfig::attempts times. +class DnsTransactionImpl : public DnsTransaction, + public base::NonThreadSafe, + public base::SupportsWeakPtr<DnsTransactionImpl> { + public: + DnsTransactionImpl(DnsSession* session, + const std::string& hostname, + uint16 qtype, + const DnsTransactionFactory::CallbackType& callback, + const BoundNetLog& net_log) + : session_(session), + hostname_(hostname), + qtype_(qtype), + callback_(callback), + net_log_(net_log), + qnames_initial_size_(0), + attempts_count_(0), + had_tcp_attempt_(false), + first_server_index_(0) { + DCHECK(session_.get()); + DCHECK(!hostname_.empty()); + DCHECK(!callback_.is_null()); + DCHECK(!IsIPLiteral(hostname_)); + } + + virtual ~DnsTransactionImpl() { + if (!callback_.is_null()) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_DNS_TRANSACTION, + ERR_ABORTED); + } // otherwise logged in DoCallback or Start + } + + virtual const std::string& GetHostname() const OVERRIDE { + DCHECK(CalledOnValidThread()); + return hostname_; + } + + virtual uint16 GetType() const OVERRIDE { + DCHECK(CalledOnValidThread()); + return qtype_; + } + + virtual void Start() OVERRIDE { + DCHECK(!callback_.is_null()); + DCHECK(attempts_.empty()); + net_log_.BeginEvent(NetLog::TYPE_DNS_TRANSACTION, + base::Bind(&NetLogStartCallback, &hostname_, qtype_)); + AttemptResult result(PrepareSearch(), NULL); + if (result.rv == OK) { + qnames_initial_size_ = qnames_.size(); + if (qtype_ == dns_protocol::kTypeA) + UMA_HISTOGRAM_COUNTS("AsyncDNS.SuffixSearchStart", qnames_.size()); + result = ProcessAttemptResult(StartQuery()); + } + + // Must always return result asynchronously, to avoid reentrancy. + if (result.rv != ERR_IO_PENDING) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&DnsTransactionImpl::DoCallback, AsWeakPtr(), result)); + } + } + + private: + // Wrapper for the result of a DnsUDPAttempt. + struct AttemptResult { + AttemptResult(int rv, const DnsAttempt* attempt) + : rv(rv), attempt(attempt) {} + + int rv; + const DnsAttempt* attempt; + }; + + // Prepares |qnames_| according to the DnsConfig. + int PrepareSearch() { + const DnsConfig& config = session_->config(); + + std::string labeled_hostname; + if (!DNSDomainFromDot(hostname_, &labeled_hostname)) + return ERR_INVALID_ARGUMENT; + + if (hostname_[hostname_.size() - 1] == '.') { + // It's a fully-qualified name, no suffix search. + qnames_.push_back(labeled_hostname); + return OK; + } + + int ndots = CountLabels(labeled_hostname) - 1; + + if (ndots > 0 && !config.append_to_multi_label_name) { + qnames_.push_back(labeled_hostname); + return OK; + } + + // Set true when |labeled_hostname| is put on the list. + bool had_hostname = false; + + if (ndots >= config.ndots) { + qnames_.push_back(labeled_hostname); + had_hostname = true; + } + + std::string qname; + for (size_t i = 0; i < config.search.size(); ++i) { + // Ignore invalid (too long) combinations. + if (!DNSDomainFromDot(hostname_ + "." + config.search[i], &qname)) + continue; + if (qname.size() == labeled_hostname.size()) { + if (had_hostname) + continue; + had_hostname = true; + } + qnames_.push_back(qname); + } + + if (ndots > 0 && !had_hostname) + qnames_.push_back(labeled_hostname); + + return qnames_.empty() ? ERR_DNS_SEARCH_EMPTY : OK; + } + + void DoCallback(AttemptResult result) { + DCHECK(!callback_.is_null()); + DCHECK_NE(ERR_IO_PENDING, result.rv); + const DnsResponse* response = result.attempt ? + result.attempt->GetResponse() : NULL; + CHECK(result.rv != OK || response != NULL); + + timer_.Stop(); + RecordLostPacketsIfAny(); + if (result.rv == OK) + UMA_HISTOGRAM_COUNTS("AsyncDNS.AttemptCountSuccess", attempts_count_); + else + UMA_HISTOGRAM_COUNTS("AsyncDNS.AttemptCountFail", attempts_count_); + + if (response && qtype_ == dns_protocol::kTypeA) { + UMA_HISTOGRAM_COUNTS("AsyncDNS.SuffixSearchRemain", qnames_.size()); + UMA_HISTOGRAM_COUNTS("AsyncDNS.SuffixSearchDone", + qnames_initial_size_ - qnames_.size()); + } + + DnsTransactionFactory::CallbackType callback = callback_; + callback_.Reset(); + + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_DNS_TRANSACTION, result.rv); + callback.Run(this, result.rv, response); + } + + // Makes another attempt at the current name, |qnames_.front()|, using the + // next nameserver. + AttemptResult MakeAttempt() { + unsigned attempt_number = attempts_.size(); + + uint16 id = session_->NextQueryId(); + scoped_ptr<DnsQuery> query; + if (attempts_.empty()) { + query.reset(new DnsQuery(id, qnames_.front(), qtype_)); + } else { + query.reset(attempts_[0]->GetQuery()->CloneWithNewId(id)); + } + + const DnsConfig& config = session_->config(); + + unsigned server_index = + (first_server_index_ + attempt_number) % config.nameservers.size(); + // Skip over known failed servers. + server_index = session_->NextGoodServerIndex(server_index); + + scoped_ptr<DnsSession::SocketLease> lease = + session_->AllocateSocket(server_index, net_log_.source()); + + bool got_socket = !!lease.get(); + + DnsUDPAttempt* attempt = + new DnsUDPAttempt(server_index, lease.Pass(), query.Pass()); + + attempts_.push_back(attempt); + ++attempts_count_; + + if (!got_socket) + return AttemptResult(ERR_CONNECTION_REFUSED, NULL); + + net_log_.AddEvent( + NetLog::TYPE_DNS_TRANSACTION_ATTEMPT, + attempt->GetSocketNetLog().source().ToEventParametersCallback()); + + int rv = attempt->Start( + base::Bind(&DnsTransactionImpl::OnUdpAttemptComplete, + base::Unretained(this), attempt_number, + base::TimeTicks::Now())); + if (rv == ERR_IO_PENDING) { + base::TimeDelta timeout = session_->NextTimeout(server_index, + attempt_number); + timer_.Start(FROM_HERE, timeout, this, &DnsTransactionImpl::OnTimeout); + } + return AttemptResult(rv, attempt); + } + + AttemptResult MakeTCPAttempt(const DnsAttempt* previous_attempt) { + DCHECK(previous_attempt); + DCHECK(!had_tcp_attempt_); + + unsigned server_index = previous_attempt->server_index(); + + scoped_ptr<StreamSocket> socket( + session_->CreateTCPSocket(server_index, net_log_.source())); + + // TODO(szym): Reuse the same id to help the server? + uint16 id = session_->NextQueryId(); + scoped_ptr<DnsQuery> query( + previous_attempt->GetQuery()->CloneWithNewId(id)); + + RecordLostPacketsIfAny(); + // Cancel all other attempts, no point waiting on them. + attempts_.clear(); + + unsigned attempt_number = attempts_.size(); + + DnsTCPAttempt* attempt = new DnsTCPAttempt(server_index, socket.Pass(), + query.Pass()); + + attempts_.push_back(attempt); + ++attempts_count_; + had_tcp_attempt_ = true; + + net_log_.AddEvent( + NetLog::TYPE_DNS_TRANSACTION_TCP_ATTEMPT, + attempt->GetSocketNetLog().source().ToEventParametersCallback()); + + int rv = attempt->Start(base::Bind(&DnsTransactionImpl::OnAttemptComplete, + base::Unretained(this), + attempt_number)); + if (rv == ERR_IO_PENDING) { + // Custom timeout for TCP attempt. + base::TimeDelta timeout = timer_.GetCurrentDelay() * 2; + timer_.Start(FROM_HERE, timeout, this, &DnsTransactionImpl::OnTimeout); + } + return AttemptResult(rv, attempt); + } + + // Begins query for the current name. Makes the first attempt. + AttemptResult StartQuery() { + std::string dotted_qname = DNSDomainToString(qnames_.front()); + net_log_.BeginEvent(NetLog::TYPE_DNS_TRANSACTION_QUERY, + NetLog::StringCallback("qname", &dotted_qname)); + + first_server_index_ = session_->NextFirstServerIndex(); + RecordLostPacketsIfAny(); + attempts_.clear(); + had_tcp_attempt_ = false; + return MakeAttempt(); + } + + void OnUdpAttemptComplete(unsigned attempt_number, + base::TimeTicks start, + int rv) { + DCHECK_LT(attempt_number, attempts_.size()); + const DnsAttempt* attempt = attempts_[attempt_number]; + if (attempt->GetResponse()) { + session_->RecordRTT(attempt->server_index(), + base::TimeTicks::Now() - start); + } + OnAttemptComplete(attempt_number, rv); + } + + void OnAttemptComplete(unsigned attempt_number, int rv) { + if (callback_.is_null()) + return; + DCHECK_LT(attempt_number, attempts_.size()); + const DnsAttempt* attempt = attempts_[attempt_number]; + AttemptResult result = ProcessAttemptResult(AttemptResult(rv, attempt)); + if (result.rv != ERR_IO_PENDING) + DoCallback(result); + } + + // Record packet loss for any incomplete attempts. + void RecordLostPacketsIfAny() { + // Loop through attempts until we find first that is completed + size_t first_completed = 0; + for (first_completed = 0; first_completed < attempts_.size(); + ++first_completed) { + if (attempts_[first_completed]->is_completed()) + break; + } + // If there were no completed attempts, then we must be offline, so don't + // record any attempts as lost packets. + if (first_completed == attempts_.size()) + return; + + size_t num_servers = session_->config().nameservers.size(); + std::vector<int> server_attempts(num_servers); + for (size_t i = 0; i < first_completed; ++i) { + unsigned server_index = attempts_[i]->server_index(); + int server_attempt = server_attempts[server_index]++; + // Don't record lost packet unless attempt is in pending state. + if (!attempts_[i]->is_pending()) + continue; + session_->RecordLostPacket(server_index, server_attempt); + } + } + + void LogResponse(const DnsAttempt* attempt) { + if (attempt && attempt->GetResponse()) { + net_log_.AddEvent( + NetLog::TYPE_DNS_TRANSACTION_RESPONSE, + base::Bind(&DnsAttempt::NetLogResponseCallback, + base::Unretained(attempt))); + } + } + + bool MoreAttemptsAllowed() const { + if (had_tcp_attempt_) + return false; + const DnsConfig& config = session_->config(); + return attempts_.size() < config.attempts * config.nameservers.size(); + } + + // Resolves the result of a DnsAttempt until a terminal result is reached + // or it will complete asynchronously (ERR_IO_PENDING). + AttemptResult ProcessAttemptResult(AttemptResult result) { + while (result.rv != ERR_IO_PENDING) { + LogResponse(result.attempt); + + switch (result.rv) { + case OK: + session_->RecordServerSuccess(result.attempt->server_index()); + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_DNS_TRANSACTION_QUERY, + result.rv); + DCHECK(result.attempt); + DCHECK(result.attempt->GetResponse()); + return result; + case ERR_NAME_NOT_RESOLVED: + session_->RecordServerSuccess(result.attempt->server_index()); + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_DNS_TRANSACTION_QUERY, + result.rv); + // Try next suffix. + qnames_.pop_front(); + if (qnames_.empty()) { + return AttemptResult(ERR_NAME_NOT_RESOLVED, NULL); + } else { + result = StartQuery(); + } + break; + case ERR_CONNECTION_REFUSED: + case ERR_DNS_TIMED_OUT: + if (result.attempt) + session_->RecordServerFailure(result.attempt->server_index()); + if (MoreAttemptsAllowed()) { + result = MakeAttempt(); + } else { + return result; + } + break; + case ERR_DNS_SERVER_REQUIRES_TCP: + result = MakeTCPAttempt(result.attempt); + break; + default: + // Server failure. + DCHECK(result.attempt); + if (result.attempt != attempts_.back()) { + // This attempt already timed out. Ignore it. + session_->RecordServerFailure(result.attempt->server_index()); + return AttemptResult(ERR_IO_PENDING, NULL); + } + if (MoreAttemptsAllowed()) { + result = MakeAttempt(); + } else if (result.rv == ERR_DNS_MALFORMED_RESPONSE && + !had_tcp_attempt_) { + // For UDP only, ignore the response and wait until the last attempt + // times out. + return AttemptResult(ERR_IO_PENDING, NULL); + } else { + return AttemptResult(result.rv, NULL); + } + break; + } + } + return result; + } + + void OnTimeout() { + if (callback_.is_null()) + return; + DCHECK(!attempts_.empty()); + AttemptResult result = ProcessAttemptResult( + AttemptResult(ERR_DNS_TIMED_OUT, attempts_.back())); + if (result.rv != ERR_IO_PENDING) + DoCallback(result); + } + + scoped_refptr<DnsSession> session_; + std::string hostname_; + uint16 qtype_; + // Cleared in DoCallback. + DnsTransactionFactory::CallbackType callback_; + + BoundNetLog net_log_; + + // Search list of fully-qualified DNS names to query next (in DNS format). + std::deque<std::string> qnames_; + size_t qnames_initial_size_; + + // List of attempts for the current name. + ScopedVector<DnsAttempt> attempts_; + // Count of attempts, not reset when |attempts_| vector is cleared. + int attempts_count_; + bool had_tcp_attempt_; + + // Index of the first server to try on each search query. + int first_server_index_; + + base::OneShotTimer<DnsTransactionImpl> timer_; + + DISALLOW_COPY_AND_ASSIGN(DnsTransactionImpl); +}; + +// ---------------------------------------------------------------------------- + +// Implementation of DnsTransactionFactory that returns instances of +// DnsTransactionImpl. +class DnsTransactionFactoryImpl : public DnsTransactionFactory { + public: + explicit DnsTransactionFactoryImpl(DnsSession* session) { + session_ = session; + } + + virtual scoped_ptr<DnsTransaction> CreateTransaction( + const std::string& hostname, + uint16 qtype, + const CallbackType& callback, + const BoundNetLog& net_log) OVERRIDE { + return scoped_ptr<DnsTransaction>(new DnsTransactionImpl( + session_.get(), hostname, qtype, callback, net_log)); + } + + private: + scoped_refptr<DnsSession> session_; +}; + +} // namespace + +// static +scoped_ptr<DnsTransactionFactory> DnsTransactionFactory::CreateFactory( + DnsSession* session) { + return scoped_ptr<DnsTransactionFactory>( + new DnsTransactionFactoryImpl(session)); +} + +} // namespace net diff --git a/chromium/net/dns/dns_transaction.h b/chromium/net/dns/dns_transaction.h new file mode 100644 index 00000000000..faf4f64e79d --- /dev/null +++ b/chromium/net/dns/dns_transaction.h @@ -0,0 +1,78 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_DNS_TRANSACTION_H_ +#define NET_DNS_DNS_TRANSACTION_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/callback_forward.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/net_export.h" + +namespace net { + +class BoundNetLog; +class DnsResponse; +class DnsSession; + +// DnsTransaction implements a stub DNS resolver as defined in RFC 1034. +// The DnsTransaction takes care of retransmissions, name server fallback (or +// round-robin), suffix search, and simple response validation ("does it match +// the query") to fight poisoning. +// +// Destroying DnsTransaction cancels the underlying network effort. +class NET_EXPORT_PRIVATE DnsTransaction { + public: + virtual ~DnsTransaction() {} + + // Returns the original |hostname|. + virtual const std::string& GetHostname() const = 0; + + // Returns the |qtype|. + virtual uint16 GetType() const = 0; + + // Starts the transaction. Always completes asynchronously. + virtual void Start() = 0; +}; + +// Creates DnsTransaction which performs asynchronous DNS search. +// It does NOT perform caching, aggregation or prioritization of transactions. +// +// Destroying the factory does NOT affect any already created DnsTransactions. +class NET_EXPORT_PRIVATE DnsTransactionFactory { + public: + // Called with the response or NULL if no matching response was received. + // Note that the |GetDottedName()| of the response may be different than the + // original |hostname| as a result of suffix search. + typedef base::Callback<void(DnsTransaction* transaction, + int neterror, + const DnsResponse* response)> CallbackType; + + virtual ~DnsTransactionFactory() {} + + // Creates DnsTransaction for the given |hostname| and |qtype| (assuming + // QCLASS is IN). |hostname| should be in the dotted form. A dot at the end + // implies the domain name is fully-qualified and will be exempt from suffix + // search. |hostname| should not be an IP literal. + // + // The transaction will run |callback| upon asynchronous completion. + // The |net_log| is used as the parent log. + virtual scoped_ptr<DnsTransaction> CreateTransaction( + const std::string& hostname, + uint16 qtype, + const CallbackType& callback, + const BoundNetLog& net_log) WARN_UNUSED_RESULT = 0; + + // Creates a DnsTransactionFactory which creates DnsTransactionImpl using the + // |session|. + static scoped_ptr<DnsTransactionFactory> CreateFactory( + DnsSession* session) WARN_UNUSED_RESULT; +}; + +} // namespace net + +#endif // NET_DNS_DNS_TRANSACTION_H_ + diff --git a/chromium/net/dns/dns_transaction_unittest.cc b/chromium/net/dns/dns_transaction_unittest.cc new file mode 100644 index 00000000000..7040e44be16 --- /dev/null +++ b/chromium/net/dns/dns_transaction_unittest.cc @@ -0,0 +1,940 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/dns_transaction.h" + +#include "base/bind.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/scoped_vector.h" +#include "base/rand_util.h" +#include "base/sys_byteorder.h" +#include "base/test/test_timeouts.h" +#include "net/base/big_endian.h" +#include "net/base/dns_util.h" +#include "net/base/net_log.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/dns_query.h" +#include "net/dns/dns_response.h" +#include "net/dns/dns_session.h" +#include "net/dns/dns_test_util.h" +#include "net/socket/socket_test_util.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +std::string DomainFromDot(const base::StringPiece& dotted) { + std::string out; + EXPECT_TRUE(DNSDomainFromDot(dotted, &out)); + return out; +} + +// A SocketDataProvider builder. +class DnsSocketData { + public: + // The ctor takes parameters for the DnsQuery. + DnsSocketData(uint16 id, + const char* dotted_name, + uint16 qtype, + IoMode mode, + bool use_tcp) + : query_(new DnsQuery(id, DomainFromDot(dotted_name), qtype)), + use_tcp_(use_tcp) { + if (use_tcp_) { + scoped_ptr<uint16> length(new uint16); + *length = base::HostToNet16(query_->io_buffer()->size()); + writes_.push_back(MockWrite(mode, + reinterpret_cast<const char*>(length.get()), + sizeof(uint16))); + lengths_.push_back(length.release()); + } + writes_.push_back(MockWrite(mode, + query_->io_buffer()->data(), + query_->io_buffer()->size())); + } + ~DnsSocketData() {} + + // All responses must be added before GetProvider. + + // Adds pre-built DnsResponse. |tcp_length| will be used in TCP mode only. + void AddResponseWithLength(scoped_ptr<DnsResponse> response, IoMode mode, + uint16 tcp_length) { + CHECK(!provider_.get()); + if (use_tcp_) { + scoped_ptr<uint16> length(new uint16); + *length = base::HostToNet16(tcp_length); + reads_.push_back(MockRead(mode, + reinterpret_cast<const char*>(length.get()), + sizeof(uint16))); + lengths_.push_back(length.release()); + } + reads_.push_back(MockRead(mode, + response->io_buffer()->data(), + response->io_buffer()->size())); + responses_.push_back(response.release()); + } + + // Adds pre-built DnsResponse. + void AddResponse(scoped_ptr<DnsResponse> response, IoMode mode) { + uint16 tcp_length = response->io_buffer()->size(); + AddResponseWithLength(response.Pass(), mode, tcp_length); + } + + // Adds pre-built response from |data| buffer. + void AddResponseData(const uint8* data, size_t length, IoMode mode) { + CHECK(!provider_.get()); + AddResponse(make_scoped_ptr( + new DnsResponse(reinterpret_cast<const char*>(data), length, 0)), mode); + } + + // Add no-answer (RCODE only) response matching the query. + void AddRcode(int rcode, IoMode mode) { + scoped_ptr<DnsResponse> response( + new DnsResponse(query_->io_buffer()->data(), + query_->io_buffer()->size(), + 0)); + dns_protocol::Header* header = + reinterpret_cast<dns_protocol::Header*>(response->io_buffer()->data()); + header->flags |= base::HostToNet16(dns_protocol::kFlagResponse | rcode); + AddResponse(response.Pass(), mode); + } + + // Build, if needed, and return the SocketDataProvider. No new responses + // should be added afterwards. + SocketDataProvider* GetProvider() { + if (provider_.get()) + return provider_.get(); + // Terminate the reads with ERR_IO_PENDING to prevent overrun and default to + // timeout. + reads_.push_back(MockRead(ASYNC, ERR_IO_PENDING)); + provider_.reset(new DelayedSocketData(1, &reads_[0], reads_.size(), + &writes_[0], writes_.size())); + if (use_tcp_) { + provider_->set_connect_data(MockConnect(reads_[0].mode, OK)); + } + return provider_.get(); + } + + uint16 query_id() const { + return query_->id(); + } + + // Returns true if the expected query was written to the socket. + bool was_written() const { + CHECK(provider_.get()); + return provider_->write_index() > 0; + } + + private: + scoped_ptr<DnsQuery> query_; + bool use_tcp_; + ScopedVector<uint16> lengths_; + ScopedVector<DnsResponse> responses_; + std::vector<MockWrite> writes_; + std::vector<MockRead> reads_; + scoped_ptr<DelayedSocketData> provider_; + + DISALLOW_COPY_AND_ASSIGN(DnsSocketData); +}; + +class TestSocketFactory; + +// A variant of MockUDPClientSocket which always fails to Connect. +class FailingUDPClientSocket : public MockUDPClientSocket { + public: + FailingUDPClientSocket(SocketDataProvider* data, + net::NetLog* net_log) + : MockUDPClientSocket(data, net_log) { + } + virtual ~FailingUDPClientSocket() {} + virtual int Connect(const IPEndPoint& endpoint) OVERRIDE { + return ERR_CONNECTION_REFUSED; + } + + private: + DISALLOW_COPY_AND_ASSIGN(FailingUDPClientSocket); +}; + +// A variant of MockUDPClientSocket which notifies the factory OnConnect. +class TestUDPClientSocket : public MockUDPClientSocket { + public: + TestUDPClientSocket(TestSocketFactory* factory, + SocketDataProvider* data, + net::NetLog* net_log) + : MockUDPClientSocket(data, net_log), factory_(factory) { + } + virtual ~TestUDPClientSocket() {} + virtual int Connect(const IPEndPoint& endpoint) OVERRIDE; + + private: + TestSocketFactory* factory_; + + DISALLOW_COPY_AND_ASSIGN(TestUDPClientSocket); +}; + +// Creates TestUDPClientSockets and keeps endpoints reported via OnConnect. +class TestSocketFactory : public MockClientSocketFactory { + public: + TestSocketFactory() : fail_next_socket_(false) {} + virtual ~TestSocketFactory() {} + + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + net::NetLog* net_log, + const net::NetLog::Source& source) OVERRIDE { + if (fail_next_socket_) { + fail_next_socket_ = false; + return scoped_ptr<DatagramClientSocket>( + new FailingUDPClientSocket(&empty_data_, net_log)); + } + SocketDataProvider* data_provider = mock_data().GetNext(); + scoped_ptr<TestUDPClientSocket> socket( + new TestUDPClientSocket(this, data_provider, net_log)); + data_provider->set_socket(socket.get()); + return socket.PassAs<DatagramClientSocket>(); + } + + void OnConnect(const IPEndPoint& endpoint) { + remote_endpoints_.push_back(endpoint); + } + + std::vector<IPEndPoint> remote_endpoints_; + bool fail_next_socket_; + + private: + StaticSocketDataProvider empty_data_; + + DISALLOW_COPY_AND_ASSIGN(TestSocketFactory); +}; + +int TestUDPClientSocket::Connect(const IPEndPoint& endpoint) { + factory_->OnConnect(endpoint); + return MockUDPClientSocket::Connect(endpoint); +} + +// Helper class that holds a DnsTransaction and handles OnTransactionComplete. +class TransactionHelper { + public: + // If |expected_answer_count| < 0 then it is the expected net error. + TransactionHelper(const char* hostname, + uint16 qtype, + int expected_answer_count) + : hostname_(hostname), + qtype_(qtype), + expected_answer_count_(expected_answer_count), + cancel_in_callback_(false), + quit_in_callback_(false), + completed_(false) { + } + + // Mark that the transaction shall be destroyed immediately upon callback. + void set_cancel_in_callback() { + cancel_in_callback_ = true; + } + + // Mark to call MessageLoop::Quit() upon callback. + void set_quit_in_callback() { + quit_in_callback_ = true; + } + + void StartTransaction(DnsTransactionFactory* factory) { + EXPECT_EQ(NULL, transaction_.get()); + transaction_ = factory->CreateTransaction( + hostname_, + qtype_, + base::Bind(&TransactionHelper::OnTransactionComplete, + base::Unretained(this)), + BoundNetLog()); + EXPECT_EQ(hostname_, transaction_->GetHostname()); + EXPECT_EQ(qtype_, transaction_->GetType()); + transaction_->Start(); + } + + void Cancel() { + ASSERT_TRUE(transaction_.get() != NULL); + transaction_.reset(NULL); + } + + void OnTransactionComplete(DnsTransaction* t, + int rv, + const DnsResponse* response) { + EXPECT_FALSE(completed_); + EXPECT_EQ(transaction_.get(), t); + + completed_ = true; + + if (cancel_in_callback_) { + Cancel(); + return; + } + + // Tell MessageLoop to quit now, in case any ASSERT_* fails. + if (quit_in_callback_) + base::MessageLoop::current()->Quit(); + + if (expected_answer_count_ >= 0) { + ASSERT_EQ(OK, rv); + ASSERT_TRUE(response != NULL); + EXPECT_EQ(static_cast<unsigned>(expected_answer_count_), + response->answer_count()); + EXPECT_EQ(qtype_, response->qtype()); + + DnsRecordParser parser = response->Parser(); + DnsResourceRecord record; + for (int i = 0; i < expected_answer_count_; ++i) { + EXPECT_TRUE(parser.ReadRecord(&record)); + } + } else { + EXPECT_EQ(expected_answer_count_, rv); + } + } + + bool has_completed() const { + return completed_; + } + + // Shorthands for commonly used commands. + + bool Run(DnsTransactionFactory* factory) { + StartTransaction(factory); + base::MessageLoop::current()->RunUntilIdle(); + return has_completed(); + } + + // Use when some of the responses are timeouts. + bool RunUntilDone(DnsTransactionFactory* factory) { + set_quit_in_callback(); + StartTransaction(factory); + base::MessageLoop::current()->Run(); + return has_completed(); + } + + private: + std::string hostname_; + uint16 qtype_; + scoped_ptr<DnsTransaction> transaction_; + int expected_answer_count_; + bool cancel_in_callback_; + bool quit_in_callback_; + + bool completed_; +}; + +class DnsTransactionTest : public testing::Test { + public: + DnsTransactionTest() {} + + // Generates |nameservers| for DnsConfig. + void ConfigureNumServers(unsigned num_servers) { + CHECK_LE(num_servers, 255u); + config_.nameservers.clear(); + IPAddressNumber dns_ip; + { + bool rv = ParseIPLiteralToNumber("192.168.1.0", &dns_ip); + EXPECT_TRUE(rv); + } + for (unsigned i = 0; i < num_servers; ++i) { + dns_ip[3] = i; + config_.nameservers.push_back(IPEndPoint(dns_ip, + dns_protocol::kDefaultPort)); + } + } + + // Called after fully configuring |config|. + void ConfigureFactory() { + socket_factory_.reset(new TestSocketFactory()); + session_ = new DnsSession( + config_, + DnsSocketPool::CreateNull(socket_factory_.get()), + base::Bind(&DnsTransactionTest::GetNextId, base::Unretained(this)), + NULL /* NetLog */); + transaction_factory_ = DnsTransactionFactory::CreateFactory(session_.get()); + } + + void AddSocketData(scoped_ptr<DnsSocketData> data) { + CHECK(socket_factory_.get()); + transaction_ids_.push_back(data->query_id()); + socket_factory_->AddSocketDataProvider(data->GetProvider()); + socket_data_.push_back(data.release()); + } + + // Add expected query for |dotted_name| and |qtype| with |id| and response + // taken verbatim from |data| of |data_length| bytes. The transaction id in + // |data| should equal |id|, unless testing mismatched response. + void AddQueryAndResponse(uint16 id, + const char* dotted_name, + uint16 qtype, + const uint8* response_data, + size_t response_length, + IoMode mode, + bool use_tcp) { + CHECK(socket_factory_.get()); + scoped_ptr<DnsSocketData> data( + new DnsSocketData(id, dotted_name, qtype, mode, use_tcp)); + data->AddResponseData(response_data, response_length, mode); + AddSocketData(data.Pass()); + } + + void AddAsyncQueryAndResponse(uint16 id, + const char* dotted_name, + uint16 qtype, + const uint8* data, + size_t data_length) { + AddQueryAndResponse(id, dotted_name, qtype, data, data_length, ASYNC, + false); + } + + void AddSyncQueryAndResponse(uint16 id, + const char* dotted_name, + uint16 qtype, + const uint8* data, + size_t data_length) { + AddQueryAndResponse(id, dotted_name, qtype, data, data_length, SYNCHRONOUS, + false); + } + + // Add expected query of |dotted_name| and |qtype| and no response. + void AddQueryAndTimeout(const char* dotted_name, uint16 qtype) { + uint16 id = base::RandInt(0, kuint16max); + scoped_ptr<DnsSocketData> data( + new DnsSocketData(id, dotted_name, qtype, ASYNC, false)); + AddSocketData(data.Pass()); + } + + // Add expected query of |dotted_name| and |qtype| and matching response with + // no answer and RCODE set to |rcode|. The id will be generated randomly. + void AddQueryAndRcode(const char* dotted_name, + uint16 qtype, + int rcode, + IoMode mode, + bool use_tcp) { + CHECK_NE(dns_protocol::kRcodeNOERROR, rcode); + uint16 id = base::RandInt(0, kuint16max); + scoped_ptr<DnsSocketData> data( + new DnsSocketData(id, dotted_name, qtype, mode, use_tcp)); + data->AddRcode(rcode, mode); + AddSocketData(data.Pass()); + } + + void AddAsyncQueryAndRcode(const char* dotted_name, uint16 qtype, int rcode) { + AddQueryAndRcode(dotted_name, qtype, rcode, ASYNC, false); + } + + void AddSyncQueryAndRcode(const char* dotted_name, uint16 qtype, int rcode) { + AddQueryAndRcode(dotted_name, qtype, rcode, SYNCHRONOUS, false); + } + + // Checks if the sockets were connected in the order matching the indices in + // |servers|. + void CheckServerOrder(const unsigned* servers, size_t num_attempts) { + ASSERT_EQ(num_attempts, socket_factory_->remote_endpoints_.size()); + for (size_t i = 0; i < num_attempts; ++i) { + EXPECT_EQ(socket_factory_->remote_endpoints_[i], + session_->config().nameservers[servers[i]]); + } + } + + virtual void SetUp() OVERRIDE { + // By default set one server, + ConfigureNumServers(1); + // and no retransmissions, + config_.attempts = 1; + // but long enough timeout for memory tests. + config_.timeout = TestTimeouts::action_timeout(); + ConfigureFactory(); + } + + virtual void TearDown() OVERRIDE { + // Check that all socket data was at least written to. + for (size_t i = 0; i < socket_data_.size(); ++i) { + EXPECT_TRUE(socket_data_[i]->was_written()) << i; + } + } + + protected: + int GetNextId(int min, int max) { + EXPECT_FALSE(transaction_ids_.empty()); + int id = transaction_ids_.front(); + transaction_ids_.pop_front(); + EXPECT_GE(id, min); + EXPECT_LE(id, max); + return id; + } + + DnsConfig config_; + + ScopedVector<DnsSocketData> socket_data_; + + std::deque<int> transaction_ids_; + scoped_ptr<TestSocketFactory> socket_factory_; + scoped_refptr<DnsSession> session_; + scoped_ptr<DnsTransactionFactory> transaction_factory_; +}; + +TEST_F(DnsTransactionTest, Lookup) { + AddAsyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype, + kT0ResponseDatagram, arraysize(kT0ResponseDatagram)); + + TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount); + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); +} + +// Concurrent lookup tests assume that DnsTransaction::Start immediately +// consumes a socket from ClientSocketFactory. +TEST_F(DnsTransactionTest, ConcurrentLookup) { + AddAsyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype, + kT0ResponseDatagram, arraysize(kT0ResponseDatagram)); + AddAsyncQueryAndResponse(1 /* id */, kT1HostName, kT1Qtype, + kT1ResponseDatagram, arraysize(kT1ResponseDatagram)); + + TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount); + helper0.StartTransaction(transaction_factory_.get()); + TransactionHelper helper1(kT1HostName, kT1Qtype, kT1RecordCount); + helper1.StartTransaction(transaction_factory_.get()); + + base::MessageLoop::current()->RunUntilIdle(); + + EXPECT_TRUE(helper0.has_completed()); + EXPECT_TRUE(helper1.has_completed()); +} + +TEST_F(DnsTransactionTest, CancelLookup) { + AddAsyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype, + kT0ResponseDatagram, arraysize(kT0ResponseDatagram)); + AddAsyncQueryAndResponse(1 /* id */, kT1HostName, kT1Qtype, + kT1ResponseDatagram, arraysize(kT1ResponseDatagram)); + + TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount); + helper0.StartTransaction(transaction_factory_.get()); + TransactionHelper helper1(kT1HostName, kT1Qtype, kT1RecordCount); + helper1.StartTransaction(transaction_factory_.get()); + + helper0.Cancel(); + + base::MessageLoop::current()->RunUntilIdle(); + + EXPECT_FALSE(helper0.has_completed()); + EXPECT_TRUE(helper1.has_completed()); +} + +TEST_F(DnsTransactionTest, DestroyFactory) { + AddAsyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype, + kT0ResponseDatagram, arraysize(kT0ResponseDatagram)); + + TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount); + helper0.StartTransaction(transaction_factory_.get()); + + // Destroying the client does not affect running requests. + transaction_factory_.reset(NULL); + + base::MessageLoop::current()->RunUntilIdle(); + + EXPECT_TRUE(helper0.has_completed()); +} + +TEST_F(DnsTransactionTest, CancelFromCallback) { + AddAsyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype, + kT0ResponseDatagram, arraysize(kT0ResponseDatagram)); + + TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount); + helper0.set_cancel_in_callback(); + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, MismatchedResponseSync) { + config_.attempts = 2; + config_.timeout = TestTimeouts::tiny_timeout(); + ConfigureFactory(); + + // Attempt receives mismatched response followed by valid response. + scoped_ptr<DnsSocketData> data( + new DnsSocketData(0 /* id */, kT0HostName, kT0Qtype, SYNCHRONOUS, false)); + data->AddResponseData(kT1ResponseDatagram, + arraysize(kT1ResponseDatagram), SYNCHRONOUS); + data->AddResponseData(kT0ResponseDatagram, + arraysize(kT0ResponseDatagram), SYNCHRONOUS); + AddSocketData(data.Pass()); + + TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount); + EXPECT_TRUE(helper0.RunUntilDone(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, MismatchedResponseAsync) { + config_.attempts = 2; + config_.timeout = TestTimeouts::tiny_timeout(); + ConfigureFactory(); + + // First attempt receives mismatched response followed by valid response. + // Second attempt times out. + scoped_ptr<DnsSocketData> data( + new DnsSocketData(0 /* id */, kT0HostName, kT0Qtype, ASYNC, false)); + data->AddResponseData(kT1ResponseDatagram, + arraysize(kT1ResponseDatagram), ASYNC); + data->AddResponseData(kT0ResponseDatagram, + arraysize(kT0ResponseDatagram), ASYNC); + AddSocketData(data.Pass()); + AddQueryAndTimeout(kT0HostName, kT0Qtype); + + TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount); + EXPECT_TRUE(helper0.RunUntilDone(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, MismatchedResponseFail) { + config_.timeout = TestTimeouts::tiny_timeout(); + ConfigureFactory(); + + // Attempt receives mismatched response but times out because only one attempt + // is allowed. + AddAsyncQueryAndResponse(1 /* id */, kT0HostName, kT0Qtype, + kT0ResponseDatagram, arraysize(kT0ResponseDatagram)); + + TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_DNS_TIMED_OUT); + EXPECT_TRUE(helper0.RunUntilDone(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, ServerFail) { + AddAsyncQueryAndRcode(kT0HostName, kT0Qtype, dns_protocol::kRcodeSERVFAIL); + + TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_DNS_SERVER_FAILED); + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, NoDomain) { + AddAsyncQueryAndRcode(kT0HostName, kT0Qtype, dns_protocol::kRcodeNXDOMAIN); + + TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_NAME_NOT_RESOLVED); + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, Timeout) { + config_.attempts = 3; + // Use short timeout to speed up the test. + config_.timeout = TestTimeouts::tiny_timeout(); + ConfigureFactory(); + + AddQueryAndTimeout(kT0HostName, kT0Qtype); + AddQueryAndTimeout(kT0HostName, kT0Qtype); + AddQueryAndTimeout(kT0HostName, kT0Qtype); + + TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_DNS_TIMED_OUT); + EXPECT_TRUE(helper0.RunUntilDone(transaction_factory_.get())); + EXPECT_TRUE(base::MessageLoop::current()->IsIdleForTesting()); +} + +TEST_F(DnsTransactionTest, ServerFallbackAndRotate) { + // Test that we fallback on both server failure and timeout. + config_.attempts = 2; + // The next request should start from the next server. + config_.rotate = true; + ConfigureNumServers(3); + // Use short timeout to speed up the test. + config_.timeout = TestTimeouts::tiny_timeout(); + ConfigureFactory(); + + // Responses for first request. + AddQueryAndTimeout(kT0HostName, kT0Qtype); + AddAsyncQueryAndRcode(kT0HostName, kT0Qtype, dns_protocol::kRcodeSERVFAIL); + AddQueryAndTimeout(kT0HostName, kT0Qtype); + AddAsyncQueryAndRcode(kT0HostName, kT0Qtype, dns_protocol::kRcodeSERVFAIL); + AddAsyncQueryAndRcode(kT0HostName, kT0Qtype, dns_protocol::kRcodeNXDOMAIN); + // Responses for second request. + AddAsyncQueryAndRcode(kT1HostName, kT1Qtype, dns_protocol::kRcodeSERVFAIL); + AddAsyncQueryAndRcode(kT1HostName, kT1Qtype, dns_protocol::kRcodeSERVFAIL); + AddAsyncQueryAndRcode(kT1HostName, kT1Qtype, dns_protocol::kRcodeNXDOMAIN); + + TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_NAME_NOT_RESOLVED); + TransactionHelper helper1(kT1HostName, kT1Qtype, ERR_NAME_NOT_RESOLVED); + + EXPECT_TRUE(helper0.RunUntilDone(transaction_factory_.get())); + EXPECT_TRUE(helper1.Run(transaction_factory_.get())); + + unsigned kOrder[] = { + 0, 1, 2, 0, 1, // The first transaction. + 1, 2, 0, // The second transaction starts from the next server. + }; + CheckServerOrder(kOrder, arraysize(kOrder)); +} + +TEST_F(DnsTransactionTest, SuffixSearchAboveNdots) { + config_.ndots = 2; + config_.search.push_back("a"); + config_.search.push_back("b"); + config_.search.push_back("c"); + config_.rotate = true; + ConfigureNumServers(2); + ConfigureFactory(); + + AddAsyncQueryAndRcode("x.y.z", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + AddAsyncQueryAndRcode("x.y.z.a", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + AddAsyncQueryAndRcode("x.y.z.b", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + AddAsyncQueryAndRcode("x.y.z.c", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + + TransactionHelper helper0("x.y.z", dns_protocol::kTypeA, + ERR_NAME_NOT_RESOLVED); + + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); + + // Also check if suffix search causes server rotation. + unsigned kOrder0[] = { 0, 1, 0, 1 }; + CheckServerOrder(kOrder0, arraysize(kOrder0)); +} + +TEST_F(DnsTransactionTest, SuffixSearchBelowNdots) { + config_.ndots = 2; + config_.search.push_back("a"); + config_.search.push_back("b"); + config_.search.push_back("c"); + ConfigureFactory(); + + // Responses for first transaction. + AddAsyncQueryAndRcode("x.y.a", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + AddAsyncQueryAndRcode("x.y.b", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + AddAsyncQueryAndRcode("x.y.c", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + AddAsyncQueryAndRcode("x.y", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + // Responses for second transaction. + AddAsyncQueryAndRcode("x.a", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + AddAsyncQueryAndRcode("x.b", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + AddAsyncQueryAndRcode("x.c", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + // Responses for third transaction. + AddAsyncQueryAndRcode("x", dns_protocol::kTypeAAAA, + dns_protocol::kRcodeNXDOMAIN); + + TransactionHelper helper0("x.y", dns_protocol::kTypeA, ERR_NAME_NOT_RESOLVED); + + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); + + // A single-label name. + TransactionHelper helper1("x", dns_protocol::kTypeA, ERR_NAME_NOT_RESOLVED); + + EXPECT_TRUE(helper1.Run(transaction_factory_.get())); + + // A fully-qualified name. + TransactionHelper helper2("x.", dns_protocol::kTypeAAAA, + ERR_NAME_NOT_RESOLVED); + + EXPECT_TRUE(helper2.Run(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, EmptySuffixSearch) { + // Responses for first transaction. + AddAsyncQueryAndRcode("x", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + + // A fully-qualified name. + TransactionHelper helper0("x.", dns_protocol::kTypeA, ERR_NAME_NOT_RESOLVED); + + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); + + // A single label name is not even attempted. + TransactionHelper helper1("singlelabel", dns_protocol::kTypeA, + ERR_DNS_SEARCH_EMPTY); + + helper1.Run(transaction_factory_.get()); + EXPECT_TRUE(helper1.has_completed()); +} + +TEST_F(DnsTransactionTest, DontAppendToMultiLabelName) { + config_.search.push_back("a"); + config_.search.push_back("b"); + config_.search.push_back("c"); + config_.append_to_multi_label_name = false; + ConfigureFactory(); + + // Responses for first transaction. + AddAsyncQueryAndRcode("x.y.z", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + // Responses for second transaction. + AddAsyncQueryAndRcode("x.y", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + // Responses for third transaction. + AddAsyncQueryAndRcode("x.a", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + AddAsyncQueryAndRcode("x.b", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + AddAsyncQueryAndRcode("x.c", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + + TransactionHelper helper0("x.y.z", dns_protocol::kTypeA, + ERR_NAME_NOT_RESOLVED); + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); + + TransactionHelper helper1("x.y", dns_protocol::kTypeA, ERR_NAME_NOT_RESOLVED); + EXPECT_TRUE(helper1.Run(transaction_factory_.get())); + + TransactionHelper helper2("x", dns_protocol::kTypeA, ERR_NAME_NOT_RESOLVED); + EXPECT_TRUE(helper2.Run(transaction_factory_.get())); +} + +const uint8 kResponseNoData[] = { + 0x00, 0x00, 0x81, 0x80, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, + // Question + 0x01, 'x', 0x01, 'y', 0x01, 'z', 0x01, 'b', 0x00, 0x00, 0x01, 0x00, 0x01, + // Authority section, SOA record, TTL 0x3E6 + 0x01, 'z', 0x00, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x03, 0xE6, + // Minimal RDATA, 18 bytes + 0x00, 0x12, + 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, +}; + +TEST_F(DnsTransactionTest, SuffixSearchStop) { + config_.ndots = 2; + config_.search.push_back("a"); + config_.search.push_back("b"); + config_.search.push_back("c"); + ConfigureFactory(); + + AddAsyncQueryAndRcode("x.y.z", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + AddAsyncQueryAndRcode("x.y.z.a", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + AddAsyncQueryAndResponse(0 /* id */, "x.y.z.b", dns_protocol::kTypeA, + kResponseNoData, arraysize(kResponseNoData)); + + TransactionHelper helper0("x.y.z", dns_protocol::kTypeA, 0 /* answers */); + + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, SyncFirstQuery) { + config_.search.push_back("lab.ccs.neu.edu"); + config_.search.push_back("ccs.neu.edu"); + ConfigureFactory(); + + AddSyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype, + kT0ResponseDatagram, arraysize(kT0ResponseDatagram)); + + TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount); + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, SyncFirstQueryWithSearch) { + config_.search.push_back("lab.ccs.neu.edu"); + config_.search.push_back("ccs.neu.edu"); + ConfigureFactory(); + + AddSyncQueryAndRcode("www.lab.ccs.neu.edu", kT2Qtype, + dns_protocol::kRcodeNXDOMAIN); + // "www.ccs.neu.edu" + AddAsyncQueryAndResponse(2 /* id */, kT2HostName, kT2Qtype, + kT2ResponseDatagram, arraysize(kT2ResponseDatagram)); + + TransactionHelper helper0("www", kT2Qtype, kT2RecordCount); + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, SyncSearchQuery) { + config_.search.push_back("lab.ccs.neu.edu"); + config_.search.push_back("ccs.neu.edu"); + ConfigureFactory(); + + AddAsyncQueryAndRcode("www.lab.ccs.neu.edu", dns_protocol::kTypeA, + dns_protocol::kRcodeNXDOMAIN); + AddSyncQueryAndResponse(2 /* id */, kT2HostName, kT2Qtype, + kT2ResponseDatagram, arraysize(kT2ResponseDatagram)); + + TransactionHelper helper0("www", kT2Qtype, kT2RecordCount); + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, ConnectFailure) { + socket_factory_->fail_next_socket_ = true; + transaction_ids_.push_back(0); // Needed to make a DnsUDPAttempt. + TransactionHelper helper0("www.chromium.org", dns_protocol::kTypeA, + ERR_CONNECTION_REFUSED); + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, ConnectFailureFollowedBySuccess) { + // Retry after server failure. + config_.attempts = 2; + ConfigureFactory(); + // First server connection attempt fails. + transaction_ids_.push_back(0); // Needed to make a DnsUDPAttempt. + socket_factory_->fail_next_socket_ = true; + // Second DNS query succeeds. + AddAsyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype, + kT0ResponseDatagram, arraysize(kT0ResponseDatagram)); + TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount); + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, TCPLookup) { + AddAsyncQueryAndRcode(kT0HostName, kT0Qtype, + dns_protocol::kRcodeNOERROR | dns_protocol::kFlagTC); + AddQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype, + kT0ResponseDatagram, arraysize(kT0ResponseDatagram), + ASYNC, true /* use_tcp */); + + TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount); + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, TCPFailure) { + AddAsyncQueryAndRcode(kT0HostName, kT0Qtype, + dns_protocol::kRcodeNOERROR | dns_protocol::kFlagTC); + AddQueryAndRcode(kT0HostName, kT0Qtype, dns_protocol::kRcodeSERVFAIL, + ASYNC, true /* use_tcp */); + + TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_DNS_SERVER_FAILED); + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, TCPMalformed) { + AddAsyncQueryAndRcode(kT0HostName, kT0Qtype, + dns_protocol::kRcodeNOERROR | dns_protocol::kFlagTC); + scoped_ptr<DnsSocketData> data( + new DnsSocketData(0 /* id */, kT0HostName, kT0Qtype, ASYNC, true)); + // Valid response but length too short. + data->AddResponseWithLength( + make_scoped_ptr( + new DnsResponse(reinterpret_cast<const char*>(kT0ResponseDatagram), + arraysize(kT0ResponseDatagram), 0)), + ASYNC, + static_cast<uint16>(kT0QuerySize - 1)); + AddSocketData(data.Pass()); + + TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_DNS_MALFORMED_RESPONSE); + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, TCPTimeout) { + config_.timeout = TestTimeouts::tiny_timeout(); + ConfigureFactory(); + AddAsyncQueryAndRcode(kT0HostName, kT0Qtype, + dns_protocol::kRcodeNOERROR | dns_protocol::kFlagTC); + AddSocketData(make_scoped_ptr( + new DnsSocketData(1 /* id */, kT0HostName, kT0Qtype, ASYNC, true))); + + TransactionHelper helper0(kT0HostName, kT0Qtype, ERR_DNS_TIMED_OUT); + EXPECT_TRUE(helper0.RunUntilDone(transaction_factory_.get())); +} + +TEST_F(DnsTransactionTest, InvalidQuery) { + config_.timeout = TestTimeouts::tiny_timeout(); + ConfigureFactory(); + + TransactionHelper helper0(".", dns_protocol::kTypeA, ERR_INVALID_ARGUMENT); + EXPECT_TRUE(helper0.Run(transaction_factory_.get())); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/dns/host_cache.cc b/chromium/net/dns/host_cache.cc new file mode 100644 index 00000000000..0e6ff15cd5d --- /dev/null +++ b/chromium/net/dns/host_cache.cc @@ -0,0 +1,122 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/host_cache.h" + +#include "base/logging.h" +#include "base/metrics/field_trial.h" +#include "base/metrics/histogram.h" +#include "base/strings/string_number_conversions.h" +#include "net/base/net_errors.h" + +namespace net { + +//----------------------------------------------------------------------------- + +HostCache::Entry::Entry(int error, const AddressList& addrlist, + base::TimeDelta ttl) + : error(error), + addrlist(addrlist), + ttl(ttl) { + DCHECK(ttl >= base::TimeDelta()); +} + +HostCache::Entry::Entry(int error, const AddressList& addrlist) + : error(error), + addrlist(addrlist), + ttl(base::TimeDelta::FromSeconds(-1)) { +} + +HostCache::Entry::~Entry() { +} + +//----------------------------------------------------------------------------- + +HostCache::HostCache(size_t max_entries) + : entries_(max_entries) { +} + +HostCache::~HostCache() { +} + +const HostCache::Entry* HostCache::Lookup(const Key& key, + base::TimeTicks now) { + DCHECK(CalledOnValidThread()); + if (caching_is_disabled()) + return NULL; + + return entries_.Get(key, now); +} + +void HostCache::Set(const Key& key, + const Entry& entry, + base::TimeTicks now, + base::TimeDelta ttl) { + DCHECK(CalledOnValidThread()); + if (caching_is_disabled()) + return; + + entries_.Put(key, entry, now, now + ttl); +} + +void HostCache::clear() { + DCHECK(CalledOnValidThread()); + entries_.Clear(); +} + +size_t HostCache::size() const { + DCHECK(CalledOnValidThread()); + return entries_.size(); +} + +size_t HostCache::max_entries() const { + DCHECK(CalledOnValidThread()); + return entries_.max_entries(); +} + +// Note that this map may contain expired entries. +const HostCache::EntryMap& HostCache::entries() const { + DCHECK(CalledOnValidThread()); + return entries_; +} + +// static +scoped_ptr<HostCache> HostCache::CreateDefaultCache() { + // Cache capacity is determined by the field trial. +#if defined(ENABLE_BUILT_IN_DNS) + const size_t kDefaultMaxEntries = 1000; +#else + const size_t kDefaultMaxEntries = 100; +#endif + const size_t kSaneMaxEntries = 1 << 20; + size_t max_entries = 0; + base::StringToSizeT(base::FieldTrialList::FindFullName("HostCacheSize"), + &max_entries); + if ((max_entries == 0) || (max_entries > kSaneMaxEntries)) + max_entries = kDefaultMaxEntries; + return make_scoped_ptr(new HostCache(max_entries)); +} + +void HostCache::EvictionHandler::Handle( + const Key& key, + const Entry& entry, + const base::TimeTicks& expiration, + const base::TimeTicks& now, + bool on_get) const { + if (on_get) { + DCHECK(now >= expiration); + UMA_HISTOGRAM_CUSTOM_TIMES("DNS.CacheExpiredOnGet", now - expiration, + base::TimeDelta::FromSeconds(1), base::TimeDelta::FromDays(1), 100); + return; + } + if (expiration > now) { + UMA_HISTOGRAM_CUSTOM_TIMES("DNS.CacheEvicted", expiration - now, + base::TimeDelta::FromSeconds(1), base::TimeDelta::FromDays(1), 100); + } else { + UMA_HISTOGRAM_CUSTOM_TIMES("DNS.CacheExpired", now - expiration, + base::TimeDelta::FromSeconds(1), base::TimeDelta::FromDays(1), 100); + } +} + +} // namespace net diff --git a/chromium/net/dns/host_cache.h b/chromium/net/dns/host_cache.h new file mode 100644 index 00000000000..e8628fb04ec --- /dev/null +++ b/chromium/net/dns/host_cache.h @@ -0,0 +1,124 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_HOST_CACHE_H_ +#define NET_DNS_HOST_CACHE_H_ + +#include <functional> +#include <string> + +#include "base/gtest_prod_util.h" +#include "base/memory/scoped_ptr.h" +#include "base/threading/non_thread_safe.h" +#include "base/time/time.h" +#include "net/base/address_family.h" +#include "net/base/address_list.h" +#include "net/base/expiring_cache.h" +#include "net/base/net_export.h" + +namespace net { + +// Cache used by HostResolver to map hostnames to their resolved result. +class NET_EXPORT HostCache : NON_EXPORTED_BASE(public base::NonThreadSafe) { + public: + // Stores the latest address list that was looked up for a hostname. + struct NET_EXPORT Entry { + Entry(int error, const AddressList& addrlist, base::TimeDelta ttl); + // Use when |ttl| is unknown. + Entry(int error, const AddressList& addrlist); + ~Entry(); + + bool has_ttl() const { return ttl >= base::TimeDelta(); } + + // The resolve results for this entry. + int error; + AddressList addrlist; + // TTL obtained from the nameserver. Negative if unknown. + base::TimeDelta ttl; + }; + + struct Key { + Key(const std::string& hostname, AddressFamily address_family, + HostResolverFlags host_resolver_flags) + : hostname(hostname), + address_family(address_family), + host_resolver_flags(host_resolver_flags) {} + + bool operator<(const Key& other) const { + // |address_family| and |host_resolver_flags| are compared before + // |hostname| under assumption that integer comparisons are faster than + // string comparisons. + if (address_family != other.address_family) + return address_family < other.address_family; + if (host_resolver_flags != other.host_resolver_flags) + return host_resolver_flags < other.host_resolver_flags; + return hostname < other.hostname; + } + + std::string hostname; + AddressFamily address_family; + HostResolverFlags host_resolver_flags; + }; + + struct EvictionHandler { + void Handle(const Key& key, + const Entry& entry, + const base::TimeTicks& expiration, + const base::TimeTicks& now, + bool onGet) const; + }; + + typedef ExpiringCache<Key, Entry, base::TimeTicks, + std::less<base::TimeTicks>, + EvictionHandler> EntryMap; + + // Constructs a HostCache that stores up to |max_entries|. + explicit HostCache(size_t max_entries); + + ~HostCache(); + + // Returns a pointer to the entry for |key|, which is valid at time + // |now|. If there is no such entry, returns NULL. + const Entry* Lookup(const Key& key, base::TimeTicks now); + + // Overwrites or creates an entry for |key|. + // |entry| is the value to set, |now| is the current time + // |ttl| is the "time to live". + void Set(const Key& key, + const Entry& entry, + base::TimeTicks now, + base::TimeDelta ttl); + + // Empties the cache + void clear(); + + // Returns the number of entries in the cache. + size_t size() const; + + // Following are used by net_internals UI. + size_t max_entries() const; + + const EntryMap& entries() const; + + // Creates a default cache. + static scoped_ptr<HostCache> CreateDefaultCache(); + + private: + FRIEND_TEST_ALL_PREFIXES(HostCacheTest, NoCache); + + // Returns true if this HostCache can contain no entries. + bool caching_is_disabled() const { + return entries_.max_entries() == 0; + } + + // Map from hostname (presumably in lowercase canonicalized format) to + // a resolved result entry. + EntryMap entries_; + + DISALLOW_COPY_AND_ASSIGN(HostCache); +}; + +} // namespace net + +#endif // NET_DNS_HOST_CACHE_H_ diff --git a/chromium/net/dns/host_cache_unittest.cc b/chromium/net/dns/host_cache_unittest.cc new file mode 100644 index 00000000000..34309c1223c --- /dev/null +++ b/chromium/net/dns/host_cache_unittest.cc @@ -0,0 +1,388 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/host_cache.h" + +#include "base/format_macros.h" +#include "base/stl_util.h" +#include "base/strings/string_util.h" +#include "base/strings/stringprintf.h" +#include "net/base/net_errors.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +const int kMaxCacheEntries = 10; + +// Builds a key for |hostname|, defaulting the address family to unspecified. +HostCache::Key Key(const std::string& hostname) { + return HostCache::Key(hostname, ADDRESS_FAMILY_UNSPECIFIED, 0); +} + +} // namespace + +TEST(HostCacheTest, Basic) { + const base::TimeDelta kTTL = base::TimeDelta::FromSeconds(10); + + HostCache cache(kMaxCacheEntries); + + // Start at t=0. + base::TimeTicks now; + + HostCache::Key key1 = Key("foobar.com"); + HostCache::Key key2 = Key("foobar2.com"); + HostCache::Entry entry = HostCache::Entry(OK, AddressList()); + + EXPECT_EQ(0U, cache.size()); + + // Add an entry for "foobar.com" at t=0. + EXPECT_FALSE(cache.Lookup(key1, now)); + cache.Set(key1, entry, now, kTTL); + EXPECT_TRUE(cache.Lookup(key1, now)); + EXPECT_TRUE(cache.Lookup(key1, now)->error == entry.error); + + EXPECT_EQ(1U, cache.size()); + + // Advance to t=5. + now += base::TimeDelta::FromSeconds(5); + + // Add an entry for "foobar2.com" at t=5. + EXPECT_FALSE(cache.Lookup(key2, now)); + cache.Set(key2, entry, now, kTTL); + EXPECT_TRUE(cache.Lookup(key2, now)); + EXPECT_EQ(2U, cache.size()); + + // Advance to t=9 + now += base::TimeDelta::FromSeconds(4); + + // Verify that the entries we added are still retrievable, and usable. + EXPECT_TRUE(cache.Lookup(key1, now)); + EXPECT_TRUE(cache.Lookup(key2, now)); + EXPECT_NE(cache.Lookup(key1, now), cache.Lookup(key2, now)); + + // Advance to t=10; key is now expired. + now += base::TimeDelta::FromSeconds(1); + + EXPECT_FALSE(cache.Lookup(key1, now)); + EXPECT_TRUE(cache.Lookup(key2, now)); + + // Update key1, so it is no longer expired. + cache.Set(key1, entry, now, kTTL); + EXPECT_TRUE(cache.Lookup(key1, now)); + EXPECT_EQ(2U, cache.size()); + + // Both entries should still be retrievable and usable. + EXPECT_TRUE(cache.Lookup(key1, now)); + EXPECT_TRUE(cache.Lookup(key2, now)); + + // Advance to t=20; both entries are now expired. + now += base::TimeDelta::FromSeconds(10); + + EXPECT_FALSE(cache.Lookup(key1, now)); + EXPECT_FALSE(cache.Lookup(key2, now)); +} + +// Try caching entries for a failed resolve attempt -- since we set the TTL of +// such entries to 0 it won't store, but it will kick out the previous result. +TEST(HostCacheTest, NoCacheZeroTTL) { + const base::TimeDelta kSuccessEntryTTL = base::TimeDelta::FromSeconds(10); + const base::TimeDelta kFailureEntryTTL = base::TimeDelta::FromSeconds(0); + + HostCache cache(kMaxCacheEntries); + + // Set t=0. + base::TimeTicks now; + + HostCache::Key key1 = Key("foobar.com"); + HostCache::Key key2 = Key("foobar2.com"); + HostCache::Entry entry = HostCache::Entry(OK, AddressList()); + + EXPECT_FALSE(cache.Lookup(key1, now)); + cache.Set(key1, entry, now, kFailureEntryTTL); + EXPECT_EQ(1U, cache.size()); + + // We disallow use of negative entries. + EXPECT_FALSE(cache.Lookup(key1, now)); + + // Now overwrite with a valid entry, and then overwrite with negative entry + // again -- the valid entry should be kicked out. + cache.Set(key1, entry, now, kSuccessEntryTTL); + EXPECT_TRUE(cache.Lookup(key1, now)); + cache.Set(key1, entry, now, kFailureEntryTTL); + EXPECT_FALSE(cache.Lookup(key1, now)); +} + +// Try caching entries for a failed resolves for 10 seconds. +TEST(HostCacheTest, CacheNegativeEntry) { + const base::TimeDelta kFailureEntryTTL = base::TimeDelta::FromSeconds(10); + + HostCache cache(kMaxCacheEntries); + + // Start at t=0. + base::TimeTicks now; + + HostCache::Key key1 = Key("foobar.com"); + HostCache::Key key2 = Key("foobar2.com"); + HostCache::Entry entry = HostCache::Entry(OK, AddressList()); + + EXPECT_EQ(0U, cache.size()); + + // Add an entry for "foobar.com" at t=0. + EXPECT_FALSE(cache.Lookup(key1, now)); + cache.Set(key1, entry, now, kFailureEntryTTL); + EXPECT_TRUE(cache.Lookup(key1, now)); + EXPECT_EQ(1U, cache.size()); + + // Advance to t=5. + now += base::TimeDelta::FromSeconds(5); + + // Add an entry for "foobar2.com" at t=5. + EXPECT_FALSE(cache.Lookup(key2, now)); + cache.Set(key2, entry, now, kFailureEntryTTL); + EXPECT_TRUE(cache.Lookup(key2, now)); + EXPECT_EQ(2U, cache.size()); + + // Advance to t=9 + now += base::TimeDelta::FromSeconds(4); + + // Verify that the entries we added are still retrievable, and usable. + EXPECT_TRUE(cache.Lookup(key1, now)); + EXPECT_TRUE(cache.Lookup(key2, now)); + + // Advance to t=10; key1 is now expired. + now += base::TimeDelta::FromSeconds(1); + + EXPECT_FALSE(cache.Lookup(key1, now)); + EXPECT_TRUE(cache.Lookup(key2, now)); + + // Update key1, so it is no longer expired. + cache.Set(key1, entry, now, kFailureEntryTTL); + // Re-uses existing entry storage. + EXPECT_TRUE(cache.Lookup(key1, now)); + EXPECT_EQ(2U, cache.size()); + + // Both entries should still be retrievable and usable. + EXPECT_TRUE(cache.Lookup(key1, now)); + EXPECT_TRUE(cache.Lookup(key2, now)); + + // Advance to t=20; both entries are now expired. + now += base::TimeDelta::FromSeconds(10); + + EXPECT_FALSE(cache.Lookup(key1, now)); + EXPECT_FALSE(cache.Lookup(key2, now)); +} + +// Tests that the same hostname can be duplicated in the cache, so long as +// the address family differs. +TEST(HostCacheTest, AddressFamilyIsPartOfKey) { + const base::TimeDelta kSuccessEntryTTL = base::TimeDelta::FromSeconds(10); + + HostCache cache(kMaxCacheEntries); + + // t=0. + base::TimeTicks now; + + HostCache::Key key1("foobar.com", ADDRESS_FAMILY_UNSPECIFIED, 0); + HostCache::Key key2("foobar.com", ADDRESS_FAMILY_IPV4, 0); + HostCache::Entry entry = HostCache::Entry(OK, AddressList()); + + EXPECT_EQ(0U, cache.size()); + + // Add an entry for ("foobar.com", UNSPECIFIED) at t=0. + EXPECT_FALSE(cache.Lookup(key1, now)); + cache.Set(key1, entry, now, kSuccessEntryTTL); + EXPECT_TRUE(cache.Lookup(key1, now)); + EXPECT_EQ(1U, cache.size()); + + // Add an entry for ("foobar.com", IPV4_ONLY) at t=0. + EXPECT_FALSE(cache.Lookup(key2, now)); + cache.Set(key2, entry, now, kSuccessEntryTTL); + EXPECT_TRUE(cache.Lookup(key2, now)); + EXPECT_EQ(2U, cache.size()); + + // Even though the hostnames were the same, we should have two unique + // entries (because the address families differ). + EXPECT_NE(cache.Lookup(key1, now), cache.Lookup(key2, now)); +} + +// Tests that the same hostname can be duplicated in the cache, so long as +// the HostResolverFlags differ. +TEST(HostCacheTest, HostResolverFlagsArePartOfKey) { + const base::TimeDelta kTTL = base::TimeDelta::FromSeconds(10); + + HostCache cache(kMaxCacheEntries); + + // t=0. + base::TimeTicks now; + + HostCache::Key key1("foobar.com", ADDRESS_FAMILY_IPV4, 0); + HostCache::Key key2("foobar.com", ADDRESS_FAMILY_IPV4, + HOST_RESOLVER_CANONNAME); + HostCache::Key key3("foobar.com", ADDRESS_FAMILY_IPV4, + HOST_RESOLVER_LOOPBACK_ONLY); + HostCache::Entry entry = HostCache::Entry(OK, AddressList()); + + EXPECT_EQ(0U, cache.size()); + + // Add an entry for ("foobar.com", IPV4, NONE) at t=0. + EXPECT_FALSE(cache.Lookup(key1, now)); + cache.Set(key1, entry, now, kTTL); + EXPECT_TRUE(cache.Lookup(key1, now)); + EXPECT_EQ(1U, cache.size()); + + // Add an entry for ("foobar.com", IPV4, CANONNAME) at t=0. + EXPECT_FALSE(cache.Lookup(key2, now)); + cache.Set(key2, entry, now, kTTL); + EXPECT_TRUE(cache.Lookup(key2, now)); + EXPECT_EQ(2U, cache.size()); + + // Add an entry for ("foobar.com", IPV4, LOOPBACK_ONLY) at t=0. + EXPECT_FALSE(cache.Lookup(key3, now)); + cache.Set(key3, entry, now, kTTL); + EXPECT_TRUE(cache.Lookup(key3, now)); + EXPECT_EQ(3U, cache.size()); + + // Even though the hostnames were the same, we should have two unique + // entries (because the HostResolverFlags differ). + EXPECT_NE(cache.Lookup(key1, now), cache.Lookup(key2, now)); + EXPECT_NE(cache.Lookup(key1, now), cache.Lookup(key3, now)); + EXPECT_NE(cache.Lookup(key2, now), cache.Lookup(key3, now)); +} + +TEST(HostCacheTest, NoCache) { + // Disable caching. + const base::TimeDelta kTTL = base::TimeDelta::FromSeconds(10); + + HostCache cache(0); + EXPECT_TRUE(cache.caching_is_disabled()); + + // Set t=0. + base::TimeTicks now; + + HostCache::Entry entry = HostCache::Entry(OK, AddressList()); + + // Lookup and Set should have no effect. + EXPECT_FALSE(cache.Lookup(Key("foobar.com"),now)); + cache.Set(Key("foobar.com"), entry, now, kTTL); + EXPECT_FALSE(cache.Lookup(Key("foobar.com"), now)); + + EXPECT_EQ(0U, cache.size()); +} + +TEST(HostCacheTest, Clear) { + const base::TimeDelta kTTL = base::TimeDelta::FromSeconds(10); + + HostCache cache(kMaxCacheEntries); + + // Set t=0. + base::TimeTicks now; + + HostCache::Entry entry = HostCache::Entry(OK, AddressList()); + + EXPECT_EQ(0u, cache.size()); + + // Add three entries. + cache.Set(Key("foobar1.com"), entry, now, kTTL); + cache.Set(Key("foobar2.com"), entry, now, kTTL); + cache.Set(Key("foobar3.com"), entry, now, kTTL); + + EXPECT_EQ(3u, cache.size()); + + cache.clear(); + + EXPECT_EQ(0u, cache.size()); +} + +// Tests the less than and equal operators for HostCache::Key work. +TEST(HostCacheTest, KeyComparators) { + struct { + // Inputs. + HostCache::Key key1; + HostCache::Key key2; + + // Expectation. + // -1 means key1 is less than key2 + // 0 means key1 equals key2 + // 1 means key1 is greater than key2 + int expected_comparison; + } tests[] = { + { + HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0), + HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0), + 0 + }, + { + HostCache::Key("host1", ADDRESS_FAMILY_IPV4, 0), + HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0), + 1 + }, + { + HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0), + HostCache::Key("host1", ADDRESS_FAMILY_IPV4, 0), + -1 + }, + { + HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0), + HostCache::Key("host2", ADDRESS_FAMILY_UNSPECIFIED, 0), + -1 + }, + { + HostCache::Key("host1", ADDRESS_FAMILY_IPV4, 0), + HostCache::Key("host2", ADDRESS_FAMILY_UNSPECIFIED, 0), + 1 + }, + { + HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0), + HostCache::Key("host2", ADDRESS_FAMILY_IPV4, 0), + -1 + }, + { + HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0), + HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, + HOST_RESOLVER_CANONNAME), + -1 + }, + { + HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, + HOST_RESOLVER_CANONNAME), + HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, 0), + 1 + }, + { + HostCache::Key("host1", ADDRESS_FAMILY_UNSPECIFIED, + HOST_RESOLVER_CANONNAME), + HostCache::Key("host2", ADDRESS_FAMILY_UNSPECIFIED, + HOST_RESOLVER_CANONNAME), + -1 + }, + }; + + for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) { + SCOPED_TRACE(base::StringPrintf("Test[%" PRIuS "]", i)); + + const HostCache::Key& key1 = tests[i].key1; + const HostCache::Key& key2 = tests[i].key2; + + switch (tests[i].expected_comparison) { + case -1: + EXPECT_TRUE(key1 < key2); + EXPECT_FALSE(key2 < key1); + break; + case 0: + EXPECT_FALSE(key1 < key2); + EXPECT_FALSE(key2 < key1); + break; + case 1: + EXPECT_FALSE(key1 < key2); + EXPECT_TRUE(key2 < key1); + break; + default: + FAIL() << "Invalid expectation. Can be only -1, 0, 1"; + } + } +} + +} // namespace net diff --git a/chromium/net/dns/host_resolver.cc b/chromium/net/dns/host_resolver.cc new file mode 100644 index 00000000000..d74be91beb0 --- /dev/null +++ b/chromium/net/dns/host_resolver.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/host_resolver.h" + +#include "base/logging.h" +#include "base/metrics/field_trial.h" +#include "base/strings/string_number_conversions.h" +#include "base/strings/string_split.h" +#include "net/dns/dns_client.h" +#include "net/dns/dns_config_service.h" +#include "net/dns/host_cache.h" +#include "net/dns/host_resolver_impl.h" + +namespace net { + +namespace { + +// Maximum of 6 concurrent resolver threads (excluding retries). +// Some routers (or resolvers) appear to start to provide host-not-found if +// too many simultaneous resolutions are pending. This number needs to be +// further optimized, but 8 is what FF currently does. We found some routers +// that limit this to 6, so we're temporarily holding it at that level. +const size_t kDefaultMaxProcTasks = 6u; + +// When configuring from field trial, do not allow +const size_t kSaneMaxProcTasks = 20u; + +PrioritizedDispatcher::Limits GetDispatcherLimits( + const HostResolver::Options& options) { + PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, + options.max_concurrent_resolves); + + // If not using default, do not use the field trial. + if (limits.total_jobs != HostResolver::kDefaultParallelism) + return limits; + + // Default, without trial is no reserved slots. + limits.total_jobs = kDefaultMaxProcTasks; + + // Parallelism is determined by the field trial. + std::string group = base::FieldTrialList::FindFullName( + "HostResolverDispatch"); + + if (group.empty()) + return limits; + + // The format of the group name is a list of non-negative integers separated + // by ':'. Each of the elements in the list corresponds to an element in + // |reserved_slots|, except the last one which is the |total_jobs|. + + std::vector<std::string> group_parts; + base::SplitString(group, ':', &group_parts); + if (group_parts.size() != NUM_PRIORITIES + 1) { + NOTREACHED(); + return limits; + } + + std::vector<size_t> parsed(group_parts.size()); + size_t total_reserved_slots = 0; + + for (size_t i = 0; i < group_parts.size(); ++i) { + if (!base::StringToSizeT(group_parts[i], &parsed[i])) { + NOTREACHED(); + return limits; + } + } + + size_t total_jobs = parsed.back(); + parsed.pop_back(); + for (size_t i = 0; i < parsed.size(); ++i) { + total_reserved_slots += parsed[i]; + } + + // There must be some unreserved slots available for the all priorities. + if (total_reserved_slots > total_jobs || + (total_reserved_slots == total_jobs && parsed[MINIMUM_PRIORITY] == 0)) { + NOTREACHED(); + return limits; + } + + limits.total_jobs = total_jobs; + limits.reserved_slots = parsed; + return limits; +} + +} // namespace + +HostResolver::Options::Options() + : max_concurrent_resolves(kDefaultParallelism), + max_retry_attempts(kDefaultRetryAttempts), + enable_caching(true) { +} + +HostResolver::RequestInfo::RequestInfo(const HostPortPair& host_port_pair) + : host_port_pair_(host_port_pair), + address_family_(ADDRESS_FAMILY_UNSPECIFIED), + host_resolver_flags_(0), + allow_cached_response_(true), + is_speculative_(false), + priority_(MEDIUM) { +} + +HostResolver::~HostResolver() { +} + +AddressFamily HostResolver::GetDefaultAddressFamily() const { + return ADDRESS_FAMILY_UNSPECIFIED; +} + +void HostResolver::SetDnsClientEnabled(bool enabled) { +} + +HostCache* HostResolver::GetHostCache() { + return NULL; +} + +base::Value* HostResolver::GetDnsConfigAsValue() const { + return NULL; +} + +// static +scoped_ptr<HostResolver> +HostResolver::CreateSystemResolver(const Options& options, NetLog* net_log) { + scoped_ptr<HostCache> cache; + if (options.enable_caching) + cache = HostCache::CreateDefaultCache(); + return scoped_ptr<HostResolver>(new HostResolverImpl( + cache.Pass(), + GetDispatcherLimits(options), + HostResolverImpl::ProcTaskParams(NULL, options.max_retry_attempts), + net_log)); +} + +// static +scoped_ptr<HostResolver> +HostResolver::CreateDefaultResolver(NetLog* net_log) { + return CreateSystemResolver(Options(), net_log); +} + +HostResolver::HostResolver() { +} + +} // namespace net diff --git a/chromium/net/dns/host_resolver.h b/chromium/net/dns/host_resolver.h new file mode 100644 index 00000000000..558a1ddc4b2 --- /dev/null +++ b/chromium/net/dns/host_resolver.h @@ -0,0 +1,204 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_HOST_RESOLVER_H_ +#define NET_DNS_HOST_RESOLVER_H_ + +#include <string> + +#include "base/memory/scoped_ptr.h" +#include "net/base/address_family.h" +#include "net/base/completion_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/net_export.h" +#include "net/base/net_util.h" +#include "net/base/request_priority.h" + +namespace base { +class Value; +} + +namespace net { + +class AddressList; +class BoundNetLog; +class HostCache; +class HostResolverProc; +class NetLog; + +// This class represents the task of resolving hostnames (or IP address +// literal) to an AddressList object. +// +// HostResolver can handle multiple requests at a time, so when cancelling a +// request the RequestHandle that was returned by Resolve() needs to be +// given. A simpler alternative for consumers that only have 1 outstanding +// request at a time is to create a SingleRequestHostResolver wrapper around +// HostResolver (which will automatically cancel the single request when it +// goes out of scope). +class NET_EXPORT HostResolver { + public: + // |max_concurrent_resolves| is how many resolve requests will be allowed to + // run in parallel. Pass HostResolver::kDefaultParallelism to choose a + // default value. + // |max_retry_attempts| is the maximum number of times we will retry for host + // resolution. Pass HostResolver::kDefaultRetryAttempts to choose a default + // value. + // |enable_caching| controls whether a HostCache is used. + struct NET_EXPORT Options { + Options(); + + size_t max_concurrent_resolves; + size_t max_retry_attempts; + bool enable_caching; + }; + + // The parameters for doing a Resolve(). A hostname and port are required, + // the rest are optional (and have reasonable defaults). + class NET_EXPORT RequestInfo { + public: + explicit RequestInfo(const HostPortPair& host_port_pair); + + const HostPortPair& host_port_pair() const { return host_port_pair_; } + void set_host_port_pair(const HostPortPair& host_port_pair) { + host_port_pair_ = host_port_pair; + } + + int port() const { return host_port_pair_.port(); } + const std::string& hostname() const { return host_port_pair_.host(); } + + AddressFamily address_family() const { return address_family_; } + void set_address_family(AddressFamily address_family) { + address_family_ = address_family; + } + + HostResolverFlags host_resolver_flags() const { + return host_resolver_flags_; + } + void set_host_resolver_flags(HostResolverFlags host_resolver_flags) { + host_resolver_flags_ = host_resolver_flags; + } + + bool allow_cached_response() const { return allow_cached_response_; } + void set_allow_cached_response(bool b) { allow_cached_response_ = b; } + + bool is_speculative() const { return is_speculative_; } + void set_is_speculative(bool b) { is_speculative_ = b; } + + RequestPriority priority() const { return priority_; } + void set_priority(RequestPriority priority) { priority_ = priority; } + + private: + // The hostname to resolve, and the port to use in resulting sockaddrs. + HostPortPair host_port_pair_; + + // The address family to restrict results to. + AddressFamily address_family_; + + // Flags to use when resolving this request. + HostResolverFlags host_resolver_flags_; + + // Whether it is ok to return a result from the host cache. + bool allow_cached_response_; + + // Whether this request was started by the DNS prefetcher. + bool is_speculative_; + + // The priority for the request. + RequestPriority priority_; + }; + + // Opaque type used to cancel a request. + typedef void* RequestHandle; + + // This value can be passed into CreateSystemResolver as the + // |max_concurrent_resolves| parameter. It will select a default level of + // concurrency. + static const size_t kDefaultParallelism = 0; + + // This value can be passed into CreateSystemResolver as the + // |max_retry_attempts| parameter. + static const size_t kDefaultRetryAttempts = -1; + + // If any completion callbacks are pending when the resolver is destroyed, + // the host resolutions are cancelled, and the completion callbacks will not + // be called. + virtual ~HostResolver(); + + // Resolves the given hostname (or IP address literal), filling out the + // |addresses| object upon success. The |info.port| parameter will be set as + // the sin(6)_port field of the sockaddr_in{6} struct. Returns OK if + // successful or an error code upon failure. Returns + // ERR_NAME_NOT_RESOLVED if hostname is invalid, or if it is an + // incompatible IP literal (e.g. IPv6 is disabled and it is an IPv6 + // literal). + // + // If the operation cannot be completed synchronously, ERR_IO_PENDING will + // be returned and the real result code will be passed to the completion + // callback. Otherwise the result code is returned immediately from this + // call. + // + // If |out_req| is non-NULL, then |*out_req| will be filled with a handle to + // the async request. This handle is not valid after the request has + // completed. + // + // Profiling information for the request is saved to |net_log| if non-NULL. + virtual int Resolve(const RequestInfo& info, + AddressList* addresses, + const CompletionCallback& callback, + RequestHandle* out_req, + const BoundNetLog& net_log) = 0; + + // Resolves the given hostname (or IP address literal) out of cache or HOSTS + // file (if enabled) only. This is guaranteed to complete synchronously. + // This acts like |Resolve()| if the hostname is IP literal, or cached value + // or HOSTS entry exists. Otherwise, ERR_DNS_CACHE_MISS is returned. + virtual int ResolveFromCache(const RequestInfo& info, + AddressList* addresses, + const BoundNetLog& net_log) = 0; + + // Cancels the specified request. |req| is the handle returned by Resolve(). + // After a request is canceled, its completion callback will not be called. + // CancelRequest must NOT be called after the request's completion callback + // has already run or the request was canceled. + virtual void CancelRequest(RequestHandle req) = 0; + + // Sets the default AddressFamily to use when requests have left it + // unspecified. For example, this could be used to restrict resolution + // results to AF_INET by passing in ADDRESS_FAMILY_IPV4, or to + // AF_INET6 by passing in ADDRESS_FAMILY_IPV6. + virtual void SetDefaultAddressFamily(AddressFamily address_family) {} + virtual AddressFamily GetDefaultAddressFamily() const; + + // Enable or disable the built-in asynchronous DnsClient. + virtual void SetDnsClientEnabled(bool enabled); + + // Returns the HostResolverCache |this| uses, or NULL if there isn't one. + // Used primarily to clear the cache and for getting debug information. + virtual HostCache* GetHostCache(); + + // Returns the current DNS configuration |this| is using, as a Value, or NULL + // if it's configured to always use the system host resolver. Caller takes + // ownership of the returned Value. + virtual base::Value* GetDnsConfigAsValue() const; + + // Creates a HostResolver implementation that queries the underlying system. + // (Except if a unit-test has changed the global HostResolverProc using + // ScopedHostResolverProc to intercept requests to the system). + static scoped_ptr<HostResolver> CreateSystemResolver( + const Options& options, + NetLog* net_log); + + // As above, but uses default parameters. + static scoped_ptr<HostResolver> CreateDefaultResolver(NetLog* net_log); + + protected: + HostResolver(); + + private: + DISALLOW_COPY_AND_ASSIGN(HostResolver); +}; + +} // namespace net + +#endif // NET_DNS_HOST_RESOLVER_H_ diff --git a/chromium/net/dns/host_resolver_impl.cc b/chromium/net/dns/host_resolver_impl.cc new file mode 100644 index 00000000000..10631773291 --- /dev/null +++ b/chromium/net/dns/host_resolver_impl.cc @@ -0,0 +1,2206 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/host_resolver_impl.h" + +#if defined(OS_WIN) +#include <Winsock2.h> +#elif defined(OS_POSIX) +#include <netdb.h> +#endif + +#include <cmath> +#include <utility> +#include <vector> + +#include "base/basictypes.h" +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/debug/debugger.h" +#include "base/debug/stack_trace.h" +#include "base/message_loop/message_loop_proxy.h" +#include "base/metrics/field_trial.h" +#include "base/metrics/histogram.h" +#include "base/stl_util.h" +#include "base/strings/string_util.h" +#include "base/strings/utf_string_conversions.h" +#include "base/threading/worker_pool.h" +#include "base/time/time.h" +#include "base/values.h" +#include "net/base/address_family.h" +#include "net/base/address_list.h" +#include "net/base/dns_reloader.h" +#include "net/base/dns_util.h" +#include "net/base/host_port_pair.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/net_util.h" +#include "net/dns/address_sorter.h" +#include "net/dns/dns_client.h" +#include "net/dns/dns_config_service.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/dns_response.h" +#include "net/dns/dns_transaction.h" +#include "net/dns/host_resolver_proc.h" +#include "net/socket/client_socket_factory.h" +#include "net/udp/datagram_client_socket.h" + +#if defined(OS_WIN) +#include "net/base/winsock_init.h" +#endif + +namespace net { + +namespace { + +// Limit the size of hostnames that will be resolved to combat issues in +// some platform's resolvers. +const size_t kMaxHostLength = 4096; + +// Default TTL for successful resolutions with ProcTask. +const unsigned kCacheEntryTTLSeconds = 60; + +// Default TTL for unsuccessful resolutions with ProcTask. +const unsigned kNegativeCacheEntryTTLSeconds = 0; + +// Minimum TTL for successful resolutions with DnsTask. +const unsigned kMinimumTTLSeconds = kCacheEntryTTLSeconds; + +// Number of consecutive failures of DnsTask (with successful fallback) before +// the DnsClient is disabled until the next DNS change. +const unsigned kMaximumDnsFailures = 16; + +// We use a separate histogram name for each platform to facilitate the +// display of error codes by their symbolic name (since each platform has +// different mappings). +const char kOSErrorsForGetAddrinfoHistogramName[] = +#if defined(OS_WIN) + "Net.OSErrorsForGetAddrinfo_Win"; +#elif defined(OS_MACOSX) + "Net.OSErrorsForGetAddrinfo_Mac"; +#elif defined(OS_LINUX) + "Net.OSErrorsForGetAddrinfo_Linux"; +#else + "Net.OSErrorsForGetAddrinfo"; +#endif + +// Gets a list of the likely error codes that getaddrinfo() can return +// (non-exhaustive). These are the error codes that we will track via +// a histogram. +std::vector<int> GetAllGetAddrinfoOSErrors() { + int os_errors[] = { +#if defined(OS_POSIX) +#if !defined(OS_FREEBSD) +#if !defined(OS_ANDROID) + // EAI_ADDRFAMILY has been declared obsolete in Android's and + // FreeBSD's netdb.h. + EAI_ADDRFAMILY, +#endif + // EAI_NODATA has been declared obsolete in FreeBSD's netdb.h. + EAI_NODATA, +#endif + EAI_AGAIN, + EAI_BADFLAGS, + EAI_FAIL, + EAI_FAMILY, + EAI_MEMORY, + EAI_NONAME, + EAI_SERVICE, + EAI_SOCKTYPE, + EAI_SYSTEM, +#elif defined(OS_WIN) + // See: http://msdn.microsoft.com/en-us/library/ms738520(VS.85).aspx + WSA_NOT_ENOUGH_MEMORY, + WSAEAFNOSUPPORT, + WSAEINVAL, + WSAESOCKTNOSUPPORT, + WSAHOST_NOT_FOUND, + WSANO_DATA, + WSANO_RECOVERY, + WSANOTINITIALISED, + WSATRY_AGAIN, + WSATYPE_NOT_FOUND, + // The following are not in doc, but might be to appearing in results :-(. + WSA_INVALID_HANDLE, +#endif + }; + + // Ensure all errors are positive, as histogram only tracks positive values. + for (size_t i = 0; i < arraysize(os_errors); ++i) { + os_errors[i] = std::abs(os_errors[i]); + } + + return base::CustomHistogram::ArrayToCustomRanges(os_errors, + arraysize(os_errors)); +} + +enum DnsResolveStatus { + RESOLVE_STATUS_DNS_SUCCESS = 0, + RESOLVE_STATUS_PROC_SUCCESS, + RESOLVE_STATUS_FAIL, + RESOLVE_STATUS_SUSPECT_NETBIOS, + RESOLVE_STATUS_MAX +}; + +void UmaAsyncDnsResolveStatus(DnsResolveStatus result) { + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.ResolveStatus", + result, + RESOLVE_STATUS_MAX); +} + +bool ResemblesNetBIOSName(const std::string& hostname) { + return (hostname.size() < 16) && (hostname.find('.') == std::string::npos); +} + +// True if |hostname| ends with either ".local" or ".local.". +bool ResemblesMulticastDNSName(const std::string& hostname) { + DCHECK(!hostname.empty()); + const char kSuffix[] = ".local."; + const size_t kSuffixLen = sizeof(kSuffix) - 1; + const size_t kSuffixLenTrimmed = kSuffixLen - 1; + if (hostname[hostname.size() - 1] == '.') { + return hostname.size() > kSuffixLen && + !hostname.compare(hostname.size() - kSuffixLen, kSuffixLen, kSuffix); + } + return hostname.size() > kSuffixLenTrimmed && + !hostname.compare(hostname.size() - kSuffixLenTrimmed, kSuffixLenTrimmed, + kSuffix, kSuffixLenTrimmed); +} + +// Attempts to connect a UDP socket to |dest|:53. +bool IsGloballyReachable(const IPAddressNumber& dest, + const BoundNetLog& net_log) { + scoped_ptr<DatagramClientSocket> socket( + ClientSocketFactory::GetDefaultFactory()->CreateDatagramClientSocket( + DatagramSocket::DEFAULT_BIND, + RandIntCallback(), + net_log.net_log(), + net_log.source())); + int rv = socket->Connect(IPEndPoint(dest, 53)); + if (rv != OK) + return false; + IPEndPoint endpoint; + rv = socket->GetLocalAddress(&endpoint); + if (rv != OK) + return false; + DCHECK(endpoint.GetFamily() == ADDRESS_FAMILY_IPV6); + const IPAddressNumber& address = endpoint.address(); + bool is_link_local = (address[0] == 0xFE) && ((address[1] & 0xC0) == 0x80); + if (is_link_local) + return false; + const uint8 kTeredoPrefix[] = { 0x20, 0x01, 0, 0 }; + bool is_teredo = std::equal(kTeredoPrefix, + kTeredoPrefix + arraysize(kTeredoPrefix), + address.begin()); + if (is_teredo) + return false; + return true; +} + +// Provide a common macro to simplify code and readability. We must use a +// macro as the underlying HISTOGRAM macro creates static variables. +#define DNS_HISTOGRAM(name, time) UMA_HISTOGRAM_CUSTOM_TIMES(name, time, \ + base::TimeDelta::FromMilliseconds(1), base::TimeDelta::FromHours(1), 100) + +// A macro to simplify code and readability. +#define DNS_HISTOGRAM_BY_PRIORITY(basename, priority, time) \ + do { \ + switch (priority) { \ + case HIGHEST: DNS_HISTOGRAM(basename "_HIGHEST", time); break; \ + case MEDIUM: DNS_HISTOGRAM(basename "_MEDIUM", time); break; \ + case LOW: DNS_HISTOGRAM(basename "_LOW", time); break; \ + case LOWEST: DNS_HISTOGRAM(basename "_LOWEST", time); break; \ + case IDLE: DNS_HISTOGRAM(basename "_IDLE", time); break; \ + default: NOTREACHED(); break; \ + } \ + DNS_HISTOGRAM(basename, time); \ + } while (0) + +// Record time from Request creation until a valid DNS response. +void RecordTotalTime(bool had_dns_config, + bool speculative, + base::TimeDelta duration) { + if (had_dns_config) { + if (speculative) { + DNS_HISTOGRAM("AsyncDNS.TotalTime_speculative", duration); + } else { + DNS_HISTOGRAM("AsyncDNS.TotalTime", duration); + } + } else { + if (speculative) { + DNS_HISTOGRAM("DNS.TotalTime_speculative", duration); + } else { + DNS_HISTOGRAM("DNS.TotalTime", duration); + } + } +} + +void RecordTTL(base::TimeDelta ttl) { + UMA_HISTOGRAM_CUSTOM_TIMES("AsyncDNS.TTL", ttl, + base::TimeDelta::FromSeconds(1), + base::TimeDelta::FromDays(1), 100); +} + +bool ConfigureAsyncDnsNoFallbackFieldTrial() { + const bool kDefault = false; + + // Configure the AsyncDns field trial as follows: + // groups AsyncDnsNoFallbackA and AsyncDnsNoFallbackB: return true, + // groups AsyncDnsA and AsyncDnsB: return false, + // groups SystemDnsA and SystemDnsB: return false, + // otherwise (trial absent): return default. + std::string group_name = base::FieldTrialList::FindFullName("AsyncDns"); + if (!group_name.empty()) + return StartsWithASCII(group_name, "AsyncDnsNoFallback", false); + return kDefault; +} + +//----------------------------------------------------------------------------- + +AddressList EnsurePortOnAddressList(const AddressList& list, uint16 port) { + if (list.empty() || list.front().port() == port) + return list; + return AddressList::CopyWithPort(list, port); +} + +// Returns true if |addresses| contains only IPv4 loopback addresses. +bool IsAllIPv4Loopback(const AddressList& addresses) { + for (unsigned i = 0; i < addresses.size(); ++i) { + const IPAddressNumber& address = addresses[i].address(); + switch (addresses[i].GetFamily()) { + case ADDRESS_FAMILY_IPV4: + if (address[0] != 127) + return false; + break; + case ADDRESS_FAMILY_IPV6: + return false; + default: + NOTREACHED(); + return false; + } + } + return true; +} + +// Creates NetLog parameters when the resolve failed. +base::Value* NetLogProcTaskFailedCallback(uint32 attempt_number, + int net_error, + int os_error, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + if (attempt_number) + dict->SetInteger("attempt_number", attempt_number); + + dict->SetInteger("net_error", net_error); + + if (os_error) { + dict->SetInteger("os_error", os_error); +#if defined(OS_POSIX) + dict->SetString("os_error_string", gai_strerror(os_error)); +#elif defined(OS_WIN) + // Map the error code to a human-readable string. + LPWSTR error_string = NULL; + int size = FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | + FORMAT_MESSAGE_FROM_SYSTEM, + 0, // Use the internal message table. + os_error, + 0, // Use default language. + (LPWSTR)&error_string, + 0, // Buffer size. + 0); // Arguments (unused). + dict->SetString("os_error_string", WideToUTF8(error_string)); + LocalFree(error_string); +#endif + } + + return dict; +} + +// Creates NetLog parameters when the DnsTask failed. +base::Value* NetLogDnsTaskFailedCallback(int net_error, + int dns_error, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetInteger("net_error", net_error); + if (dns_error) + dict->SetInteger("dns_error", dns_error); + return dict; +}; + +// Creates NetLog parameters containing the information in a RequestInfo object, +// along with the associated NetLog::Source. +base::Value* NetLogRequestInfoCallback(const NetLog::Source& source, + const HostResolver::RequestInfo* info, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + source.AddToEventParameters(dict); + + dict->SetString("host", info->host_port_pair().ToString()); + dict->SetInteger("address_family", + static_cast<int>(info->address_family())); + dict->SetBoolean("allow_cached_response", info->allow_cached_response()); + dict->SetBoolean("is_speculative", info->is_speculative()); + dict->SetInteger("priority", info->priority()); + return dict; +} + +// Creates NetLog parameters for the creation of a HostResolverImpl::Job. +base::Value* NetLogJobCreationCallback(const NetLog::Source& source, + const std::string* host, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + source.AddToEventParameters(dict); + dict->SetString("host", *host); + return dict; +} + +// Creates NetLog parameters for HOST_RESOLVER_IMPL_JOB_ATTACH/DETACH events. +base::Value* NetLogJobAttachCallback(const NetLog::Source& source, + RequestPriority priority, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + source.AddToEventParameters(dict); + dict->SetInteger("priority", priority); + return dict; +} + +// Creates NetLog parameters for the DNS_CONFIG_CHANGED event. +base::Value* NetLogDnsConfigCallback(const DnsConfig* config, + NetLog::LogLevel /* log_level */) { + return config->ToValue(); +} + +// The logging routines are defined here because some requests are resolved +// without a Request object. + +// Logs when a request has just been started. +void LogStartRequest(const BoundNetLog& source_net_log, + const BoundNetLog& request_net_log, + const HostResolver::RequestInfo& info) { + source_net_log.BeginEvent( + NetLog::TYPE_HOST_RESOLVER_IMPL, + request_net_log.source().ToEventParametersCallback()); + + request_net_log.BeginEvent( + NetLog::TYPE_HOST_RESOLVER_IMPL_REQUEST, + base::Bind(&NetLogRequestInfoCallback, source_net_log.source(), &info)); +} + +// Logs when a request has just completed (before its callback is run). +void LogFinishRequest(const BoundNetLog& source_net_log, + const BoundNetLog& request_net_log, + const HostResolver::RequestInfo& info, + int net_error) { + request_net_log.EndEventWithNetErrorCode( + NetLog::TYPE_HOST_RESOLVER_IMPL_REQUEST, net_error); + source_net_log.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL); +} + +// Logs when a request has been cancelled. +void LogCancelRequest(const BoundNetLog& source_net_log, + const BoundNetLog& request_net_log, + const HostResolverImpl::RequestInfo& info) { + request_net_log.AddEvent(NetLog::TYPE_CANCELLED); + request_net_log.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_REQUEST); + source_net_log.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL); +} + +//----------------------------------------------------------------------------- + +// Keeps track of the highest priority. +class PriorityTracker { + public: + explicit PriorityTracker(RequestPriority initial_priority) + : highest_priority_(initial_priority), total_count_(0) { + memset(counts_, 0, sizeof(counts_)); + } + + RequestPriority highest_priority() const { + return highest_priority_; + } + + size_t total_count() const { + return total_count_; + } + + void Add(RequestPriority req_priority) { + ++total_count_; + ++counts_[req_priority]; + if (highest_priority_ < req_priority) + highest_priority_ = req_priority; + } + + void Remove(RequestPriority req_priority) { + DCHECK_GT(total_count_, 0u); + DCHECK_GT(counts_[req_priority], 0u); + --total_count_; + --counts_[req_priority]; + size_t i; + for (i = highest_priority_; i > MINIMUM_PRIORITY && !counts_[i]; --i); + highest_priority_ = static_cast<RequestPriority>(i); + + // In absence of requests, default to MINIMUM_PRIORITY. + if (total_count_ == 0) + DCHECK_EQ(MINIMUM_PRIORITY, highest_priority_); + } + + private: + RequestPriority highest_priority_; + size_t total_count_; + size_t counts_[NUM_PRIORITIES]; +}; + +} // namespace + +//----------------------------------------------------------------------------- + +// Holds the data for a request that could not be completed synchronously. +// It is owned by a Job. Canceled Requests are only marked as canceled rather +// than removed from the Job's |requests_| list. +class HostResolverImpl::Request { + public: + Request(const BoundNetLog& source_net_log, + const BoundNetLog& request_net_log, + const RequestInfo& info, + const CompletionCallback& callback, + AddressList* addresses) + : source_net_log_(source_net_log), + request_net_log_(request_net_log), + info_(info), + job_(NULL), + callback_(callback), + addresses_(addresses), + request_time_(base::TimeTicks::Now()) { + } + + // Mark the request as canceled. + void MarkAsCanceled() { + job_ = NULL; + addresses_ = NULL; + callback_.Reset(); + } + + bool was_canceled() const { + return callback_.is_null(); + } + + void set_job(Job* job) { + DCHECK(job); + // Identify which job the request is waiting on. + job_ = job; + } + + // Prepare final AddressList and call completion callback. + void OnComplete(int error, const AddressList& addr_list) { + DCHECK(!was_canceled()); + if (error == OK) + *addresses_ = EnsurePortOnAddressList(addr_list, info_.port()); + CompletionCallback callback = callback_; + MarkAsCanceled(); + callback.Run(error); + } + + Job* job() const { + return job_; + } + + // NetLog for the source, passed in HostResolver::Resolve. + const BoundNetLog& source_net_log() { + return source_net_log_; + } + + // NetLog for this request. + const BoundNetLog& request_net_log() { + return request_net_log_; + } + + const RequestInfo& info() const { + return info_; + } + + base::TimeTicks request_time() const { + return request_time_; + } + + private: + BoundNetLog source_net_log_; + BoundNetLog request_net_log_; + + // The request info that started the request. + RequestInfo info_; + + // The resolve job that this request is dependent on. + Job* job_; + + // The user's callback to invoke when the request completes. + CompletionCallback callback_; + + // The address list to save result into. + AddressList* addresses_; + + const base::TimeTicks request_time_; + + DISALLOW_COPY_AND_ASSIGN(Request); +}; + +//------------------------------------------------------------------------------ + +// Calls HostResolverProc on the WorkerPool. Performs retries if necessary. +// +// Whenever we try to resolve the host, we post a delayed task to check if host +// resolution (OnLookupComplete) is completed or not. If the original attempt +// hasn't completed, then we start another attempt for host resolution. We take +// the results from the first attempt that finishes and ignore the results from +// all other attempts. +// +// TODO(szym): Move to separate source file for testing and mocking. +// +class HostResolverImpl::ProcTask + : public base::RefCountedThreadSafe<HostResolverImpl::ProcTask> { + public: + typedef base::Callback<void(int net_error, + const AddressList& addr_list)> Callback; + + ProcTask(const Key& key, + const ProcTaskParams& params, + const Callback& callback, + const BoundNetLog& job_net_log) + : key_(key), + params_(params), + callback_(callback), + origin_loop_(base::MessageLoopProxy::current()), + attempt_number_(0), + completed_attempt_number_(0), + completed_attempt_error_(ERR_UNEXPECTED), + had_non_speculative_request_(false), + net_log_(job_net_log) { + if (!params_.resolver_proc.get()) + params_.resolver_proc = HostResolverProc::GetDefault(); + // If default is unset, use the system proc. + if (!params_.resolver_proc.get()) + params_.resolver_proc = new SystemHostResolverProc(); + } + + void Start() { + DCHECK(origin_loop_->BelongsToCurrentThread()); + net_log_.BeginEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_PROC_TASK); + StartLookupAttempt(); + } + + // Cancels this ProcTask. It will be orphaned. Any outstanding resolve + // attempts running on worker threads will continue running. Only once all the + // attempts complete will the final reference to this ProcTask be released. + void Cancel() { + DCHECK(origin_loop_->BelongsToCurrentThread()); + + if (was_canceled() || was_completed()) + return; + + callback_.Reset(); + net_log_.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_PROC_TASK); + } + + void set_had_non_speculative_request() { + DCHECK(origin_loop_->BelongsToCurrentThread()); + had_non_speculative_request_ = true; + } + + bool was_canceled() const { + DCHECK(origin_loop_->BelongsToCurrentThread()); + return callback_.is_null(); + } + + bool was_completed() const { + DCHECK(origin_loop_->BelongsToCurrentThread()); + return completed_attempt_number_ > 0; + } + + private: + friend class base::RefCountedThreadSafe<ProcTask>; + ~ProcTask() {} + + void StartLookupAttempt() { + DCHECK(origin_loop_->BelongsToCurrentThread()); + base::TimeTicks start_time = base::TimeTicks::Now(); + ++attempt_number_; + // Dispatch the lookup attempt to a worker thread. + if (!base::WorkerPool::PostTask( + FROM_HERE, + base::Bind(&ProcTask::DoLookup, this, start_time, attempt_number_), + true)) { + NOTREACHED(); + + // Since we could be running within Resolve() right now, we can't just + // call OnLookupComplete(). Instead we must wait until Resolve() has + // returned (IO_PENDING). + origin_loop_->PostTask( + FROM_HERE, + base::Bind(&ProcTask::OnLookupComplete, this, AddressList(), + start_time, attempt_number_, ERR_UNEXPECTED, 0)); + return; + } + + net_log_.AddEvent( + NetLog::TYPE_HOST_RESOLVER_IMPL_ATTEMPT_STARTED, + NetLog::IntegerCallback("attempt_number", attempt_number_)); + + // If we don't get the results within a given time, RetryIfNotComplete + // will start a new attempt on a different worker thread if none of our + // outstanding attempts have completed yet. + if (attempt_number_ <= params_.max_retry_attempts) { + origin_loop_->PostDelayedTask( + FROM_HERE, + base::Bind(&ProcTask::RetryIfNotComplete, this), + params_.unresponsive_delay); + } + } + + // WARNING: This code runs inside a worker pool. The shutdown code cannot + // wait for it to finish, so we must be very careful here about using other + // objects (like MessageLoops, Singletons, etc). During shutdown these objects + // may no longer exist. Multiple DoLookups() could be running in parallel, so + // any state inside of |this| must not mutate . + void DoLookup(const base::TimeTicks& start_time, + const uint32 attempt_number) { + AddressList results; + int os_error = 0; + // Running on the worker thread + int error = params_.resolver_proc->Resolve(key_.hostname, + key_.address_family, + key_.host_resolver_flags, + &results, + &os_error); + + origin_loop_->PostTask( + FROM_HERE, + base::Bind(&ProcTask::OnLookupComplete, this, results, start_time, + attempt_number, error, os_error)); + } + + // Makes next attempt if DoLookup() has not finished (runs on origin thread). + void RetryIfNotComplete() { + DCHECK(origin_loop_->BelongsToCurrentThread()); + + if (was_completed() || was_canceled()) + return; + + params_.unresponsive_delay *= params_.retry_factor; + StartLookupAttempt(); + } + + // Callback for when DoLookup() completes (runs on origin thread). + void OnLookupComplete(const AddressList& results, + const base::TimeTicks& start_time, + const uint32 attempt_number, + int error, + const int os_error) { + DCHECK(origin_loop_->BelongsToCurrentThread()); + // If results are empty, we should return an error. + bool empty_list_on_ok = (error == OK && results.empty()); + UMA_HISTOGRAM_BOOLEAN("DNS.EmptyAddressListAndNoError", empty_list_on_ok); + if (empty_list_on_ok) + error = ERR_NAME_NOT_RESOLVED; + + bool was_retry_attempt = attempt_number > 1; + + // Ideally the following code would be part of host_resolver_proc.cc, + // however it isn't safe to call NetworkChangeNotifier from worker threads. + // So we do it here on the IO thread instead. + if (error != OK && NetworkChangeNotifier::IsOffline()) + error = ERR_INTERNET_DISCONNECTED; + + // If this is the first attempt that is finishing later, then record data + // for the first attempt. Won't contaminate with retry attempt's data. + if (!was_retry_attempt) + RecordPerformanceHistograms(start_time, error, os_error); + + RecordAttemptHistograms(start_time, attempt_number, error, os_error); + + if (was_canceled()) + return; + + NetLog::ParametersCallback net_log_callback; + if (error != OK) { + net_log_callback = base::Bind(&NetLogProcTaskFailedCallback, + attempt_number, + error, + os_error); + } else { + net_log_callback = NetLog::IntegerCallback("attempt_number", + attempt_number); + } + net_log_.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_ATTEMPT_FINISHED, + net_log_callback); + + if (was_completed()) + return; + + // Copy the results from the first worker thread that resolves the host. + results_ = results; + completed_attempt_number_ = attempt_number; + completed_attempt_error_ = error; + + if (was_retry_attempt) { + // If retry attempt finishes before 1st attempt, then get stats on how + // much time is saved by having spawned an extra attempt. + retry_attempt_finished_time_ = base::TimeTicks::Now(); + } + + if (error != OK) { + net_log_callback = base::Bind(&NetLogProcTaskFailedCallback, + 0, error, os_error); + } else { + net_log_callback = results_.CreateNetLogCallback(); + } + net_log_.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_PROC_TASK, + net_log_callback); + + callback_.Run(error, results_); + } + + void RecordPerformanceHistograms(const base::TimeTicks& start_time, + const int error, + const int os_error) const { + DCHECK(origin_loop_->BelongsToCurrentThread()); + enum Category { // Used in HISTOGRAM_ENUMERATION. + RESOLVE_SUCCESS, + RESOLVE_FAIL, + RESOLVE_SPECULATIVE_SUCCESS, + RESOLVE_SPECULATIVE_FAIL, + RESOLVE_MAX, // Bounding value. + }; + int category = RESOLVE_MAX; // Illegal value for later DCHECK only. + + base::TimeDelta duration = base::TimeTicks::Now() - start_time; + if (error == OK) { + if (had_non_speculative_request_) { + category = RESOLVE_SUCCESS; + DNS_HISTOGRAM("DNS.ResolveSuccess", duration); + } else { + category = RESOLVE_SPECULATIVE_SUCCESS; + DNS_HISTOGRAM("DNS.ResolveSpeculativeSuccess", duration); + } + + // Log DNS lookups based on |address_family|. This will help us determine + // if IPv4 or IPv4/6 lookups are faster or slower. + switch(key_.address_family) { + case ADDRESS_FAMILY_IPV4: + DNS_HISTOGRAM("DNS.ResolveSuccess_FAMILY_IPV4", duration); + break; + case ADDRESS_FAMILY_IPV6: + DNS_HISTOGRAM("DNS.ResolveSuccess_FAMILY_IPV6", duration); + break; + case ADDRESS_FAMILY_UNSPECIFIED: + DNS_HISTOGRAM("DNS.ResolveSuccess_FAMILY_UNSPEC", duration); + break; + } + } else { + if (had_non_speculative_request_) { + category = RESOLVE_FAIL; + DNS_HISTOGRAM("DNS.ResolveFail", duration); + } else { + category = RESOLVE_SPECULATIVE_FAIL; + DNS_HISTOGRAM("DNS.ResolveSpeculativeFail", duration); + } + // Log DNS lookups based on |address_family|. This will help us determine + // if IPv4 or IPv4/6 lookups are faster or slower. + switch(key_.address_family) { + case ADDRESS_FAMILY_IPV4: + DNS_HISTOGRAM("DNS.ResolveFail_FAMILY_IPV4", duration); + break; + case ADDRESS_FAMILY_IPV6: + DNS_HISTOGRAM("DNS.ResolveFail_FAMILY_IPV6", duration); + break; + case ADDRESS_FAMILY_UNSPECIFIED: + DNS_HISTOGRAM("DNS.ResolveFail_FAMILY_UNSPEC", duration); + break; + } + UMA_HISTOGRAM_CUSTOM_ENUMERATION(kOSErrorsForGetAddrinfoHistogramName, + std::abs(os_error), + GetAllGetAddrinfoOSErrors()); + } + DCHECK_LT(category, static_cast<int>(RESOLVE_MAX)); // Be sure it was set. + + UMA_HISTOGRAM_ENUMERATION("DNS.ResolveCategory", category, RESOLVE_MAX); + } + + void RecordAttemptHistograms(const base::TimeTicks& start_time, + const uint32 attempt_number, + const int error, + const int os_error) const { + DCHECK(origin_loop_->BelongsToCurrentThread()); + bool first_attempt_to_complete = + completed_attempt_number_ == attempt_number; + bool is_first_attempt = (attempt_number == 1); + + if (first_attempt_to_complete) { + // If this was first attempt to complete, then record the resolution + // status of the attempt. + if (completed_attempt_error_ == OK) { + UMA_HISTOGRAM_ENUMERATION( + "DNS.AttemptFirstSuccess", attempt_number, 100); + } else { + UMA_HISTOGRAM_ENUMERATION( + "DNS.AttemptFirstFailure", attempt_number, 100); + } + } + + if (error == OK) + UMA_HISTOGRAM_ENUMERATION("DNS.AttemptSuccess", attempt_number, 100); + else + UMA_HISTOGRAM_ENUMERATION("DNS.AttemptFailure", attempt_number, 100); + + // If first attempt didn't finish before retry attempt, then calculate stats + // on how much time is saved by having spawned an extra attempt. + if (!first_attempt_to_complete && is_first_attempt && !was_canceled()) { + DNS_HISTOGRAM("DNS.AttemptTimeSavedByRetry", + base::TimeTicks::Now() - retry_attempt_finished_time_); + } + + if (was_canceled() || !first_attempt_to_complete) { + // Count those attempts which completed after the job was already canceled + // OR after the job was already completed by an earlier attempt (so in + // effect). + UMA_HISTOGRAM_ENUMERATION("DNS.AttemptDiscarded", attempt_number, 100); + + // Record if job is canceled. + if (was_canceled()) + UMA_HISTOGRAM_ENUMERATION("DNS.AttemptCancelled", attempt_number, 100); + } + + base::TimeDelta duration = base::TimeTicks::Now() - start_time; + if (error == OK) + DNS_HISTOGRAM("DNS.AttemptSuccessDuration", duration); + else + DNS_HISTOGRAM("DNS.AttemptFailDuration", duration); + } + + // Set on the origin thread, read on the worker thread. + Key key_; + + // Holds an owning reference to the HostResolverProc that we are going to use. + // This may not be the current resolver procedure by the time we call + // ResolveAddrInfo, but that's OK... we'll use it anyways, and the owning + // reference ensures that it remains valid until we are done. + ProcTaskParams params_; + + // The listener to the results of this ProcTask. + Callback callback_; + + // Used to post ourselves onto the origin thread. + scoped_refptr<base::MessageLoopProxy> origin_loop_; + + // Keeps track of the number of attempts we have made so far to resolve the + // host. Whenever we start an attempt to resolve the host, we increase this + // number. + uint32 attempt_number_; + + // The index of the attempt which finished first (or 0 if the job is still in + // progress). + uint32 completed_attempt_number_; + + // The result (a net error code) from the first attempt to complete. + int completed_attempt_error_; + + // The time when retry attempt was finished. + base::TimeTicks retry_attempt_finished_time_; + + // True if a non-speculative request was ever attached to this job + // (regardless of whether or not it was later canceled. + // This boolean is used for histogramming the duration of jobs used to + // service non-speculative requests. + bool had_non_speculative_request_; + + AddressList results_; + + BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(ProcTask); +}; + +//----------------------------------------------------------------------------- + +// Wraps a call to HaveOnlyLoopbackAddresses to be executed on the WorkerPool as +// it takes 40-100ms and should not block initialization. +class HostResolverImpl::LoopbackProbeJob { + public: + explicit LoopbackProbeJob(const base::WeakPtr<HostResolverImpl>& resolver) + : resolver_(resolver), + result_(false) { + DCHECK(resolver.get()); + const bool kIsSlow = true; + base::WorkerPool::PostTaskAndReply( + FROM_HERE, + base::Bind(&LoopbackProbeJob::DoProbe, base::Unretained(this)), + base::Bind(&LoopbackProbeJob::OnProbeComplete, base::Owned(this)), + kIsSlow); + } + + virtual ~LoopbackProbeJob() {} + + private: + // Runs on worker thread. + void DoProbe() { + result_ = HaveOnlyLoopbackAddresses(); + } + + void OnProbeComplete() { + if (!resolver_.get()) + return; + resolver_->SetHaveOnlyLoopbackAddresses(result_); + } + + // Used/set only on origin thread. + base::WeakPtr<HostResolverImpl> resolver_; + + bool result_; + + DISALLOW_COPY_AND_ASSIGN(LoopbackProbeJob); +}; + +//----------------------------------------------------------------------------- + +// Resolves the hostname using DnsTransaction. +// TODO(szym): This could be moved to separate source file as well. +class HostResolverImpl::DnsTask : public base::SupportsWeakPtr<DnsTask> { + public: + typedef base::Callback<void(int net_error, + const AddressList& addr_list, + base::TimeDelta ttl)> Callback; + + DnsTask(DnsClient* client, + const Key& key, + const Callback& callback, + const BoundNetLog& job_net_log) + : client_(client), + family_(key.address_family), + callback_(callback), + net_log_(job_net_log) { + DCHECK(client); + DCHECK(!callback.is_null()); + + // If unspecified, do IPv4 first, because suffix search will be faster. + uint16 qtype = (family_ == ADDRESS_FAMILY_IPV6) ? + dns_protocol::kTypeAAAA : + dns_protocol::kTypeA; + transaction_ = client_->GetTransactionFactory()->CreateTransaction( + key.hostname, + qtype, + base::Bind(&DnsTask::OnTransactionComplete, base::Unretained(this), + true /* first_query */, base::TimeTicks::Now()), + net_log_); + } + + void Start() { + net_log_.BeginEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_DNS_TASK); + transaction_->Start(); + } + + private: + void OnTransactionComplete(bool first_query, + const base::TimeTicks& start_time, + DnsTransaction* transaction, + int net_error, + const DnsResponse* response) { + DCHECK(transaction); + base::TimeDelta duration = base::TimeTicks::Now() - start_time; + // Run |callback_| last since the owning Job will then delete this DnsTask. + if (net_error != OK) { + DNS_HISTOGRAM("AsyncDNS.TransactionFailure", duration); + OnFailure(net_error, DnsResponse::DNS_PARSE_OK); + return; + } + + CHECK(response); + DNS_HISTOGRAM("AsyncDNS.TransactionSuccess", duration); + switch (transaction->GetType()) { + case dns_protocol::kTypeA: + DNS_HISTOGRAM("AsyncDNS.TransactionSuccess_A", duration); + break; + case dns_protocol::kTypeAAAA: + DNS_HISTOGRAM("AsyncDNS.TransactionSuccess_AAAA", duration); + break; + } + AddressList addr_list; + base::TimeDelta ttl; + DnsResponse::Result result = response->ParseToAddressList(&addr_list, &ttl); + UMA_HISTOGRAM_ENUMERATION("AsyncDNS.ParseToAddressList", + result, + DnsResponse::DNS_PARSE_RESULT_MAX); + if (result != DnsResponse::DNS_PARSE_OK) { + // Fail even if the other query succeeds. + OnFailure(ERR_DNS_MALFORMED_RESPONSE, result); + return; + } + + bool needs_sort = false; + if (first_query) { + DCHECK(client_->GetConfig()) << + "Transaction should have been aborted when config changed!"; + if (family_ == ADDRESS_FAMILY_IPV6) { + needs_sort = (addr_list.size() > 1); + } else if (family_ == ADDRESS_FAMILY_UNSPECIFIED) { + first_addr_list_ = addr_list; + first_ttl_ = ttl; + // Use fully-qualified domain name to avoid search. + transaction_ = client_->GetTransactionFactory()->CreateTransaction( + response->GetDottedName() + ".", + dns_protocol::kTypeAAAA, + base::Bind(&DnsTask::OnTransactionComplete, base::Unretained(this), + false /* first_query */, base::TimeTicks::Now()), + net_log_); + transaction_->Start(); + return; + } + } else { + DCHECK_EQ(ADDRESS_FAMILY_UNSPECIFIED, family_); + bool has_ipv6_addresses = !addr_list.empty(); + if (!first_addr_list_.empty()) { + ttl = std::min(ttl, first_ttl_); + // Place IPv4 addresses after IPv6. + addr_list.insert(addr_list.end(), first_addr_list_.begin(), + first_addr_list_.end()); + } + needs_sort = (has_ipv6_addresses && addr_list.size() > 1); + } + + if (addr_list.empty()) { + // TODO(szym): Don't fallback to ProcTask in this case. + OnFailure(ERR_NAME_NOT_RESOLVED, DnsResponse::DNS_PARSE_OK); + return; + } + + if (needs_sort) { + // Sort could complete synchronously. + client_->GetAddressSorter()->Sort( + addr_list, + base::Bind(&DnsTask::OnSortComplete, + AsWeakPtr(), + base::TimeTicks::Now(), + ttl)); + } else { + OnSuccess(addr_list, ttl); + } + } + + void OnSortComplete(base::TimeTicks start_time, + base::TimeDelta ttl, + bool success, + const AddressList& addr_list) { + if (!success) { + DNS_HISTOGRAM("AsyncDNS.SortFailure", + base::TimeTicks::Now() - start_time); + OnFailure(ERR_DNS_SORT_ERROR, DnsResponse::DNS_PARSE_OK); + return; + } + + DNS_HISTOGRAM("AsyncDNS.SortSuccess", + base::TimeTicks::Now() - start_time); + + // AddressSorter prunes unusable destinations. + if (addr_list.empty()) { + LOG(WARNING) << "Address list empty after RFC3484 sort"; + OnFailure(ERR_NAME_NOT_RESOLVED, DnsResponse::DNS_PARSE_OK); + return; + } + + OnSuccess(addr_list, ttl); + } + + void OnFailure(int net_error, DnsResponse::Result result) { + DCHECK_NE(OK, net_error); + net_log_.EndEvent( + NetLog::TYPE_HOST_RESOLVER_IMPL_DNS_TASK, + base::Bind(&NetLogDnsTaskFailedCallback, net_error, result)); + callback_.Run(net_error, AddressList(), base::TimeDelta()); + } + + void OnSuccess(const AddressList& addr_list, base::TimeDelta ttl) { + net_log_.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_DNS_TASK, + addr_list.CreateNetLogCallback()); + callback_.Run(OK, addr_list, ttl); + } + + DnsClient* client_; + AddressFamily family_; + // The listener to the results of this DnsTask. + Callback callback_; + const BoundNetLog net_log_; + + scoped_ptr<DnsTransaction> transaction_; + + // Results from the first transaction. Used only if |family_| is unspecified. + AddressList first_addr_list_; + base::TimeDelta first_ttl_; + + DISALLOW_COPY_AND_ASSIGN(DnsTask); +}; + +//----------------------------------------------------------------------------- + +// Aggregates all Requests for the same Key. Dispatched via PriorityDispatch. +class HostResolverImpl::Job : public PrioritizedDispatcher::Job { + public: + // Creates new job for |key| where |request_net_log| is bound to the + // request that spawned it. + Job(const base::WeakPtr<HostResolverImpl>& resolver, + const Key& key, + RequestPriority priority, + const BoundNetLog& request_net_log) + : resolver_(resolver), + key_(key), + priority_tracker_(priority), + had_non_speculative_request_(false), + had_dns_config_(false), + dns_task_error_(OK), + creation_time_(base::TimeTicks::Now()), + priority_change_time_(creation_time_), + net_log_(BoundNetLog::Make(request_net_log.net_log(), + NetLog::SOURCE_HOST_RESOLVER_IMPL_JOB)) { + request_net_log.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_CREATE_JOB); + + net_log_.BeginEvent( + NetLog::TYPE_HOST_RESOLVER_IMPL_JOB, + base::Bind(&NetLogJobCreationCallback, + request_net_log.source(), + &key_.hostname)); + } + + virtual ~Job() { + if (is_running()) { + // |resolver_| was destroyed with this Job still in flight. + // Clean-up, record in the log, but don't run any callbacks. + if (is_proc_running()) { + proc_task_->Cancel(); + proc_task_ = NULL; + } + // Clean up now for nice NetLog. + dns_task_.reset(NULL); + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB, + ERR_ABORTED); + } else if (is_queued()) { + // |resolver_| was destroyed without running this Job. + // TODO(szym): is there any benefit in having this distinction? + net_log_.AddEvent(NetLog::TYPE_CANCELLED); + net_log_.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB); + } + // else CompleteRequests logged EndEvent. + + // Log any remaining Requests as cancelled. + for (RequestsList::const_iterator it = requests_.begin(); + it != requests_.end(); ++it) { + Request* req = *it; + if (req->was_canceled()) + continue; + DCHECK_EQ(this, req->job()); + LogCancelRequest(req->source_net_log(), req->request_net_log(), + req->info()); + } + } + + // Add this job to the dispatcher. + void Schedule() { + handle_ = resolver_->dispatcher_.Add(this, priority()); + } + + void AddRequest(scoped_ptr<Request> req) { + DCHECK_EQ(key_.hostname, req->info().hostname()); + + req->set_job(this); + priority_tracker_.Add(req->info().priority()); + + req->request_net_log().AddEvent( + NetLog::TYPE_HOST_RESOLVER_IMPL_JOB_ATTACH, + net_log_.source().ToEventParametersCallback()); + + net_log_.AddEvent( + NetLog::TYPE_HOST_RESOLVER_IMPL_JOB_REQUEST_ATTACH, + base::Bind(&NetLogJobAttachCallback, + req->request_net_log().source(), + priority())); + + // TODO(szym): Check if this is still needed. + if (!req->info().is_speculative()) { + had_non_speculative_request_ = true; + if (proc_task_.get()) + proc_task_->set_had_non_speculative_request(); + } + + requests_.push_back(req.release()); + + UpdatePriority(); + } + + // Marks |req| as cancelled. If it was the last active Request, also finishes + // this Job, marking it as cancelled, and deletes it. + void CancelRequest(Request* req) { + DCHECK_EQ(key_.hostname, req->info().hostname()); + DCHECK(!req->was_canceled()); + + // Don't remove it from |requests_| just mark it canceled. + req->MarkAsCanceled(); + LogCancelRequest(req->source_net_log(), req->request_net_log(), + req->info()); + + priority_tracker_.Remove(req->info().priority()); + net_log_.AddEvent( + NetLog::TYPE_HOST_RESOLVER_IMPL_JOB_REQUEST_DETACH, + base::Bind(&NetLogJobAttachCallback, + req->request_net_log().source(), + priority())); + + if (num_active_requests() > 0) { + UpdatePriority(); + } else { + // If we were called from a Request's callback within CompleteRequests, + // that Request could not have been cancelled, so num_active_requests() + // could not be 0. Therefore, we are not in CompleteRequests(). + CompleteRequestsWithError(OK /* cancelled */); + } + } + + // Called from AbortAllInProgressJobs. Completes all requests and destroys + // the job. This currently assumes the abort is due to a network change. + void Abort() { + DCHECK(is_running()); + CompleteRequestsWithError(ERR_NETWORK_CHANGED); + } + + // If DnsTask present, abort it and fall back to ProcTask. + void AbortDnsTask() { + if (dns_task_) { + dns_task_.reset(); + dns_task_error_ = OK; + StartProcTask(); + } + } + + // Called by HostResolverImpl when this job is evicted due to queue overflow. + // Completes all requests and destroys the job. + void OnEvicted() { + DCHECK(!is_running()); + DCHECK(is_queued()); + handle_.Reset(); + + net_log_.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB_EVICTED); + + // This signals to CompleteRequests that this job never ran. + CompleteRequestsWithError(ERR_HOST_RESOLVER_QUEUE_TOO_LARGE); + } + + // Attempts to serve the job from HOSTS. Returns true if succeeded and + // this Job was destroyed. + bool ServeFromHosts() { + DCHECK_GT(num_active_requests(), 0u); + AddressList addr_list; + if (resolver_->ServeFromHosts(key(), + requests_.front()->info(), + &addr_list)) { + // This will destroy the Job. + CompleteRequests( + HostCache::Entry(OK, MakeAddressListForRequest(addr_list)), + base::TimeDelta()); + return true; + } + return false; + } + + const Key key() const { + return key_; + } + + bool is_queued() const { + return !handle_.is_null(); + } + + bool is_running() const { + return is_dns_running() || is_proc_running(); + } + + private: + void UpdatePriority() { + if (is_queued()) { + if (priority() != static_cast<RequestPriority>(handle_.priority())) + priority_change_time_ = base::TimeTicks::Now(); + handle_ = resolver_->dispatcher_.ChangePriority(handle_, priority()); + } + } + + AddressList MakeAddressListForRequest(const AddressList& list) const { + if (requests_.empty()) + return list; + return AddressList::CopyWithPort(list, requests_.front()->info().port()); + } + + // PriorityDispatch::Job: + virtual void Start() OVERRIDE { + DCHECK(!is_running()); + handle_.Reset(); + + net_log_.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB_STARTED); + + had_dns_config_ = resolver_->HaveDnsConfig(); + + base::TimeTicks now = base::TimeTicks::Now(); + base::TimeDelta queue_time = now - creation_time_; + base::TimeDelta queue_time_after_change = now - priority_change_time_; + + if (had_dns_config_) { + DNS_HISTOGRAM_BY_PRIORITY("AsyncDNS.JobQueueTime", priority(), + queue_time); + DNS_HISTOGRAM_BY_PRIORITY("AsyncDNS.JobQueueTimeAfterChange", priority(), + queue_time_after_change); + } else { + DNS_HISTOGRAM_BY_PRIORITY("DNS.JobQueueTime", priority(), queue_time); + DNS_HISTOGRAM_BY_PRIORITY("DNS.JobQueueTimeAfterChange", priority(), + queue_time_after_change); + } + + // Caution: Job::Start must not complete synchronously. + if (had_dns_config_ && !ResemblesMulticastDNSName(key_.hostname)) { + StartDnsTask(); + } else { + StartProcTask(); + } + } + + // TODO(szym): Since DnsTransaction does not consume threads, we can increase + // the limits on |dispatcher_|. But in order to keep the number of WorkerPool + // threads low, we will need to use an "inner" PrioritizedDispatcher with + // tighter limits. + void StartProcTask() { + DCHECK(!is_dns_running()); + proc_task_ = new ProcTask( + key_, + resolver_->proc_params_, + base::Bind(&Job::OnProcTaskComplete, base::Unretained(this), + base::TimeTicks::Now()), + net_log_); + + if (had_non_speculative_request_) + proc_task_->set_had_non_speculative_request(); + // Start() could be called from within Resolve(), hence it must NOT directly + // call OnProcTaskComplete, for example, on synchronous failure. + proc_task_->Start(); + } + + // Called by ProcTask when it completes. + void OnProcTaskComplete(base::TimeTicks start_time, + int net_error, + const AddressList& addr_list) { + DCHECK(is_proc_running()); + + if (!resolver_->resolved_known_ipv6_hostname_ && + net_error == OK && + key_.address_family == ADDRESS_FAMILY_UNSPECIFIED) { + if (key_.hostname == "www.google.com") { + resolver_->resolved_known_ipv6_hostname_ = true; + bool got_ipv6_address = false; + for (size_t i = 0; i < addr_list.size(); ++i) { + if (addr_list[i].GetFamily() == ADDRESS_FAMILY_IPV6) { + got_ipv6_address = true; + break; + } + } + UMA_HISTOGRAM_BOOLEAN("Net.UnspecResolvedIPv6", got_ipv6_address); + } + } + + if (dns_task_error_ != OK) { + base::TimeDelta duration = base::TimeTicks::Now() - start_time; + if (net_error == OK) { + DNS_HISTOGRAM("AsyncDNS.FallbackSuccess", duration); + if ((dns_task_error_ == ERR_NAME_NOT_RESOLVED) && + ResemblesNetBIOSName(key_.hostname)) { + UmaAsyncDnsResolveStatus(RESOLVE_STATUS_SUSPECT_NETBIOS); + } else { + UmaAsyncDnsResolveStatus(RESOLVE_STATUS_PROC_SUCCESS); + } + UMA_HISTOGRAM_CUSTOM_ENUMERATION("AsyncDNS.ResolveError", + std::abs(dns_task_error_), + GetAllErrorCodesForUma()); + resolver_->OnDnsTaskResolve(dns_task_error_); + } else { + DNS_HISTOGRAM("AsyncDNS.FallbackFail", duration); + UmaAsyncDnsResolveStatus(RESOLVE_STATUS_FAIL); + } + } + + base::TimeDelta ttl = + base::TimeDelta::FromSeconds(kNegativeCacheEntryTTLSeconds); + if (net_error == OK) + ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds); + + // Don't store the |ttl| in cache since it's not obtained from the server. + CompleteRequests( + HostCache::Entry(net_error, MakeAddressListForRequest(addr_list)), + ttl); + } + + void StartDnsTask() { + DCHECK(resolver_->HaveDnsConfig()); + base::TimeTicks start_time = base::TimeTicks::Now(); + dns_task_.reset(new DnsTask( + resolver_->dns_client_.get(), + key_, + base::Bind(&Job::OnDnsTaskComplete, base::Unretained(this), start_time), + net_log_)); + + dns_task_->Start(); + } + + // Called if DnsTask fails. It is posted from StartDnsTask, so Job may be + // deleted before this callback. In this case dns_task is deleted as well, + // so we use it as indicator whether Job is still valid. + void OnDnsTaskFailure(const base::WeakPtr<DnsTask>& dns_task, + base::TimeDelta duration, + int net_error) { + DNS_HISTOGRAM("AsyncDNS.ResolveFail", duration); + + if (dns_task == NULL) + return; + + dns_task_error_ = net_error; + + // TODO(szym): Run ServeFromHosts now if nsswitch.conf says so. + // http://crbug.com/117655 + + // TODO(szym): Some net errors indicate lack of connectivity. Starting + // ProcTask in that case is a waste of time. + if (resolver_->fallback_to_proctask_) { + dns_task_.reset(); + StartProcTask(); + } else { + UmaAsyncDnsResolveStatus(RESOLVE_STATUS_FAIL); + CompleteRequestsWithError(net_error); + } + } + + // Called by DnsTask when it completes. + void OnDnsTaskComplete(base::TimeTicks start_time, + int net_error, + const AddressList& addr_list, + base::TimeDelta ttl) { + DCHECK(is_dns_running()); + + base::TimeDelta duration = base::TimeTicks::Now() - start_time; + if (net_error != OK) { + OnDnsTaskFailure(dns_task_->AsWeakPtr(), duration, net_error); + return; + } + DNS_HISTOGRAM("AsyncDNS.ResolveSuccess", duration); + // Log DNS lookups based on |address_family|. + switch(key_.address_family) { + case ADDRESS_FAMILY_IPV4: + DNS_HISTOGRAM("AsyncDNS.ResolveSuccess_FAMILY_IPV4", duration); + break; + case ADDRESS_FAMILY_IPV6: + DNS_HISTOGRAM("AsyncDNS.ResolveSuccess_FAMILY_IPV6", duration); + break; + case ADDRESS_FAMILY_UNSPECIFIED: + DNS_HISTOGRAM("AsyncDNS.ResolveSuccess_FAMILY_UNSPEC", duration); + break; + } + + UmaAsyncDnsResolveStatus(RESOLVE_STATUS_DNS_SUCCESS); + RecordTTL(ttl); + + resolver_->OnDnsTaskResolve(OK); + + base::TimeDelta bounded_ttl = + std::max(ttl, base::TimeDelta::FromSeconds(kMinimumTTLSeconds)); + + CompleteRequests( + HostCache::Entry(net_error, MakeAddressListForRequest(addr_list), ttl), + bounded_ttl); + } + + // Performs Job's last rites. Completes all Requests. Deletes this. + void CompleteRequests(const HostCache::Entry& entry, + base::TimeDelta ttl) { + CHECK(resolver_.get()); + + // This job must be removed from resolver's |jobs_| now to make room for a + // new job with the same key in case one of the OnComplete callbacks decides + // to spawn one. Consequently, the job deletes itself when CompleteRequests + // is done. + scoped_ptr<Job> self_deleter(this); + + resolver_->RemoveJob(this); + + if (is_running()) { + DCHECK(!is_queued()); + if (is_proc_running()) { + proc_task_->Cancel(); + proc_task_ = NULL; + } + dns_task_.reset(); + + // Signal dispatcher that a slot has opened. + resolver_->dispatcher_.OnJobFinished(); + } else if (is_queued()) { + resolver_->dispatcher_.Cancel(handle_); + handle_.Reset(); + } + + if (num_active_requests() == 0) { + net_log_.AddEvent(NetLog::TYPE_CANCELLED); + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB, + OK); + return; + } + + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB, + entry.error); + + DCHECK(!requests_.empty()); + + if (entry.error == OK) { + // Record this histogram here, when we know the system has a valid DNS + // configuration. + UMA_HISTOGRAM_BOOLEAN("AsyncDNS.HaveDnsConfig", + resolver_->received_dns_config_); + } + + bool did_complete = (entry.error != ERR_NETWORK_CHANGED) && + (entry.error != ERR_HOST_RESOLVER_QUEUE_TOO_LARGE); + if (did_complete) + resolver_->CacheResult(key_, entry, ttl); + + // Complete all of the requests that were attached to the job. + for (RequestsList::const_iterator it = requests_.begin(); + it != requests_.end(); ++it) { + Request* req = *it; + + if (req->was_canceled()) + continue; + + DCHECK_EQ(this, req->job()); + // Update the net log and notify registered observers. + LogFinishRequest(req->source_net_log(), req->request_net_log(), + req->info(), entry.error); + if (did_complete) { + // Record effective total time from creation to completion. + RecordTotalTime(had_dns_config_, req->info().is_speculative(), + base::TimeTicks::Now() - req->request_time()); + } + req->OnComplete(entry.error, entry.addrlist); + + // Check if the resolver was destroyed as a result of running the + // callback. If it was, we could continue, but we choose to bail. + if (!resolver_.get()) + return; + } + } + + // Convenience wrapper for CompleteRequests in case of failure. + void CompleteRequestsWithError(int net_error) { + CompleteRequests(HostCache::Entry(net_error, AddressList()), + base::TimeDelta()); + } + + RequestPriority priority() const { + return priority_tracker_.highest_priority(); + } + + // Number of non-canceled requests in |requests_|. + size_t num_active_requests() const { + return priority_tracker_.total_count(); + } + + bool is_dns_running() const { + return dns_task_.get() != NULL; + } + + bool is_proc_running() const { + return proc_task_.get() != NULL; + } + + base::WeakPtr<HostResolverImpl> resolver_; + + Key key_; + + // Tracks the highest priority across |requests_|. + PriorityTracker priority_tracker_; + + bool had_non_speculative_request_; + + // Distinguishes measurements taken while DnsClient was fully configured. + bool had_dns_config_; + + // Result of DnsTask. + int dns_task_error_; + + const base::TimeTicks creation_time_; + base::TimeTicks priority_change_time_; + + BoundNetLog net_log_; + + // Resolves the host using a HostResolverProc. + scoped_refptr<ProcTask> proc_task_; + + // Resolves the host using a DnsTransaction. + scoped_ptr<DnsTask> dns_task_; + + // All Requests waiting for the result of this Job. Some can be canceled. + RequestsList requests_; + + // A handle used in |HostResolverImpl::dispatcher_|. + PrioritizedDispatcher::Handle handle_; +}; + +//----------------------------------------------------------------------------- + +HostResolverImpl::ProcTaskParams::ProcTaskParams( + HostResolverProc* resolver_proc, + size_t max_retry_attempts) + : resolver_proc(resolver_proc), + max_retry_attempts(max_retry_attempts), + unresponsive_delay(base::TimeDelta::FromMilliseconds(6000)), + retry_factor(2) { +} + +HostResolverImpl::ProcTaskParams::~ProcTaskParams() {} + +HostResolverImpl::HostResolverImpl( + scoped_ptr<HostCache> cache, + const PrioritizedDispatcher::Limits& job_limits, + const ProcTaskParams& proc_params, + NetLog* net_log) + : cache_(cache.Pass()), + dispatcher_(job_limits), + max_queued_jobs_(job_limits.total_jobs * 100u), + proc_params_(proc_params), + net_log_(net_log), + default_address_family_(ADDRESS_FAMILY_UNSPECIFIED), + weak_ptr_factory_(this), + probe_weak_ptr_factory_(this), + received_dns_config_(false), + num_dns_failures_(0), + probe_ipv6_support_(true), + resolved_known_ipv6_hostname_(false), + additional_resolver_flags_(0), + fallback_to_proctask_(true) { + + DCHECK_GE(dispatcher_.num_priorities(), static_cast<size_t>(NUM_PRIORITIES)); + + // Maximum of 4 retry attempts for host resolution. + static const size_t kDefaultMaxRetryAttempts = 4u; + + if (proc_params_.max_retry_attempts == HostResolver::kDefaultRetryAttempts) + proc_params_.max_retry_attempts = kDefaultMaxRetryAttempts; + +#if defined(OS_WIN) + EnsureWinsockInit(); +#endif +#if defined(OS_POSIX) && !defined(OS_MACOSX) && !defined(OS_ANDROID) + new LoopbackProbeJob(weak_ptr_factory_.GetWeakPtr()); +#endif + NetworkChangeNotifier::AddIPAddressObserver(this); + NetworkChangeNotifier::AddDNSObserver(this); +#if defined(OS_POSIX) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) && \ + !defined(OS_ANDROID) + EnsureDnsReloaderInit(); +#endif + + // TODO(szym): Remove when received_dns_config_ is removed, once + // http://crbug.com/137914 is resolved. + { + DnsConfig dns_config; + NetworkChangeNotifier::GetDnsConfig(&dns_config); + received_dns_config_ = dns_config.IsValid(); + } + + fallback_to_proctask_ = !ConfigureAsyncDnsNoFallbackFieldTrial(); +} + +HostResolverImpl::~HostResolverImpl() { + // This will also cancel all outstanding requests. + STLDeleteValues(&jobs_); + + NetworkChangeNotifier::RemoveIPAddressObserver(this); + NetworkChangeNotifier::RemoveDNSObserver(this); +} + +void HostResolverImpl::SetMaxQueuedJobs(size_t value) { + DCHECK_EQ(0u, dispatcher_.num_queued_jobs()); + DCHECK_GT(value, 0u); + max_queued_jobs_ = value; +} + +int HostResolverImpl::Resolve(const RequestInfo& info, + AddressList* addresses, + const CompletionCallback& callback, + RequestHandle* out_req, + const BoundNetLog& source_net_log) { + DCHECK(addresses); + DCHECK(CalledOnValidThread()); + DCHECK_EQ(false, callback.is_null()); + + // Check that the caller supplied a valid hostname to resolve. + std::string labeled_hostname; + if (!DNSDomainFromDot(info.hostname(), &labeled_hostname)) + return ERR_NAME_NOT_RESOLVED; + + // Make a log item for the request. + BoundNetLog request_net_log = BoundNetLog::Make(net_log_, + NetLog::SOURCE_HOST_RESOLVER_IMPL_REQUEST); + + LogStartRequest(source_net_log, request_net_log, info); + + // Build a key that identifies the request in the cache and in the + // outstanding jobs map. + Key key = GetEffectiveKeyForRequest(info, request_net_log); + + int rv = ResolveHelper(key, info, addresses, request_net_log); + if (rv != ERR_DNS_CACHE_MISS) { + LogFinishRequest(source_net_log, request_net_log, info, rv); + RecordTotalTime(HaveDnsConfig(), info.is_speculative(), base::TimeDelta()); + return rv; + } + + // Next we need to attach our request to a "job". This job is responsible for + // calling "getaddrinfo(hostname)" on a worker thread. + + JobMap::iterator jobit = jobs_.find(key); + Job* job; + if (jobit == jobs_.end()) { + job = new Job(weak_ptr_factory_.GetWeakPtr(), key, info.priority(), + request_net_log); + job->Schedule(); + + // Check for queue overflow. + if (dispatcher_.num_queued_jobs() > max_queued_jobs_) { + Job* evicted = static_cast<Job*>(dispatcher_.EvictOldestLowest()); + DCHECK(evicted); + evicted->OnEvicted(); // Deletes |evicted|. + if (evicted == job) { + rv = ERR_HOST_RESOLVER_QUEUE_TOO_LARGE; + LogFinishRequest(source_net_log, request_net_log, info, rv); + return rv; + } + } + jobs_.insert(jobit, std::make_pair(key, job)); + } else { + job = jobit->second; + } + + // Can't complete synchronously. Create and attach request. + scoped_ptr<Request> req(new Request(source_net_log, + request_net_log, + info, + callback, + addresses)); + if (out_req) + *out_req = reinterpret_cast<RequestHandle>(req.get()); + + job->AddRequest(req.Pass()); + // Completion happens during Job::CompleteRequests(). + return ERR_IO_PENDING; +} + +int HostResolverImpl::ResolveHelper(const Key& key, + const RequestInfo& info, + AddressList* addresses, + const BoundNetLog& request_net_log) { + // The result of |getaddrinfo| for empty hosts is inconsistent across systems. + // On Windows it gives the default interface's address, whereas on Linux it + // gives an error. We will make it fail on all platforms for consistency. + if (info.hostname().empty() || info.hostname().size() > kMaxHostLength) + return ERR_NAME_NOT_RESOLVED; + + int net_error = ERR_UNEXPECTED; + if (ResolveAsIP(key, info, &net_error, addresses)) + return net_error; + if (ServeFromCache(key, info, &net_error, addresses)) { + request_net_log.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_CACHE_HIT); + return net_error; + } + // TODO(szym): Do not do this if nsswitch.conf instructs not to. + // http://crbug.com/117655 + if (ServeFromHosts(key, info, addresses)) { + request_net_log.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_HOSTS_HIT); + return OK; + } + return ERR_DNS_CACHE_MISS; +} + +int HostResolverImpl::ResolveFromCache(const RequestInfo& info, + AddressList* addresses, + const BoundNetLog& source_net_log) { + DCHECK(CalledOnValidThread()); + DCHECK(addresses); + + // Make a log item for the request. + BoundNetLog request_net_log = BoundNetLog::Make(net_log_, + NetLog::SOURCE_HOST_RESOLVER_IMPL_REQUEST); + + // Update the net log and notify registered observers. + LogStartRequest(source_net_log, request_net_log, info); + + Key key = GetEffectiveKeyForRequest(info, request_net_log); + + int rv = ResolveHelper(key, info, addresses, request_net_log); + LogFinishRequest(source_net_log, request_net_log, info, rv); + return rv; +} + +void HostResolverImpl::CancelRequest(RequestHandle req_handle) { + DCHECK(CalledOnValidThread()); + Request* req = reinterpret_cast<Request*>(req_handle); + DCHECK(req); + Job* job = req->job(); + DCHECK(job); + job->CancelRequest(req); +} + +void HostResolverImpl::SetDefaultAddressFamily(AddressFamily address_family) { + DCHECK(CalledOnValidThread()); + default_address_family_ = address_family; + probe_ipv6_support_ = false; +} + +AddressFamily HostResolverImpl::GetDefaultAddressFamily() const { + return default_address_family_; +} + +void HostResolverImpl::SetDnsClientEnabled(bool enabled) { + DCHECK(CalledOnValidThread()); +#if defined(ENABLE_BUILT_IN_DNS) + if (enabled && !dns_client_) { + SetDnsClient(DnsClient::CreateClient(net_log_)); + } else if (!enabled && dns_client_) { + SetDnsClient(scoped_ptr<DnsClient>()); + } +#endif +} + +HostCache* HostResolverImpl::GetHostCache() { + return cache_.get(); +} + +base::Value* HostResolverImpl::GetDnsConfigAsValue() const { + // Check if async DNS is disabled. + if (!dns_client_.get()) + return NULL; + + // Check if async DNS is enabled, but we currently have no configuration + // for it. + const DnsConfig* dns_config = dns_client_->GetConfig(); + if (dns_config == NULL) + return new base::DictionaryValue(); + + return dns_config->ToValue(); +} + +bool HostResolverImpl::ResolveAsIP(const Key& key, + const RequestInfo& info, + int* net_error, + AddressList* addresses) { + DCHECK(addresses); + DCHECK(net_error); + IPAddressNumber ip_number; + if (!ParseIPLiteralToNumber(key.hostname, &ip_number)) + return false; + + DCHECK_EQ(key.host_resolver_flags & + ~(HOST_RESOLVER_CANONNAME | HOST_RESOLVER_LOOPBACK_ONLY | + HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6), + 0) << " Unhandled flag"; + bool ipv6_disabled = (default_address_family_ == ADDRESS_FAMILY_IPV4) && + !probe_ipv6_support_; + *net_error = OK; + if ((ip_number.size() == kIPv6AddressSize) && ipv6_disabled) { + *net_error = ERR_NAME_NOT_RESOLVED; + } else { + *addresses = AddressList::CreateFromIPAddress(ip_number, info.port()); + if (key.host_resolver_flags & HOST_RESOLVER_CANONNAME) + addresses->SetDefaultCanonicalName(); + } + return true; +} + +bool HostResolverImpl::ServeFromCache(const Key& key, + const RequestInfo& info, + int* net_error, + AddressList* addresses) { + DCHECK(addresses); + DCHECK(net_error); + if (!info.allow_cached_response() || !cache_.get()) + return false; + + const HostCache::Entry* cache_entry = cache_->Lookup( + key, base::TimeTicks::Now()); + if (!cache_entry) + return false; + + *net_error = cache_entry->error; + if (*net_error == OK) { + if (cache_entry->has_ttl()) + RecordTTL(cache_entry->ttl); + *addresses = EnsurePortOnAddressList(cache_entry->addrlist, info.port()); + } + return true; +} + +bool HostResolverImpl::ServeFromHosts(const Key& key, + const RequestInfo& info, + AddressList* addresses) { + DCHECK(addresses); + if (!HaveDnsConfig()) + return false; + addresses->clear(); + + // HOSTS lookups are case-insensitive. + std::string hostname = StringToLowerASCII(key.hostname); + + const DnsHosts& hosts = dns_client_->GetConfig()->hosts; + + // If |address_family| is ADDRESS_FAMILY_UNSPECIFIED other implementations + // (glibc and c-ares) return the first matching line. We have more + // flexibility, but lose implicit ordering. + // We prefer IPv6 because "happy eyeballs" will fall back to IPv4 if + // necessary. + if (key.address_family == ADDRESS_FAMILY_IPV6 || + key.address_family == ADDRESS_FAMILY_UNSPECIFIED) { + DnsHosts::const_iterator it = hosts.find( + DnsHostsKey(hostname, ADDRESS_FAMILY_IPV6)); + if (it != hosts.end()) + addresses->push_back(IPEndPoint(it->second, info.port())); + } + + if (key.address_family == ADDRESS_FAMILY_IPV4 || + key.address_family == ADDRESS_FAMILY_UNSPECIFIED) { + DnsHosts::const_iterator it = hosts.find( + DnsHostsKey(hostname, ADDRESS_FAMILY_IPV4)); + if (it != hosts.end()) + addresses->push_back(IPEndPoint(it->second, info.port())); + } + + // If got only loopback addresses and the family was restricted, resolve + // again, without restrictions. See SystemHostResolverCall for rationale. + if ((key.host_resolver_flags & + HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6) && + IsAllIPv4Loopback(*addresses)) { + Key new_key(key); + new_key.address_family = ADDRESS_FAMILY_UNSPECIFIED; + new_key.host_resolver_flags &= + ~HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; + return ServeFromHosts(new_key, info, addresses); + } + return !addresses->empty(); +} + +void HostResolverImpl::CacheResult(const Key& key, + const HostCache::Entry& entry, + base::TimeDelta ttl) { + if (cache_.get()) + cache_->Set(key, entry, base::TimeTicks::Now(), ttl); +} + +void HostResolverImpl::RemoveJob(Job* job) { + DCHECK(job); + JobMap::iterator it = jobs_.find(job->key()); + if (it != jobs_.end() && it->second == job) + jobs_.erase(it); +} + +void HostResolverImpl::SetHaveOnlyLoopbackAddresses(bool result) { + if (result) { + additional_resolver_flags_ |= HOST_RESOLVER_LOOPBACK_ONLY; + } else { + additional_resolver_flags_ &= ~HOST_RESOLVER_LOOPBACK_ONLY; + } +} + +HostResolverImpl::Key HostResolverImpl::GetEffectiveKeyForRequest( + const RequestInfo& info, const BoundNetLog& net_log) const { + HostResolverFlags effective_flags = + info.host_resolver_flags() | additional_resolver_flags_; + AddressFamily effective_address_family = info.address_family(); + + if (info.address_family() == ADDRESS_FAMILY_UNSPECIFIED) { + if (probe_ipv6_support_) { + base::TimeTicks start_time = base::TimeTicks::Now(); + // Google DNS address. + const uint8 kIPv6Address[] = + { 0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x88 }; + IPAddressNumber address(kIPv6Address, + kIPv6Address + arraysize(kIPv6Address)); + bool rv6 = IsGloballyReachable(address, net_log); + if (rv6) + net_log.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_IPV6_SUPPORTED); + + UMA_HISTOGRAM_TIMES("Net.IPv6ConnectDuration", + base::TimeTicks::Now() - start_time); + if (rv6) { + UMA_HISTOGRAM_BOOLEAN("Net.IPv6ConnectSuccessMatch", + default_address_family_ == ADDRESS_FAMILY_UNSPECIFIED); + } else { + UMA_HISTOGRAM_BOOLEAN("Net.IPv6ConnectFailureMatch", + default_address_family_ != ADDRESS_FAMILY_UNSPECIFIED); + + effective_address_family = ADDRESS_FAMILY_IPV4; + effective_flags |= HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; + } + } else { + effective_address_family = default_address_family_; + } + } + + return Key(info.hostname(), effective_address_family, effective_flags); +} + +void HostResolverImpl::AbortAllInProgressJobs() { + // In Abort, a Request callback could spawn new Jobs with matching keys, so + // first collect and remove all running jobs from |jobs_|. + ScopedVector<Job> jobs_to_abort; + for (JobMap::iterator it = jobs_.begin(); it != jobs_.end(); ) { + Job* job = it->second; + if (job->is_running()) { + jobs_to_abort.push_back(job); + jobs_.erase(it++); + } else { + DCHECK(job->is_queued()); + ++it; + } + } + + // Check if no dispatcher slots leaked out. + DCHECK_EQ(dispatcher_.num_running_jobs(), jobs_to_abort.size()); + + // Life check to bail once |this| is deleted. + base::WeakPtr<HostResolverImpl> self = weak_ptr_factory_.GetWeakPtr(); + + // Then Abort them. + for (size_t i = 0; self.get() && i < jobs_to_abort.size(); ++i) { + jobs_to_abort[i]->Abort(); + jobs_to_abort[i] = NULL; + } +} + +void HostResolverImpl::TryServingAllJobsFromHosts() { + if (!HaveDnsConfig()) + return; + + // TODO(szym): Do not do this if nsswitch.conf instructs not to. + // http://crbug.com/117655 + + // Life check to bail once |this| is deleted. + base::WeakPtr<HostResolverImpl> self = weak_ptr_factory_.GetWeakPtr(); + + for (JobMap::iterator it = jobs_.begin(); self.get() && it != jobs_.end();) { + Job* job = it->second; + ++it; + // This could remove |job| from |jobs_|, but iterator will remain valid. + job->ServeFromHosts(); + } +} + +void HostResolverImpl::OnIPAddressChanged() { + resolved_known_ipv6_hostname_ = false; + // Abandon all ProbeJobs. + probe_weak_ptr_factory_.InvalidateWeakPtrs(); + if (cache_.get()) + cache_->clear(); +#if defined(OS_POSIX) && !defined(OS_MACOSX) && !defined(OS_ANDROID) + new LoopbackProbeJob(probe_weak_ptr_factory_.GetWeakPtr()); +#endif + AbortAllInProgressJobs(); + // |this| may be deleted inside AbortAllInProgressJobs(). +} + +void HostResolverImpl::OnDNSChanged() { + DnsConfig dns_config; + NetworkChangeNotifier::GetDnsConfig(&dns_config); + + if (net_log_) { + net_log_->AddGlobalEntry( + NetLog::TYPE_DNS_CONFIG_CHANGED, + base::Bind(&NetLogDnsConfigCallback, &dns_config)); + } + + // TODO(szym): Remove once http://crbug.com/137914 is resolved. + received_dns_config_ = dns_config.IsValid(); + + num_dns_failures_ = 0; + + // We want a new DnsSession in place, before we Abort running Jobs, so that + // the newly started jobs use the new config. + if (dns_client_.get()) { + dns_client_->SetConfig(dns_config); + if (dns_config.IsValid()) + UMA_HISTOGRAM_BOOLEAN("AsyncDNS.DnsClientEnabled", true); + } + + // If the DNS server has changed, existing cached info could be wrong so we + // have to drop our internal cache :( Note that OS level DNS caches, such + // as NSCD's cache should be dropped automatically by the OS when + // resolv.conf changes so we don't need to do anything to clear that cache. + if (cache_.get()) + cache_->clear(); + + // Life check to bail once |this| is deleted. + base::WeakPtr<HostResolverImpl> self = weak_ptr_factory_.GetWeakPtr(); + + // Existing jobs will have been sent to the original server so they need to + // be aborted. + AbortAllInProgressJobs(); + + // |this| may be deleted inside AbortAllInProgressJobs(). + if (self.get()) + TryServingAllJobsFromHosts(); +} + +bool HostResolverImpl::HaveDnsConfig() const { + // Use DnsClient only if it's fully configured and there is no override by + // ScopedDefaultHostResolverProc. + // The alternative is to use NetworkChangeNotifier to override DnsConfig, + // but that would introduce construction order requirements for NCN and SDHRP. + return (dns_client_.get() != NULL) && (dns_client_->GetConfig() != NULL) && + !(proc_params_.resolver_proc.get() == NULL && + HostResolverProc::GetDefault() != NULL); +} + +void HostResolverImpl::OnDnsTaskResolve(int net_error) { + DCHECK(dns_client_); + if (net_error == OK) { + num_dns_failures_ = 0; + return; + } + ++num_dns_failures_; + if (num_dns_failures_ < kMaximumDnsFailures) + return; + // Disable DnsClient until the next DNS change. + for (JobMap::iterator it = jobs_.begin(); it != jobs_.end(); ++it) + it->second->AbortDnsTask(); + dns_client_->SetConfig(DnsConfig()); + UMA_HISTOGRAM_BOOLEAN("AsyncDNS.DnsClientEnabled", false); + UMA_HISTOGRAM_CUSTOM_ENUMERATION("AsyncDNS.DnsClientDisabledReason", + std::abs(net_error), + GetAllErrorCodesForUma()); +} + +void HostResolverImpl::SetDnsClient(scoped_ptr<DnsClient> dns_client) { + if (HaveDnsConfig()) { + for (JobMap::iterator it = jobs_.begin(); it != jobs_.end(); ++it) + it->second->AbortDnsTask(); + } + dns_client_ = dns_client.Pass(); + if (!dns_client_ || dns_client_->GetConfig() || + num_dns_failures_ >= kMaximumDnsFailures) { + return; + } + DnsConfig dns_config; + NetworkChangeNotifier::GetDnsConfig(&dns_config); + dns_client_->SetConfig(dns_config); + num_dns_failures_ = 0; + if (dns_config.IsValid()) + UMA_HISTOGRAM_BOOLEAN("AsyncDNS.DnsClientEnabled", true); +} + +} // namespace net diff --git a/chromium/net/dns/host_resolver_impl.h b/chromium/net/dns/host_resolver_impl.h new file mode 100644 index 00000000000..928d07af8b8 --- /dev/null +++ b/chromium/net/dns/host_resolver_impl.h @@ -0,0 +1,285 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_HOST_RESOLVER_IMPL_H_ +#define NET_DNS_HOST_RESOLVER_IMPL_H_ + +#include <map> + +#include "base/basictypes.h" +#include "base/gtest_prod_util.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/scoped_vector.h" +#include "base/memory/weak_ptr.h" +#include "base/threading/non_thread_safe.h" +#include "base/time/time.h" +#include "net/base/capturing_net_log.h" +#include "net/base/net_export.h" +#include "net/base/network_change_notifier.h" +#include "net/base/prioritized_dispatcher.h" +#include "net/dns/host_cache.h" +#include "net/dns/host_resolver.h" +#include "net/dns/host_resolver_proc.h" + +namespace net { + +class BoundNetLog; +class DnsClient; +class NetLog; + +// For each hostname that is requested, HostResolver creates a +// HostResolverImpl::Job. When this job gets dispatched it creates a ProcTask +// which runs the given HostResolverProc on a WorkerPool thread. If requests for +// that same host are made during the job's lifetime, they are attached to the +// existing job rather than creating a new one. This avoids doing parallel +// resolves for the same host. +// +// The way these classes fit together is illustrated by: +// +// +// +----------- HostResolverImpl -------------+ +// | | | +// Job Job Job +// (for host1, fam1) (for host2, fam2) (for hostx, famx) +// / | | / | | / | | +// Request ... Request Request ... Request Request ... Request +// (port1) (port2) (port3) (port4) (port5) (portX) +// +// When a HostResolverImpl::Job finishes, the callbacks of each waiting request +// are run on the origin thread. +// +// Thread safety: This class is not threadsafe, and must only be called +// from one thread! +// +// The HostResolverImpl enforces limits on the maximum number of concurrent +// threads using PrioritizedDispatcher::Limits. +// +// Jobs are ordered in the queue based on their priority and order of arrival. +class NET_EXPORT HostResolverImpl + : public HostResolver, + NON_EXPORTED_BASE(public base::NonThreadSafe), + public NetworkChangeNotifier::IPAddressObserver, + public NetworkChangeNotifier::DNSObserver { + public: + // Parameters for ProcTask which resolves hostnames using HostResolveProc. + // + // |resolver_proc| is used to perform the actual resolves; it must be + // thread-safe since it is run from multiple worker threads. If + // |resolver_proc| is NULL then the default host resolver procedure is + // used (which is SystemHostResolverProc except if overridden). + // + // For each attempt, we could start another attempt if host is not resolved + // within |unresponsive_delay| time. We keep attempting to resolve the host + // for |max_retry_attempts|. For every retry attempt, we grow the + // |unresponsive_delay| by the |retry_factor| amount (that is retry interval + // is multiplied by the retry factor each time). Once we have retried + // |max_retry_attempts|, we give up on additional attempts. + // + struct NET_EXPORT_PRIVATE ProcTaskParams { + // Sets up defaults. + ProcTaskParams(HostResolverProc* resolver_proc, size_t max_retry_attempts); + + ~ProcTaskParams(); + + // The procedure to use for resolving host names. This will be NULL, except + // in the case of unit-tests which inject custom host resolving behaviors. + scoped_refptr<HostResolverProc> resolver_proc; + + // Maximum number retry attempts to resolve the hostname. + // Pass HostResolver::kDefaultRetryAttempts to choose a default value. + size_t max_retry_attempts; + + // This is the limit after which we make another attempt to resolve the host + // if the worker thread has not responded yet. + base::TimeDelta unresponsive_delay; + + // Factor to grow |unresponsive_delay| when we re-re-try. + uint32 retry_factor; + }; + + // Creates a HostResolver that first uses the local cache |cache|, and then + // falls back to |proc_params.resolver_proc|. + // + // If |cache| is NULL, then no caching is used. Otherwise we take + // ownership of the |cache| pointer, and will free it during destruction. + // + // |job_limits| specifies the maximum number of jobs that the resolver will + // run at once. This upper-bounds the total number of outstanding + // DNS transactions (not counting retransmissions and retries). + // + // |net_log| must remain valid for the life of the HostResolverImpl. + HostResolverImpl(scoped_ptr<HostCache> cache, + const PrioritizedDispatcher::Limits& job_limits, + const ProcTaskParams& proc_params, + NetLog* net_log); + + // If any completion callbacks are pending when the resolver is destroyed, + // the host resolutions are cancelled, and the completion callbacks will not + // be called. + virtual ~HostResolverImpl(); + + // Configures maximum number of Jobs in the queue. Exposed for testing. + // Only allowed when the queue is empty. + void SetMaxQueuedJobs(size_t value); + + // Set the DnsClient to be used for resolution. In case of failure, the + // HostResolverProc from ProcTaskParams will be queried. If the DnsClient is + // not pre-configured with a valid DnsConfig, a new config is fetched from + // NetworkChangeNotifier. + void SetDnsClient(scoped_ptr<DnsClient> dns_client); + + // HostResolver methods: + virtual int Resolve(const RequestInfo& info, + AddressList* addresses, + const CompletionCallback& callback, + RequestHandle* out_req, + const BoundNetLog& source_net_log) OVERRIDE; + virtual int ResolveFromCache(const RequestInfo& info, + AddressList* addresses, + const BoundNetLog& source_net_log) OVERRIDE; + virtual void CancelRequest(RequestHandle req) OVERRIDE; + virtual void SetDefaultAddressFamily(AddressFamily address_family) OVERRIDE; + virtual AddressFamily GetDefaultAddressFamily() const OVERRIDE; + virtual void SetDnsClientEnabled(bool enabled) OVERRIDE; + virtual HostCache* GetHostCache() OVERRIDE; + virtual base::Value* GetDnsConfigAsValue() const OVERRIDE; + + private: + friend class HostResolverImplTest; + class Job; + class ProcTask; + class LoopbackProbeJob; + class DnsTask; + class Request; + typedef HostCache::Key Key; + typedef std::map<Key, Job*> JobMap; + typedef ScopedVector<Request> RequestsList; + + // Helper used by |Resolve()| and |ResolveFromCache()|. Performs IP + // literal, cache and HOSTS lookup (if enabled), returns OK if successful, + // ERR_NAME_NOT_RESOLVED if either hostname is invalid or IP literal is + // incompatible, ERR_DNS_CACHE_MISS if entry was not found in cache and HOSTS. + int ResolveHelper(const Key& key, + const RequestInfo& info, + AddressList* addresses, + const BoundNetLog& request_net_log); + + // Tries to resolve |key| as an IP, returns true and sets |net_error| if + // succeeds, returns false otherwise. + bool ResolveAsIP(const Key& key, + const RequestInfo& info, + int* net_error, + AddressList* addresses); + + // If |key| is not found in cache returns false, otherwise returns + // true, sets |net_error| to the cached error code and fills |addresses| + // if it is a positive entry. + bool ServeFromCache(const Key& key, + const RequestInfo& info, + int* net_error, + AddressList* addresses); + + // If we have a DnsClient with a valid DnsConfig, and |key| is found in the + // HOSTS file, returns true and fills |addresses|. Otherwise returns false. + bool ServeFromHosts(const Key& key, + const RequestInfo& info, + AddressList* addresses); + + // Callback from HaveOnlyLoopbackAddresses probe. + void SetHaveOnlyLoopbackAddresses(bool result); + + // Returns the (hostname, address_family) key to use for |info|, choosing an + // "effective" address family by inheriting the resolver's default address + // family when the request leaves it unspecified. + Key GetEffectiveKeyForRequest(const RequestInfo& info, + const BoundNetLog& net_log) const; + + // Records the result in cache if cache is present. + void CacheResult(const Key& key, + const HostCache::Entry& entry, + base::TimeDelta ttl); + + // Removes |job| from |jobs_|, only if it exists. + void RemoveJob(Job* job); + + // Aborts all in progress jobs with ERR_NETWORK_CHANGED and notifies their + // requests. Might start new jobs. + void AbortAllInProgressJobs(); + + // Attempts to serve each Job in |jobs_| from the HOSTS file if we have + // a DnsClient with a valid DnsConfig. + void TryServingAllJobsFromHosts(); + + // NetworkChangeNotifier::IPAddressObserver: + virtual void OnIPAddressChanged() OVERRIDE; + + // NetworkChangeNotifier::DNSObserver: + virtual void OnDNSChanged() OVERRIDE; + + // True if have a DnsClient with a valid DnsConfig. + bool HaveDnsConfig() const; + + // Called when a host name is successfully resolved and DnsTask was run on it + // and resulted in |net_error|. + void OnDnsTaskResolve(int net_error); + + // Allows the tests to catch slots leaking out of the dispatcher. + size_t num_running_jobs_for_tests() const { + return dispatcher_.num_running_jobs(); + } + + // Cache of host resolution results. + scoped_ptr<HostCache> cache_; + + // Map from HostCache::Key to a Job. + JobMap jobs_; + + // Starts Jobs according to their priority and the configured limits. + PrioritizedDispatcher dispatcher_; + + // Limit on the maximum number of jobs queued in |dispatcher_|. + size_t max_queued_jobs_; + + // Parameters for ProcTask. + ProcTaskParams proc_params_; + + NetLog* net_log_; + + // Address family to use when the request doesn't specify one. + AddressFamily default_address_family_; + + base::WeakPtrFactory<HostResolverImpl> weak_ptr_factory_; + + base::WeakPtrFactory<HostResolverImpl> probe_weak_ptr_factory_; + + // If present, used by DnsTask and ServeFromHosts to resolve requests. + scoped_ptr<DnsClient> dns_client_; + + // True if received valid config from |dns_config_service_|. Temporary, used + // to measure performance of DnsConfigService: http://crbug.com/125599 + bool received_dns_config_; + + // Number of consecutive failures of DnsTask, counted when fallback succeeds. + unsigned num_dns_failures_; + + // True if probing is done for each Request to set address family. When false, + // explicit setting in |default_address_family_| is used. + bool probe_ipv6_support_; + + // True iff ProcTask has successfully resolved a hostname known to have IPv6 + // addresses using ADDRESS_FAMILY_UNSPECIFIED. Reset on IP address change. + bool resolved_known_ipv6_hostname_; + + // Any resolver flags that should be added to a request by default. + HostResolverFlags additional_resolver_flags_; + + // Allow fallback to ProcTask if DnsTask fails. + bool fallback_to_proctask_; + + DISALLOW_COPY_AND_ASSIGN(HostResolverImpl); +}; + +} // namespace net + +#endif // NET_DNS_HOST_RESOLVER_IMPL_H_ diff --git a/chromium/net/dns/host_resolver_impl_unittest.cc b/chromium/net/dns/host_resolver_impl_unittest.cc new file mode 100644 index 00000000000..f6b7f690675 --- /dev/null +++ b/chromium/net/dns/host_resolver_impl_unittest.cc @@ -0,0 +1,1641 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/host_resolver_impl.h" + +#include <algorithm> +#include <string> + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_vector.h" +#include "base/message_loop/message_loop.h" +#include "base/strings/string_util.h" +#include "base/strings/stringprintf.h" +#include "base/synchronization/condition_variable.h" +#include "base/synchronization/lock.h" +#include "base/test/test_timeouts.h" +#include "base/time/time.h" +#include "net/base/address_list.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" +#include "net/dns/dns_client.h" +#include "net/dns/dns_test_util.h" +#include "net/dns/host_cache.h" +#include "net/dns/mock_host_resolver.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +const size_t kMaxJobs = 10u; +const size_t kMaxRetryAttempts = 4u; + +PrioritizedDispatcher::Limits DefaultLimits() { + PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, kMaxJobs); + return limits; +} + +HostResolverImpl::ProcTaskParams DefaultParams( + HostResolverProc* resolver_proc) { + return HostResolverImpl::ProcTaskParams(resolver_proc, kMaxRetryAttempts); +} + +// A HostResolverProc that pushes each host mapped into a list and allows +// waiting for a specific number of requests. Unlike RuleBasedHostResolverProc +// it never calls SystemHostResolverCall. By default resolves all hostnames to +// "127.0.0.1". After AddRule(), it resolves only names explicitly specified. +class MockHostResolverProc : public HostResolverProc { + public: + struct ResolveKey { + ResolveKey(const std::string& hostname, AddressFamily address_family) + : hostname(hostname), address_family(address_family) {} + bool operator<(const ResolveKey& other) const { + return address_family < other.address_family || + (address_family == other.address_family && hostname < other.hostname); + } + std::string hostname; + AddressFamily address_family; + }; + + typedef std::vector<ResolveKey> CaptureList; + + MockHostResolverProc() + : HostResolverProc(NULL), + num_requests_waiting_(0), + num_slots_available_(0), + requests_waiting_(&lock_), + slots_available_(&lock_) { + } + + // Waits until |count| calls to |Resolve| are blocked. Returns false when + // timed out. + bool WaitFor(unsigned count) { + base::AutoLock lock(lock_); + base::Time start_time = base::Time::Now(); + while (num_requests_waiting_ < count) { + requests_waiting_.TimedWait(TestTimeouts::action_timeout()); + if (base::Time::Now() > start_time + TestTimeouts::action_timeout()) + return false; + } + return true; + } + + // Signals |count| waiting calls to |Resolve|. First come first served. + void SignalMultiple(unsigned count) { + base::AutoLock lock(lock_); + num_slots_available_ += count; + slots_available_.Broadcast(); + } + + // Signals all waiting calls to |Resolve|. Beware of races. + void SignalAll() { + base::AutoLock lock(lock_); + num_slots_available_ = num_requests_waiting_; + slots_available_.Broadcast(); + } + + void AddRule(const std::string& hostname, AddressFamily family, + const AddressList& result) { + base::AutoLock lock(lock_); + rules_[ResolveKey(hostname, family)] = result; + } + + void AddRule(const std::string& hostname, AddressFamily family, + const std::string& ip_list) { + AddressList result; + int rv = ParseAddressList(ip_list, std::string(), &result); + DCHECK_EQ(OK, rv); + AddRule(hostname, family, result); + } + + void AddRuleForAllFamilies(const std::string& hostname, + const std::string& ip_list) { + AddressList result; + int rv = ParseAddressList(ip_list, std::string(), &result); + DCHECK_EQ(OK, rv); + AddRule(hostname, ADDRESS_FAMILY_UNSPECIFIED, result); + AddRule(hostname, ADDRESS_FAMILY_IPV4, result); + AddRule(hostname, ADDRESS_FAMILY_IPV6, result); + } + + virtual int Resolve(const std::string& hostname, + AddressFamily address_family, + HostResolverFlags host_resolver_flags, + AddressList* addrlist, + int* os_error) OVERRIDE { + base::AutoLock lock(lock_); + capture_list_.push_back(ResolveKey(hostname, address_family)); + ++num_requests_waiting_; + requests_waiting_.Broadcast(); + while (!num_slots_available_) + slots_available_.Wait(); + DCHECK_GT(num_requests_waiting_, 0u); + --num_slots_available_; + --num_requests_waiting_; + if (rules_.empty()) { + int rv = ParseAddressList("127.0.0.1", std::string(), addrlist); + DCHECK_EQ(OK, rv); + return OK; + } + ResolveKey key(hostname, address_family); + if (rules_.count(key) == 0) + return ERR_NAME_NOT_RESOLVED; + *addrlist = rules_[key]; + return OK; + } + + CaptureList GetCaptureList() const { + CaptureList copy; + { + base::AutoLock lock(lock_); + copy = capture_list_; + } + return copy; + } + + bool HasBlockedRequests() const { + base::AutoLock lock(lock_); + return num_requests_waiting_ > num_slots_available_; + } + + protected: + virtual ~MockHostResolverProc() {} + + private: + mutable base::Lock lock_; + std::map<ResolveKey, AddressList> rules_; + CaptureList capture_list_; + unsigned num_requests_waiting_; + unsigned num_slots_available_; + base::ConditionVariable requests_waiting_; + base::ConditionVariable slots_available_; + + DISALLOW_COPY_AND_ASSIGN(MockHostResolverProc); +}; + +bool AddressListContains(const AddressList& list, const std::string& address, + int port) { + IPAddressNumber ip; + bool rv = ParseIPLiteralToNumber(address, &ip); + DCHECK(rv); + return std::find(list.begin(), + list.end(), + IPEndPoint(ip, port)) != list.end(); +} + +// A wrapper for requests to a HostResolver. +class Request { + public: + // Base class of handlers to be executed on completion of requests. + struct Handler { + virtual ~Handler() {} + virtual void Handle(Request* request) = 0; + }; + + Request(const HostResolver::RequestInfo& info, + size_t index, + HostResolver* resolver, + Handler* handler) + : info_(info), + index_(index), + resolver_(resolver), + handler_(handler), + quit_on_complete_(false), + result_(ERR_UNEXPECTED), + handle_(NULL) {} + + int Resolve() { + DCHECK(resolver_); + DCHECK(!handle_); + list_ = AddressList(); + result_ = resolver_->Resolve( + info_, &list_, base::Bind(&Request::OnComplete, base::Unretained(this)), + &handle_, BoundNetLog()); + if (!list_.empty()) + EXPECT_EQ(OK, result_); + return result_; + } + + int ResolveFromCache() { + DCHECK(resolver_); + DCHECK(!handle_); + return resolver_->ResolveFromCache(info_, &list_, BoundNetLog()); + } + + void Cancel() { + DCHECK(resolver_); + DCHECK(handle_); + resolver_->CancelRequest(handle_); + handle_ = NULL; + } + + const HostResolver::RequestInfo& info() const { return info_; } + size_t index() const { return index_; } + const AddressList& list() const { return list_; } + int result() const { return result_; } + bool completed() const { return result_ != ERR_IO_PENDING; } + bool pending() const { return handle_ != NULL; } + + bool HasAddress(const std::string& address, int port) const { + return AddressListContains(list_, address, port); + } + + // Returns the number of addresses in |list_|. + unsigned NumberOfAddresses() const { + return list_.size(); + } + + bool HasOneAddress(const std::string& address, int port) const { + return HasAddress(address, port) && (NumberOfAddresses() == 1u); + } + + // Returns ERR_UNEXPECTED if timed out. + int WaitForResult() { + if (completed()) + return result_; + base::CancelableClosure closure(base::MessageLoop::QuitClosure()); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, closure.callback(), TestTimeouts::action_max_timeout()); + quit_on_complete_ = true; + base::MessageLoop::current()->Run(); + bool did_quit = !quit_on_complete_; + quit_on_complete_ = false; + closure.Cancel(); + if (did_quit) + return result_; + else + return ERR_UNEXPECTED; + } + + private: + void OnComplete(int rv) { + EXPECT_TRUE(pending()); + EXPECT_EQ(ERR_IO_PENDING, result_); + EXPECT_NE(ERR_IO_PENDING, rv); + result_ = rv; + handle_ = NULL; + if (!list_.empty()) { + EXPECT_EQ(OK, result_); + EXPECT_EQ(info_.port(), list_.front().port()); + } + if (handler_) + handler_->Handle(this); + if (quit_on_complete_) { + base::MessageLoop::current()->Quit(); + quit_on_complete_ = false; + } + } + + HostResolver::RequestInfo info_; + size_t index_; + HostResolver* resolver_; + Handler* handler_; + bool quit_on_complete_; + + AddressList list_; + int result_; + HostResolver::RequestHandle handle_; + + DISALLOW_COPY_AND_ASSIGN(Request); +}; + +// Using LookupAttemptHostResolverProc simulate very long lookups, and control +// which attempt resolves the host. +class LookupAttemptHostResolverProc : public HostResolverProc { + public: + LookupAttemptHostResolverProc(HostResolverProc* previous, + int attempt_number_to_resolve, + int total_attempts) + : HostResolverProc(previous), + attempt_number_to_resolve_(attempt_number_to_resolve), + current_attempt_number_(0), + total_attempts_(total_attempts), + total_attempts_resolved_(0), + resolved_attempt_number_(0), + all_done_(&lock_) { + } + + // Test harness will wait for all attempts to finish before checking the + // results. + void WaitForAllAttemptsToFinish(const base::TimeDelta& wait_time) { + base::TimeTicks end_time = base::TimeTicks::Now() + wait_time; + { + base::AutoLock auto_lock(lock_); + while (total_attempts_resolved_ != total_attempts_ && + base::TimeTicks::Now() < end_time) { + all_done_.TimedWait(end_time - base::TimeTicks::Now()); + } + } + } + + // All attempts will wait for an attempt to resolve the host. + void WaitForAnAttemptToComplete() { + base::TimeDelta wait_time = base::TimeDelta::FromSeconds(60); + base::TimeTicks end_time = base::TimeTicks::Now() + wait_time; + { + base::AutoLock auto_lock(lock_); + while (resolved_attempt_number_ == 0 && base::TimeTicks::Now() < end_time) + all_done_.TimedWait(end_time - base::TimeTicks::Now()); + } + all_done_.Broadcast(); // Tell all waiting attempts to proceed. + } + + // Returns the number of attempts that have finished the Resolve() method. + int total_attempts_resolved() { return total_attempts_resolved_; } + + // Returns the first attempt that that has resolved the host. + int resolved_attempt_number() { return resolved_attempt_number_; } + + // HostResolverProc methods. + virtual int Resolve(const std::string& host, + AddressFamily address_family, + HostResolverFlags host_resolver_flags, + AddressList* addrlist, + int* os_error) OVERRIDE { + bool wait_for_right_attempt_to_complete = true; + { + base::AutoLock auto_lock(lock_); + ++current_attempt_number_; + if (current_attempt_number_ == attempt_number_to_resolve_) { + resolved_attempt_number_ = current_attempt_number_; + wait_for_right_attempt_to_complete = false; + } + } + + if (wait_for_right_attempt_to_complete) + // Wait for the attempt_number_to_resolve_ attempt to resolve. + WaitForAnAttemptToComplete(); + + int result = ResolveUsingPrevious(host, address_family, host_resolver_flags, + addrlist, os_error); + + { + base::AutoLock auto_lock(lock_); + ++total_attempts_resolved_; + } + + all_done_.Broadcast(); // Tell all attempts to proceed. + + // Since any negative number is considered a network error, with -1 having + // special meaning (ERR_IO_PENDING). We could return the attempt that has + // resolved the host as a negative number. For example, if attempt number 3 + // resolves the host, then this method returns -4. + if (result == OK) + return -1 - resolved_attempt_number_; + else + return result; + } + + protected: + virtual ~LookupAttemptHostResolverProc() {} + + private: + int attempt_number_to_resolve_; + int current_attempt_number_; // Incremented whenever Resolve is called. + int total_attempts_; + int total_attempts_resolved_; + int resolved_attempt_number_; + + // All attempts wait for right attempt to be resolve. + base::Lock lock_; + base::ConditionVariable all_done_; +}; + +} // namespace + +class HostResolverImplTest : public testing::Test { + public: + static const int kDefaultPort = 80; + + HostResolverImplTest() : proc_(new MockHostResolverProc()) {} + + protected: + // A Request::Handler which is a proxy to the HostResolverImplTest fixture. + struct Handler : public Request::Handler { + virtual ~Handler() {} + + // Proxy functions so that classes derived from Handler can access them. + Request* CreateRequest(const HostResolver::RequestInfo& info) { + return test->CreateRequest(info); + } + Request* CreateRequest(const std::string& hostname, int port) { + return test->CreateRequest(hostname, port); + } + Request* CreateRequest(const std::string& hostname) { + return test->CreateRequest(hostname); + } + ScopedVector<Request>& requests() { return test->requests_; } + + void DeleteResolver() { test->resolver_.reset(); } + + HostResolverImplTest* test; + }; + + void CreateResolver() { + resolver_.reset(new HostResolverImpl(HostCache::CreateDefaultCache(), + DefaultLimits(), + DefaultParams(proc_.get()), + NULL)); + } + + // This HostResolverImpl will only allow 1 outstanding resolve at a time and + // perform no retries. + void CreateSerialResolver() { + HostResolverImpl::ProcTaskParams params = DefaultParams(proc_.get()); + params.max_retry_attempts = 0u; + PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, 1); + resolver_.reset(new HostResolverImpl( + HostCache::CreateDefaultCache(), + limits, + params, + NULL)); + } + + // The Request will not be made until a call to |Resolve()|, and the Job will + // not start until released by |proc_->SignalXXX|. + Request* CreateRequest(const HostResolver::RequestInfo& info) { + Request* req = new Request(info, requests_.size(), resolver_.get(), + handler_.get()); + requests_.push_back(req); + return req; + } + + Request* CreateRequest(const std::string& hostname, + int port, + RequestPriority priority, + AddressFamily family) { + HostResolver::RequestInfo info(HostPortPair(hostname, port)); + info.set_priority(priority); + info.set_address_family(family); + return CreateRequest(info); + } + + Request* CreateRequest(const std::string& hostname, + int port, + RequestPriority priority) { + return CreateRequest(hostname, port, priority, ADDRESS_FAMILY_UNSPECIFIED); + } + + Request* CreateRequest(const std::string& hostname, int port) { + return CreateRequest(hostname, port, MEDIUM); + } + + Request* CreateRequest(const std::string& hostname) { + return CreateRequest(hostname, kDefaultPort); + } + + virtual void SetUp() OVERRIDE { + CreateResolver(); + } + + virtual void TearDown() OVERRIDE { + if (resolver_.get()) + EXPECT_EQ(0u, resolver_->num_running_jobs_for_tests()); + EXPECT_FALSE(proc_->HasBlockedRequests()); + } + + void set_handler(Handler* handler) { + handler_.reset(handler); + handler_->test = this; + } + + // Friendship is not inherited, so use proxies to access those. + size_t num_running_jobs() const { + DCHECK(resolver_.get()); + return resolver_->num_running_jobs_for_tests(); + } + + void set_fallback_to_proctask(bool fallback_to_proctask) { + DCHECK(resolver_.get()); + resolver_->fallback_to_proctask_ = fallback_to_proctask; + } + + scoped_refptr<MockHostResolverProc> proc_; + scoped_ptr<HostResolverImpl> resolver_; + ScopedVector<Request> requests_; + + scoped_ptr<Handler> handler_; +}; + +TEST_F(HostResolverImplTest, AsynchronousLookup) { + proc_->AddRuleForAllFamilies("just.testing", "192.168.1.42"); + proc_->SignalMultiple(1u); + + Request* req = CreateRequest("just.testing", 80); + EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); + EXPECT_EQ(OK, req->WaitForResult()); + + EXPECT_TRUE(req->HasOneAddress("192.168.1.42", 80)); + + EXPECT_EQ("just.testing", proc_->GetCaptureList()[0].hostname); +} + +TEST_F(HostResolverImplTest, EmptyListMeansNameNotResolved) { + proc_->AddRuleForAllFamilies("just.testing", ""); + proc_->SignalMultiple(1u); + + Request* req = CreateRequest("just.testing", 80); + EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req->WaitForResult()); + EXPECT_EQ(0u, req->NumberOfAddresses()); + EXPECT_EQ("just.testing", proc_->GetCaptureList()[0].hostname); +} + +TEST_F(HostResolverImplTest, FailedAsynchronousLookup) { + proc_->AddRuleForAllFamilies(std::string(), + "0.0.0.0"); // Default to failures. + proc_->SignalMultiple(1u); + + Request* req = CreateRequest("just.testing", 80); + EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req->WaitForResult()); + + EXPECT_EQ("just.testing", proc_->GetCaptureList()[0].hostname); + + // Also test that the error is not cached. + EXPECT_EQ(ERR_DNS_CACHE_MISS, req->ResolveFromCache()); +} + +TEST_F(HostResolverImplTest, AbortedAsynchronousLookup) { + Request* req0 = CreateRequest("just.testing", 80); + EXPECT_EQ(ERR_IO_PENDING, req0->Resolve()); + + EXPECT_TRUE(proc_->WaitFor(1u)); + + // Resolver is destroyed while job is running on WorkerPool. + resolver_.reset(); + + proc_->SignalAll(); + + // To ensure there was no spurious callback, complete with a new resolver. + CreateResolver(); + Request* req1 = CreateRequest("just.testing", 80); + EXPECT_EQ(ERR_IO_PENDING, req1->Resolve()); + + proc_->SignalMultiple(2u); + + EXPECT_EQ(OK, req1->WaitForResult()); + + // This request was canceled. + EXPECT_FALSE(req0->completed()); +} + +TEST_F(HostResolverImplTest, NumericIPv4Address) { + // Stevens says dotted quads with AI_UNSPEC resolve to a single sockaddr_in. + Request* req = CreateRequest("127.1.2.3", 5555); + EXPECT_EQ(OK, req->Resolve()); + + EXPECT_TRUE(req->HasOneAddress("127.1.2.3", 5555)); +} + +TEST_F(HostResolverImplTest, NumericIPv6Address) { + // Resolve a plain IPv6 address. Don't worry about [brackets], because + // the caller should have removed them. + Request* req = CreateRequest("2001:db8::1", 5555); + EXPECT_EQ(OK, req->Resolve()); + + EXPECT_TRUE(req->HasOneAddress("2001:db8::1", 5555)); +} + +TEST_F(HostResolverImplTest, EmptyHost) { + Request* req = CreateRequest(std::string(), 5555); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req->Resolve()); +} + +TEST_F(HostResolverImplTest, EmptyDotsHost) { + for (int i = 0; i < 16; ++i) { + Request* req = CreateRequest(std::string(i, '.'), 5555); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req->Resolve()); + } +} + +TEST_F(HostResolverImplTest, LongHost) { + Request* req = CreateRequest(std::string(4097, 'a'), 5555); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req->Resolve()); +} + +TEST_F(HostResolverImplTest, DeDupeRequests) { + // Start 5 requests, duplicating hosts "a" and "b". Since the resolver_proc is + // blocked, these should all pile up until we signal it. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 81)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 82)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 83)->Resolve()); + + proc_->SignalMultiple(2u); // One for "a", one for "b". + + for (size_t i = 0; i < requests_.size(); ++i) { + EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i; + } +} + +TEST_F(HostResolverImplTest, CancelMultipleRequests) { + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 81)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 82)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 83)->Resolve()); + + // Cancel everything except request for ("a", 82). + requests_[0]->Cancel(); + requests_[1]->Cancel(); + requests_[2]->Cancel(); + requests_[4]->Cancel(); + + proc_->SignalMultiple(2u); // One for "a", one for "b". + + EXPECT_EQ(OK, requests_[3]->WaitForResult()); +} + +TEST_F(HostResolverImplTest, CanceledRequestsReleaseJobSlots) { + // Fill up the dispatcher and queue. + for (unsigned i = 0; i < kMaxJobs + 1; ++i) { + std::string hostname = "a_"; + hostname[1] = 'a' + i; + EXPECT_EQ(ERR_IO_PENDING, CreateRequest(hostname, 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest(hostname, 81)->Resolve()); + } + + EXPECT_TRUE(proc_->WaitFor(kMaxJobs)); + + // Cancel all but last two. + for (unsigned i = 0; i < requests_.size() - 2; ++i) { + requests_[i]->Cancel(); + } + + EXPECT_TRUE(proc_->WaitFor(kMaxJobs + 1)); + + proc_->SignalAll(); + + size_t num_requests = requests_.size(); + EXPECT_EQ(OK, requests_[num_requests - 1]->WaitForResult()); + EXPECT_EQ(OK, requests_[num_requests - 2]->result()); +} + +TEST_F(HostResolverImplTest, CancelWithinCallback) { + struct MyHandler : public Handler { + virtual void Handle(Request* req) OVERRIDE { + // Port 80 is the first request that the callback will be invoked for. + // While we are executing within that callback, cancel the other requests + // in the job and start another request. + if (req->index() == 0) { + // Once "a:80" completes, it will cancel "a:81" and "a:82". + requests()[1]->Cancel(); + requests()[2]->Cancel(); + } + } + }; + set_handler(new MyHandler()); + + for (size_t i = 0; i < 4; ++i) { + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 80 + i)->Resolve()) << i; + } + + proc_->SignalMultiple(2u); // One for "a". One for "finalrequest". + + EXPECT_EQ(OK, requests_[0]->WaitForResult()); + + Request* final_request = CreateRequest("finalrequest", 70); + EXPECT_EQ(ERR_IO_PENDING, final_request->Resolve()); + EXPECT_EQ(OK, final_request->WaitForResult()); + EXPECT_TRUE(requests_[3]->completed()); +} + +TEST_F(HostResolverImplTest, DeleteWithinCallback) { + struct MyHandler : public Handler { + virtual void Handle(Request* req) OVERRIDE { + EXPECT_EQ("a", req->info().hostname()); + EXPECT_EQ(80, req->info().port()); + + DeleteResolver(); + + // Quit after returning from OnCompleted (to give it a chance at + // incorrectly running the cancelled tasks). + base::MessageLoop::current()->PostTask(FROM_HERE, + base::MessageLoop::QuitClosure()); + } + }; + set_handler(new MyHandler()); + + for (size_t i = 0; i < 4; ++i) { + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 80 + i)->Resolve()) << i; + } + + proc_->SignalMultiple(1u); // One for "a". + + // |MyHandler| will send quit message once all the requests have finished. + base::MessageLoop::current()->Run(); +} + +TEST_F(HostResolverImplTest, DeleteWithinAbortedCallback) { + struct MyHandler : public Handler { + virtual void Handle(Request* req) OVERRIDE { + EXPECT_EQ("a", req->info().hostname()); + EXPECT_EQ(80, req->info().port()); + + DeleteResolver(); + + // Quit after returning from OnCompleted (to give it a chance at + // incorrectly running the cancelled tasks). + base::MessageLoop::current()->PostTask(FROM_HERE, + base::MessageLoop::QuitClosure()); + } + }; + set_handler(new MyHandler()); + + // This test assumes that the Jobs will be Aborted in order ["a", "b"] + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 80)->Resolve()); + // HostResolverImpl will be deleted before later Requests can complete. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 81)->Resolve()); + // Job for 'b' will be aborted before it can complete. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 82)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b", 83)->Resolve()); + + EXPECT_TRUE(proc_->WaitFor(1u)); + + // Triggering an IP address change. + NetworkChangeNotifier::NotifyObserversOfIPAddressChangeForTests(); + + // |MyHandler| will send quit message once all the requests have finished. + base::MessageLoop::current()->Run(); + + EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[0]->result()); + EXPECT_EQ(ERR_IO_PENDING, requests_[1]->result()); + EXPECT_EQ(ERR_IO_PENDING, requests_[2]->result()); + EXPECT_EQ(ERR_IO_PENDING, requests_[3]->result()); + // Clean up. + proc_->SignalMultiple(requests_.size()); +} + +TEST_F(HostResolverImplTest, StartWithinCallback) { + struct MyHandler : public Handler { + virtual void Handle(Request* req) OVERRIDE { + if (req->index() == 0) { + // On completing the first request, start another request for "a". + // Since caching is disabled, this will result in another async request. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 70)->Resolve()); + } + } + }; + set_handler(new MyHandler()); + + // Turn off caching for this host resolver. + resolver_.reset(new HostResolverImpl(scoped_ptr<HostCache>(), + DefaultLimits(), + DefaultParams(proc_.get()), + NULL)); + + for (size_t i = 0; i < 4; ++i) { + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 80 + i)->Resolve()) << i; + } + + proc_->SignalMultiple(2u); // One for "a". One for the second "a". + + EXPECT_EQ(OK, requests_[0]->WaitForResult()); + ASSERT_EQ(5u, requests_.size()); + EXPECT_EQ(OK, requests_.back()->WaitForResult()); + + EXPECT_EQ(2u, proc_->GetCaptureList().size()); +} + +TEST_F(HostResolverImplTest, BypassCache) { + struct MyHandler : public Handler { + virtual void Handle(Request* req) OVERRIDE { + if (req->index() == 0) { + // On completing the first request, start another request for "a". + // Since caching is enabled, this should complete synchronously. + std::string hostname = req->info().hostname(); + EXPECT_EQ(OK, CreateRequest(hostname, 70)->Resolve()); + EXPECT_EQ(OK, CreateRequest(hostname, 75)->ResolveFromCache()); + + // Ok good. Now make sure that if we ask to bypass the cache, it can no + // longer service the request synchronously. + HostResolver::RequestInfo info(HostPortPair(hostname, 71)); + info.set_allow_cached_response(false); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest(info)->Resolve()); + } else if (71 == req->info().port()) { + // Test is done. + base::MessageLoop::current()->Quit(); + } else { + FAIL() << "Unexpected request"; + } + } + }; + set_handler(new MyHandler()); + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a", 80)->Resolve()); + proc_->SignalMultiple(3u); // Only need two, but be generous. + + // |verifier| will send quit message once all the requests have finished. + base::MessageLoop::current()->Run(); + EXPECT_EQ(2u, proc_->GetCaptureList().size()); +} + +// Test that IP address changes flush the cache. +TEST_F(HostResolverImplTest, FlushCacheOnIPAddressChange) { + proc_->SignalMultiple(2u); // One before the flush, one after. + + Request* req = CreateRequest("host1", 70); + EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); + EXPECT_EQ(OK, req->WaitForResult()); + + req = CreateRequest("host1", 75); + EXPECT_EQ(OK, req->Resolve()); // Should complete synchronously. + + // Flush cache by triggering an IP address change. + NetworkChangeNotifier::NotifyObserversOfIPAddressChangeForTests(); + base::MessageLoop::current()->RunUntilIdle(); // Notification happens async. + + // Resolve "host1" again -- this time it won't be served from cache, so it + // will complete asynchronously. + req = CreateRequest("host1", 80); + EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); + EXPECT_EQ(OK, req->WaitForResult()); +} + +// Test that IP address changes send ERR_NETWORK_CHANGED to pending requests. +TEST_F(HostResolverImplTest, AbortOnIPAddressChanged) { + Request* req = CreateRequest("host1", 70); + EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); + + EXPECT_TRUE(proc_->WaitFor(1u)); + // Triggering an IP address change. + NetworkChangeNotifier::NotifyObserversOfIPAddressChangeForTests(); + base::MessageLoop::current()->RunUntilIdle(); // Notification happens async. + proc_->SignalAll(); + + EXPECT_EQ(ERR_NETWORK_CHANGED, req->WaitForResult()); + EXPECT_EQ(0u, resolver_->GetHostCache()->size()); +} + +// Obey pool constraints after IP address has changed. +TEST_F(HostResolverImplTest, ObeyPoolConstraintsAfterIPAddressChange) { + // Runs at most one job at a time. + CreateSerialResolver(); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("a")->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("b")->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("c")->Resolve()); + + EXPECT_TRUE(proc_->WaitFor(1u)); + // Triggering an IP address change. + NetworkChangeNotifier::NotifyObserversOfIPAddressChangeForTests(); + base::MessageLoop::current()->RunUntilIdle(); // Notification happens async. + proc_->SignalMultiple(3u); // Let the false-start go so that we can catch it. + + EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[0]->WaitForResult()); + + EXPECT_EQ(1u, num_running_jobs()); + + EXPECT_FALSE(requests_[1]->completed()); + EXPECT_FALSE(requests_[2]->completed()); + + EXPECT_EQ(OK, requests_[2]->WaitForResult()); + EXPECT_EQ(OK, requests_[1]->result()); +} + +// Tests that a new Request made from the callback of a previously aborted one +// will not be aborted. +TEST_F(HostResolverImplTest, AbortOnlyExistingRequestsOnIPAddressChange) { + struct MyHandler : public Handler { + virtual void Handle(Request* req) OVERRIDE { + // Start new request for a different hostname to ensure that the order + // of jobs in HostResolverImpl is not stable. + std::string hostname; + if (req->index() == 0) + hostname = "zzz"; + else if (req->index() == 1) + hostname = "aaa"; + else if (req->index() == 2) + hostname = "eee"; + else + return; // A request started from within MyHandler. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest(hostname)->Resolve()) << hostname; + } + }; + set_handler(new MyHandler()); + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("bbb")->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("eee")->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ccc")->Resolve()); + + // Wait until all are blocked; + EXPECT_TRUE(proc_->WaitFor(3u)); + // Trigger an IP address change. + NetworkChangeNotifier::NotifyObserversOfIPAddressChangeForTests(); + // This should abort all running jobs. + base::MessageLoop::current()->RunUntilIdle(); + EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[0]->result()); + EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[1]->result()); + EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[2]->result()); + ASSERT_EQ(6u, requests_.size()); + // Unblock all calls to proc. + proc_->SignalMultiple(requests_.size()); + // Run until the re-started requests finish. + EXPECT_EQ(OK, requests_[3]->WaitForResult()); + EXPECT_EQ(OK, requests_[4]->WaitForResult()); + EXPECT_EQ(OK, requests_[5]->WaitForResult()); + // Verify that results of aborted Jobs were not cached. + EXPECT_EQ(6u, proc_->GetCaptureList().size()); + EXPECT_EQ(3u, resolver_->GetHostCache()->size()); +} + +// Tests that when the maximum threads is set to 1, requests are dequeued +// in order of priority. +TEST_F(HostResolverImplTest, HigherPriorityRequestsStartedFirst) { + CreateSerialResolver(); + + // Note that at this point the MockHostResolverProc is blocked, so any + // requests we make will not complete. + CreateRequest("req0", 80, LOW); + CreateRequest("req1", 80, MEDIUM); + CreateRequest("req2", 80, MEDIUM); + CreateRequest("req3", 80, LOW); + CreateRequest("req4", 80, HIGHEST); + CreateRequest("req5", 80, LOW); + CreateRequest("req6", 80, LOW); + CreateRequest("req5", 80, HIGHEST); + + for (size_t i = 0; i < requests_.size(); ++i) { + EXPECT_EQ(ERR_IO_PENDING, requests_[i]->Resolve()) << i; + } + + // Unblock the resolver thread so the requests can run. + proc_->SignalMultiple(requests_.size()); // More than needed. + + // Wait for all the requests to complete succesfully. + for (size_t i = 0; i < requests_.size(); ++i) { + EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i; + } + + // Since we have restricted to a single concurrent thread in the jobpool, + // the requests should complete in order of priority (with the exception + // of the first request, which gets started right away, since there is + // nothing outstanding). + MockHostResolverProc::CaptureList capture_list = proc_->GetCaptureList(); + ASSERT_EQ(7u, capture_list.size()); + + EXPECT_EQ("req0", capture_list[0].hostname); + EXPECT_EQ("req4", capture_list[1].hostname); + EXPECT_EQ("req5", capture_list[2].hostname); + EXPECT_EQ("req1", capture_list[3].hostname); + EXPECT_EQ("req2", capture_list[4].hostname); + EXPECT_EQ("req3", capture_list[5].hostname); + EXPECT_EQ("req6", capture_list[6].hostname); +} + +// Try cancelling a job which has not started yet. +TEST_F(HostResolverImplTest, CancelPendingRequest) { + CreateSerialResolver(); + + CreateRequest("req0", 80, LOWEST); + CreateRequest("req1", 80, HIGHEST); // Will cancel. + CreateRequest("req2", 80, MEDIUM); + CreateRequest("req3", 80, LOW); + CreateRequest("req4", 80, HIGHEST); // Will cancel. + CreateRequest("req5", 80, LOWEST); // Will cancel. + CreateRequest("req6", 80, MEDIUM); + + // Start all of the requests. + for (size_t i = 0; i < requests_.size(); ++i) { + EXPECT_EQ(ERR_IO_PENDING, requests_[i]->Resolve()) << i; + } + + // Cancel some requests + requests_[1]->Cancel(); + requests_[4]->Cancel(); + requests_[5]->Cancel(); + + // Unblock the resolver thread so the requests can run. + proc_->SignalMultiple(requests_.size()); // More than needed. + + // Wait for all the requests to complete succesfully. + for (size_t i = 0; i < requests_.size(); ++i) { + if (!requests_[i]->pending()) + continue; // Don't wait for the requests we cancelled. + EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i; + } + + // Verify that they called out the the resolver proc (which runs on the + // resolver thread) in the expected order. + MockHostResolverProc::CaptureList capture_list = proc_->GetCaptureList(); + ASSERT_EQ(4u, capture_list.size()); + + EXPECT_EQ("req0", capture_list[0].hostname); + EXPECT_EQ("req2", capture_list[1].hostname); + EXPECT_EQ("req6", capture_list[2].hostname); + EXPECT_EQ("req3", capture_list[3].hostname); +} + +// Test that when too many requests are enqueued, old ones start to be aborted. +TEST_F(HostResolverImplTest, QueueOverflow) { + CreateSerialResolver(); + + // Allow only 3 queued jobs. + const size_t kMaxPendingJobs = 3u; + resolver_->SetMaxQueuedJobs(kMaxPendingJobs); + + // Note that at this point the MockHostResolverProc is blocked, so any + // requests we make will not complete. + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("req0", 80, LOWEST)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("req1", 80, HIGHEST)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("req2", 80, MEDIUM)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("req3", 80, MEDIUM)->Resolve()); + + // At this point, there are 3 enqueued jobs. + // Insertion of subsequent requests will cause evictions + // based on priority. + + EXPECT_EQ(ERR_HOST_RESOLVER_QUEUE_TOO_LARGE, + CreateRequest("req4", 80, LOW)->Resolve()); // Evicts itself! + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("req5", 80, MEDIUM)->Resolve()); + EXPECT_EQ(ERR_HOST_RESOLVER_QUEUE_TOO_LARGE, requests_[2]->result()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("req6", 80, HIGHEST)->Resolve()); + EXPECT_EQ(ERR_HOST_RESOLVER_QUEUE_TOO_LARGE, requests_[3]->result()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("req7", 80, MEDIUM)->Resolve()); + EXPECT_EQ(ERR_HOST_RESOLVER_QUEUE_TOO_LARGE, requests_[5]->result()); + + // Unblock the resolver thread so the requests can run. + proc_->SignalMultiple(4u); + + // The rest should succeed. + EXPECT_EQ(OK, requests_[7]->WaitForResult()); + EXPECT_EQ(OK, requests_[0]->result()); + EXPECT_EQ(OK, requests_[1]->result()); + EXPECT_EQ(OK, requests_[6]->result()); + + // Verify that they called out the the resolver proc (which runs on the + // resolver thread) in the expected order. + MockHostResolverProc::CaptureList capture_list = proc_->GetCaptureList(); + ASSERT_EQ(4u, capture_list.size()); + + EXPECT_EQ("req0", capture_list[0].hostname); + EXPECT_EQ("req1", capture_list[1].hostname); + EXPECT_EQ("req6", capture_list[2].hostname); + EXPECT_EQ("req7", capture_list[3].hostname); + + // Verify that the evicted (incomplete) requests were not cached. + EXPECT_EQ(4u, resolver_->GetHostCache()->size()); + + for (size_t i = 0; i < requests_.size(); ++i) { + EXPECT_TRUE(requests_[i]->completed()) << i; + } +} + +// Tests that after changing the default AddressFamily to IPV4, requests +// with UNSPECIFIED address family map to IPV4. +TEST_F(HostResolverImplTest, SetDefaultAddressFamily_IPv4) { + CreateSerialResolver(); // To guarantee order of resolutions. + + proc_->AddRule("h1", ADDRESS_FAMILY_IPV4, "1.0.0.1"); + proc_->AddRule("h1", ADDRESS_FAMILY_IPV6, "::2"); + + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV4); + + CreateRequest("h1", 80, MEDIUM, ADDRESS_FAMILY_UNSPECIFIED); + CreateRequest("h1", 80, MEDIUM, ADDRESS_FAMILY_IPV4); + CreateRequest("h1", 80, MEDIUM, ADDRESS_FAMILY_IPV6); + + // Start all of the requests. + for (size_t i = 0; i < requests_.size(); ++i) { + EXPECT_EQ(ERR_IO_PENDING, requests_[i]->Resolve()) << i; + } + + proc_->SignalMultiple(requests_.size()); + + // Wait for all the requests to complete. + for (size_t i = 0u; i < requests_.size(); ++i) { + EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i; + } + + // Since the requests all had the same priority and we limited the thread + // count to 1, they should have completed in the same order as they were + // requested. Moreover, request0 and request1 will have been serviced by + // the same job. + + MockHostResolverProc::CaptureList capture_list = proc_->GetCaptureList(); + ASSERT_EQ(2u, capture_list.size()); + + EXPECT_EQ("h1", capture_list[0].hostname); + EXPECT_EQ(ADDRESS_FAMILY_IPV4, capture_list[0].address_family); + + EXPECT_EQ("h1", capture_list[1].hostname); + EXPECT_EQ(ADDRESS_FAMILY_IPV6, capture_list[1].address_family); + + // Now check that the correct resolved IP addresses were returned. + EXPECT_TRUE(requests_[0]->HasOneAddress("1.0.0.1", 80)); + EXPECT_TRUE(requests_[1]->HasOneAddress("1.0.0.1", 80)); + EXPECT_TRUE(requests_[2]->HasOneAddress("::2", 80)); +} + +// This is the exact same test as SetDefaultAddressFamily_IPv4, except the +// default family is set to IPv6 and the family of requests is flipped where +// specified. +TEST_F(HostResolverImplTest, SetDefaultAddressFamily_IPv6) { + CreateSerialResolver(); // To guarantee order of resolutions. + + // Don't use IPv6 replacements here since some systems don't support it. + proc_->AddRule("h1", ADDRESS_FAMILY_IPV4, "1.0.0.1"); + proc_->AddRule("h1", ADDRESS_FAMILY_IPV6, "::2"); + + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV6); + + CreateRequest("h1", 80, MEDIUM, ADDRESS_FAMILY_UNSPECIFIED); + CreateRequest("h1", 80, MEDIUM, ADDRESS_FAMILY_IPV6); + CreateRequest("h1", 80, MEDIUM, ADDRESS_FAMILY_IPV4); + + // Start all of the requests. + for (size_t i = 0; i < requests_.size(); ++i) { + EXPECT_EQ(ERR_IO_PENDING, requests_[i]->Resolve()) << i; + } + + proc_->SignalMultiple(requests_.size()); + + // Wait for all the requests to complete. + for (size_t i = 0u; i < requests_.size(); ++i) { + EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i; + } + + // Since the requests all had the same priority and we limited the thread + // count to 1, they should have completed in the same order as they were + // requested. Moreover, request0 and request1 will have been serviced by + // the same job. + + MockHostResolverProc::CaptureList capture_list = proc_->GetCaptureList(); + ASSERT_EQ(2u, capture_list.size()); + + EXPECT_EQ("h1", capture_list[0].hostname); + EXPECT_EQ(ADDRESS_FAMILY_IPV6, capture_list[0].address_family); + + EXPECT_EQ("h1", capture_list[1].hostname); + EXPECT_EQ(ADDRESS_FAMILY_IPV4, capture_list[1].address_family); + + // Now check that the correct resolved IP addresses were returned. + EXPECT_TRUE(requests_[0]->HasOneAddress("::2", 80)); + EXPECT_TRUE(requests_[1]->HasOneAddress("::2", 80)); + EXPECT_TRUE(requests_[2]->HasOneAddress("1.0.0.1", 80)); +} + +TEST_F(HostResolverImplTest, ResolveFromCache) { + proc_->AddRuleForAllFamilies("just.testing", "192.168.1.42"); + proc_->SignalMultiple(1u); // Need only one. + + HostResolver::RequestInfo info(HostPortPair("just.testing", 80)); + + // First hit will miss the cache. + EXPECT_EQ(ERR_DNS_CACHE_MISS, CreateRequest(info)->ResolveFromCache()); + + // This time, we fetch normally. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest(info)->Resolve()); + EXPECT_EQ(OK, requests_[1]->WaitForResult()); + + // Now we should be able to fetch from the cache. + EXPECT_EQ(OK, CreateRequest(info)->ResolveFromCache()); + EXPECT_TRUE(requests_[2]->HasOneAddress("192.168.1.42", 80)); +} + +// Test the retry attempts simulating host resolver proc that takes too long. +TEST_F(HostResolverImplTest, MultipleAttempts) { + // Total number of attempts would be 3 and we want the 3rd attempt to resolve + // the host. First and second attempt will be forced to sleep until they get + // word that a resolution has completed. The 3rd resolution attempt will try + // to get done ASAP, and won't sleep.. + int kAttemptNumberToResolve = 3; + int kTotalAttempts = 3; + + scoped_refptr<LookupAttemptHostResolverProc> resolver_proc( + new LookupAttemptHostResolverProc( + NULL, kAttemptNumberToResolve, kTotalAttempts)); + + HostResolverImpl::ProcTaskParams params = DefaultParams(resolver_proc.get()); + + // Specify smaller interval for unresponsive_delay_ for HostResolverImpl so + // that unit test runs faster. For example, this test finishes in 1.5 secs + // (500ms * 3). + params.unresponsive_delay = base::TimeDelta::FromMilliseconds(500); + + resolver_.reset( + new HostResolverImpl(HostCache::CreateDefaultCache(), + DefaultLimits(), + params, + NULL)); + + // Resolve "host1". + HostResolver::RequestInfo info(HostPortPair("host1", 70)); + Request* req = CreateRequest(info); + EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); + + // Resolve returns -4 to indicate that 3rd attempt has resolved the host. + EXPECT_EQ(-4, req->WaitForResult()); + + resolver_proc->WaitForAllAttemptsToFinish( + base::TimeDelta::FromMilliseconds(60000)); + base::MessageLoop::current()->RunUntilIdle(); + + EXPECT_EQ(resolver_proc->total_attempts_resolved(), kTotalAttempts); + EXPECT_EQ(resolver_proc->resolved_attempt_number(), kAttemptNumberToResolve); +} + +DnsConfig CreateValidDnsConfig() { + IPAddressNumber dns_ip; + bool rv = ParseIPLiteralToNumber("192.168.1.0", &dns_ip); + EXPECT_TRUE(rv); + + DnsConfig config; + config.nameservers.push_back(IPEndPoint(dns_ip, dns_protocol::kDefaultPort)); + EXPECT_TRUE(config.IsValid()); + return config; +} + +// Specialized fixture for tests of DnsTask. +class HostResolverImplDnsTest : public HostResolverImplTest { + protected: + virtual void SetUp() OVERRIDE { + AddDnsRule("nx", dns_protocol::kTypeA, MockDnsClientRule::FAIL); + AddDnsRule("nx", dns_protocol::kTypeAAAA, MockDnsClientRule::FAIL); + AddDnsRule("ok", dns_protocol::kTypeA, MockDnsClientRule::OK); + AddDnsRule("ok", dns_protocol::kTypeAAAA, MockDnsClientRule::OK); + AddDnsRule("4ok", dns_protocol::kTypeA, MockDnsClientRule::OK); + AddDnsRule("4ok", dns_protocol::kTypeAAAA, MockDnsClientRule::EMPTY); + AddDnsRule("6ok", dns_protocol::kTypeA, MockDnsClientRule::EMPTY); + AddDnsRule("6ok", dns_protocol::kTypeAAAA, MockDnsClientRule::OK); + AddDnsRule("4nx", dns_protocol::kTypeA, MockDnsClientRule::OK); + AddDnsRule("4nx", dns_protocol::kTypeAAAA, MockDnsClientRule::FAIL); + CreateResolver(); + } + + void CreateResolver() { + resolver_.reset(new HostResolverImpl(HostCache::CreateDefaultCache(), + DefaultLimits(), + DefaultParams(proc_.get()), + NULL)); + // Disable IPv6 support probing. + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_UNSPECIFIED); + resolver_->SetDnsClient(CreateMockDnsClient(DnsConfig(), dns_rules_)); + } + + // Adds a rule to |dns_rules_|. Must be followed by |CreateResolver| to apply. + void AddDnsRule(const std::string& prefix, + uint16 qtype, + MockDnsClientRule::Result result) { + dns_rules_.push_back(MockDnsClientRule(prefix, qtype, result)); + } + + void ChangeDnsConfig(const DnsConfig& config) { + NetworkChangeNotifier::SetDnsConfig(config); + // Notification is delivered asynchronously. + base::MessageLoop::current()->RunUntilIdle(); + } + + MockDnsClientRuleList dns_rules_; +}; + +// TODO(szym): Test AbortAllInProgressJobs due to DnsConfig change. + +// TODO(cbentzel): Test a mix of requests with different HostResolverFlags. + +// Test successful and fallback resolutions in HostResolverImpl::DnsTask. +TEST_F(HostResolverImplDnsTest, DnsTask) { + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV4); + + proc_->AddRuleForAllFamilies("nx_succeed", "192.168.1.102"); + // All other hostnames will fail in proc_. + + // Initially there is no config, so client should not be invoked. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok_fail", 80)->Resolve()); + proc_->SignalMultiple(requests_.size()); + + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, requests_[0]->WaitForResult()); + + ChangeDnsConfig(CreateValidDnsConfig()); + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok_fail", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("nx_fail", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("nx_succeed", 80)->Resolve()); + + proc_->SignalMultiple(requests_.size()); + + for (size_t i = 1; i < requests_.size(); ++i) + EXPECT_NE(ERR_UNEXPECTED, requests_[i]->WaitForResult()) << i; + + EXPECT_EQ(OK, requests_[1]->result()); + // Resolved by MockDnsClient. + EXPECT_TRUE(requests_[1]->HasOneAddress("127.0.0.1", 80)); + // Fallback to ProcTask. + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, requests_[2]->result()); + EXPECT_EQ(OK, requests_[3]->result()); + EXPECT_TRUE(requests_[3]->HasOneAddress("192.168.1.102", 80)); +} + +// Test successful and failing resolutions in HostResolverImpl::DnsTask when +// fallback to ProcTask is disabled. +TEST_F(HostResolverImplDnsTest, NoFallbackToProcTask) { + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV4); + set_fallback_to_proctask(false); + + proc_->AddRuleForAllFamilies("nx_succeed", "192.168.1.102"); + // All other hostnames will fail in proc_. + + // Set empty DnsConfig. + ChangeDnsConfig(DnsConfig()); + // Initially there is no config, so client should not be invoked. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok_fail", 80)->Resolve()); + // There is no config, so fallback to ProcTask must work. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("nx_succeed", 80)->Resolve()); + proc_->SignalMultiple(requests_.size()); + + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, requests_[0]->WaitForResult()); + EXPECT_EQ(OK, requests_[1]->WaitForResult()); + EXPECT_TRUE(requests_[1]->HasOneAddress("192.168.1.102", 80)); + + ChangeDnsConfig(CreateValidDnsConfig()); + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok_abort", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("nx_abort", 80)->Resolve()); + + // Simulate the case when the preference or policy has disabled the DNS client + // causing AbortDnsTasks. + resolver_->SetDnsClient(CreateMockDnsClient(DnsConfig(), dns_rules_)); + ChangeDnsConfig(CreateValidDnsConfig()); + + // First request is resolved by MockDnsClient, others should fail due to + // disabled fallback to ProcTask. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok_fail", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("nx_fail", 80)->Resolve()); + proc_->SignalMultiple(requests_.size()); + + // Aborted due to Network Change. + EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[2]->WaitForResult()); + EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[3]->WaitForResult()); + // Resolved by MockDnsClient. + EXPECT_EQ(OK, requests_[4]->WaitForResult()); + EXPECT_TRUE(requests_[4]->HasOneAddress("127.0.0.1", 80)); + // Fallback to ProcTask is disabled. + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, requests_[5]->WaitForResult()); +} + +// Test behavior of OnDnsTaskFailure when Job is aborted. +TEST_F(HostResolverImplDnsTest, OnDnsTaskFailureAbortedJob) { + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV4); + ChangeDnsConfig(CreateValidDnsConfig()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("nx_abort", 80)->Resolve()); + // Abort all jobs here. + CreateResolver(); + proc_->SignalMultiple(requests_.size()); + // Run to completion. + base::MessageLoop::current()->RunUntilIdle(); // Notification happens async. + // It shouldn't crash during OnDnsTaskFailure callbacks. + EXPECT_EQ(ERR_IO_PENDING, requests_[0]->result()); + + // Repeat test with Fallback to ProcTask disabled + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV4); + set_fallback_to_proctask(false); + ChangeDnsConfig(CreateValidDnsConfig()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("nx_abort", 80)->Resolve()); + // Abort all jobs here. + CreateResolver(); + // Run to completion. + base::MessageLoop::current()->RunUntilIdle(); // Notification happens async. + // It shouldn't crash during OnDnsTaskFailure callbacks. + EXPECT_EQ(ERR_IO_PENDING, requests_[1]->result()); +} + +TEST_F(HostResolverImplDnsTest, DnsTaskUnspec) { + ChangeDnsConfig(CreateValidDnsConfig()); + + proc_->AddRuleForAllFamilies("4nx", "192.168.1.101"); + // All other hostnames will fail in proc_. + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("4ok", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("6ok", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("4nx", 80)->Resolve()); + + proc_->SignalMultiple(requests_.size()); + + for (size_t i = 0; i < requests_.size(); ++i) + EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i; + + EXPECT_EQ(2u, requests_[0]->NumberOfAddresses()); + EXPECT_TRUE(requests_[0]->HasAddress("127.0.0.1", 80)); + EXPECT_TRUE(requests_[0]->HasAddress("::1", 80)); + EXPECT_EQ(1u, requests_[1]->NumberOfAddresses()); + EXPECT_TRUE(requests_[1]->HasAddress("127.0.0.1", 80)); + EXPECT_EQ(1u, requests_[2]->NumberOfAddresses()); + EXPECT_TRUE(requests_[2]->HasAddress("::1", 80)); + EXPECT_EQ(1u, requests_[3]->NumberOfAddresses()); + EXPECT_TRUE(requests_[3]->HasAddress("192.168.1.101", 80)); +} + +TEST_F(HostResolverImplDnsTest, ServeFromHosts) { + // Initially, use empty HOSTS file. + DnsConfig config = CreateValidDnsConfig(); + ChangeDnsConfig(config); + + proc_->AddRuleForAllFamilies(std::string(), + std::string()); // Default to failures. + proc_->SignalMultiple(1u); // For the first request which misses. + + Request* req0 = CreateRequest("nx_ipv4", 80); + EXPECT_EQ(ERR_IO_PENDING, req0->Resolve()); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req0->WaitForResult()); + + IPAddressNumber local_ipv4, local_ipv6; + ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &local_ipv4)); + ASSERT_TRUE(ParseIPLiteralToNumber("::1", &local_ipv6)); + + DnsHosts hosts; + hosts[DnsHostsKey("nx_ipv4", ADDRESS_FAMILY_IPV4)] = local_ipv4; + hosts[DnsHostsKey("nx_ipv6", ADDRESS_FAMILY_IPV6)] = local_ipv6; + hosts[DnsHostsKey("nx_both", ADDRESS_FAMILY_IPV4)] = local_ipv4; + hosts[DnsHostsKey("nx_both", ADDRESS_FAMILY_IPV6)] = local_ipv6; + + // Update HOSTS file. + config.hosts = hosts; + ChangeDnsConfig(config); + + Request* req1 = CreateRequest("nx_ipv4", 80); + EXPECT_EQ(OK, req1->Resolve()); + EXPECT_TRUE(req1->HasOneAddress("127.0.0.1", 80)); + + Request* req2 = CreateRequest("nx_ipv6", 80); + EXPECT_EQ(OK, req2->Resolve()); + EXPECT_TRUE(req2->HasOneAddress("::1", 80)); + + Request* req3 = CreateRequest("nx_both", 80); + EXPECT_EQ(OK, req3->Resolve()); + EXPECT_TRUE(req3->HasAddress("127.0.0.1", 80) && + req3->HasAddress("::1", 80)); + + // Requests with specified AddressFamily. + Request* req4 = CreateRequest("nx_ipv4", 80, MEDIUM, ADDRESS_FAMILY_IPV4); + EXPECT_EQ(OK, req4->Resolve()); + EXPECT_TRUE(req4->HasOneAddress("127.0.0.1", 80)); + + Request* req5 = CreateRequest("nx_ipv6", 80, MEDIUM, ADDRESS_FAMILY_IPV6); + EXPECT_EQ(OK, req5->Resolve()); + EXPECT_TRUE(req5->HasOneAddress("::1", 80)); + + // Request with upper case. + Request* req6 = CreateRequest("nx_IPV4", 80); + EXPECT_EQ(OK, req6->Resolve()); + EXPECT_TRUE(req6->HasOneAddress("127.0.0.1", 80)); +} + +TEST_F(HostResolverImplDnsTest, BypassDnsTask) { + ChangeDnsConfig(CreateValidDnsConfig()); + + proc_->AddRuleForAllFamilies(std::string(), + std::string()); // Default to failures. + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok.local", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok.local.", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("oklocal", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("oklocal.", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok", 80)->Resolve()); + + proc_->SignalMultiple(requests_.size()); + + for (size_t i = 0; i < 2; ++i) + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, requests_[i]->WaitForResult()) << i; + + for (size_t i = 2; i < requests_.size(); ++i) + EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i; +} + +TEST_F(HostResolverImplDnsTest, DisableDnsClientOnPersistentFailure) { + ChangeDnsConfig(CreateValidDnsConfig()); + + proc_->AddRuleForAllFamilies(std::string(), + std::string()); // Default to failures. + + // Check that DnsTask works. + Request* req = CreateRequest("ok_1", 80); + EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); + EXPECT_EQ(OK, req->WaitForResult()); + + for (unsigned i = 0; i < 20; ++i) { + // Use custom names to require separate Jobs. + std::string hostname = base::StringPrintf("nx_%u", i); + // Ensure fallback to ProcTask succeeds. + proc_->AddRuleForAllFamilies(hostname, "192.168.1.101"); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest(hostname, 80)->Resolve()) << i; + } + + proc_->SignalMultiple(requests_.size()); + + for (size_t i = 0; i < requests_.size(); ++i) + EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i; + + ASSERT_FALSE(proc_->HasBlockedRequests()); + + // DnsTask should be disabled by now. + req = CreateRequest("ok_2", 80); + EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); + proc_->SignalMultiple(1u); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, req->WaitForResult()); + + // Check that it is re-enabled after DNS change. + ChangeDnsConfig(CreateValidDnsConfig()); + req = CreateRequest("ok_3", 80); + EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); + EXPECT_EQ(OK, req->WaitForResult()); +} + +TEST_F(HostResolverImplDnsTest, DontDisableDnsClientOnSporadicFailure) { + ChangeDnsConfig(CreateValidDnsConfig()); + + // |proc_| defaults to successes. + + // 20 failures interleaved with 20 successes. + for (unsigned i = 0; i < 40; ++i) { + // Use custom names to require separate Jobs. + std::string hostname = (i % 2) == 0 ? base::StringPrintf("nx_%u", i) + : base::StringPrintf("ok_%u", i); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest(hostname, 80)->Resolve()) << i; + } + + proc_->SignalMultiple(requests_.size()); + + for (size_t i = 0; i < requests_.size(); ++i) + EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i; + + // Make |proc_| default to failures. + proc_->AddRuleForAllFamilies(std::string(), std::string()); + + // DnsTask should still be enabled. + Request* req = CreateRequest("ok_last", 80); + EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); + EXPECT_EQ(OK, req->WaitForResult()); +} + +// Confirm that resolving "localhost" is unrestricted even if there are no +// global IPv6 address. See SystemHostResolverCall for rationale. +// Test both the DnsClient and system host resolver paths. +TEST_F(HostResolverImplDnsTest, DualFamilyLocalhost) { + // Use regular SystemHostResolverCall! + scoped_refptr<HostResolverProc> proc(new SystemHostResolverProc()); + resolver_.reset(new HostResolverImpl(HostCache::CreateDefaultCache(), + DefaultLimits(), + DefaultParams(proc.get()), + NULL)); + resolver_->SetDnsClient(CreateMockDnsClient(DnsConfig(), dns_rules_)); + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV4); + + // Get the expected output. + AddressList addrlist; + int rv = proc->Resolve("localhost", ADDRESS_FAMILY_UNSPECIFIED, 0, &addrlist, + NULL); + if (rv != OK) + return; + + for (unsigned i = 0; i < addrlist.size(); ++i) + LOG(WARNING) << addrlist[i].ToString(); + + bool saw_ipv4 = AddressListContains(addrlist, "127.0.0.1", 0); + bool saw_ipv6 = AddressListContains(addrlist, "::1", 0); + if (!saw_ipv4 && !saw_ipv6) + return; + + HostResolver::RequestInfo info(HostPortPair("localhost", 80)); + info.set_address_family(ADDRESS_FAMILY_UNSPECIFIED); + info.set_host_resolver_flags(HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6); + + // Try without DnsClient. + ChangeDnsConfig(DnsConfig()); + Request* req = CreateRequest(info); + // It is resolved via getaddrinfo, so expect asynchronous result. + EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); + EXPECT_EQ(OK, req->WaitForResult()); + + EXPECT_EQ(saw_ipv4, req->HasAddress("127.0.0.1", 80)); + EXPECT_EQ(saw_ipv6, req->HasAddress("::1", 80)); + + // Configure DnsClient with dual-host HOSTS file. + DnsConfig config = CreateValidDnsConfig(); + DnsHosts hosts; + IPAddressNumber local_ipv4, local_ipv6; + ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &local_ipv4)); + ASSERT_TRUE(ParseIPLiteralToNumber("::1", &local_ipv6)); + if (saw_ipv4) + hosts[DnsHostsKey("localhost", ADDRESS_FAMILY_IPV4)] = local_ipv4; + if (saw_ipv6) + hosts[DnsHostsKey("localhost", ADDRESS_FAMILY_IPV6)] = local_ipv6; + config.hosts = hosts; + + ChangeDnsConfig(config); + req = CreateRequest(info); + // Expect synchronous resolution from DnsHosts. + EXPECT_EQ(OK, req->Resolve()); + + EXPECT_EQ(saw_ipv4, req->HasAddress("127.0.0.1", 80)); + EXPECT_EQ(saw_ipv6, req->HasAddress("::1", 80)); +} + +} // namespace net diff --git a/chromium/net/dns/host_resolver_proc.cc b/chromium/net/dns/host_resolver_proc.cc new file mode 100644 index 00000000000..f2b10c649ba --- /dev/null +++ b/chromium/net/dns/host_resolver_proc.cc @@ -0,0 +1,267 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/host_resolver_proc.h" + +#include "build/build_config.h" + +#include "base/logging.h" +#include "base/sys_byteorder.h" +#include "net/base/address_list.h" +#include "net/base/dns_reloader.h" +#include "net/base/net_errors.h" +#include "net/base/sys_addrinfo.h" + +#if defined(OS_OPENBSD) +#define AI_ADDRCONFIG 0 +#endif + +namespace net { + +namespace { + +bool IsAllLocalhostOfOneFamily(const struct addrinfo* ai) { + bool saw_v4_localhost = false; + bool saw_v6_localhost = false; + for (; ai != NULL; ai = ai->ai_next) { + switch (ai->ai_family) { + case AF_INET: { + const struct sockaddr_in* addr_in = + reinterpret_cast<struct sockaddr_in*>(ai->ai_addr); + if ((base::NetToHost32(addr_in->sin_addr.s_addr) & 0xff000000) == + 0x7f000000) + saw_v4_localhost = true; + else + return false; + break; + } + case AF_INET6: { + const struct sockaddr_in6* addr_in6 = + reinterpret_cast<struct sockaddr_in6*>(ai->ai_addr); + if (IN6_IS_ADDR_LOOPBACK(&addr_in6->sin6_addr)) + saw_v6_localhost = true; + else + return false; + break; + } + default: + NOTREACHED(); + return false; + } + } + + return saw_v4_localhost != saw_v6_localhost; +} + +} // namespace + +HostResolverProc* HostResolverProc::default_proc_ = NULL; + +HostResolverProc::HostResolverProc(HostResolverProc* previous) { + SetPreviousProc(previous); + + // Implicitly fall-back to the global default procedure. + if (!previous) + SetPreviousProc(default_proc_); +} + +HostResolverProc::~HostResolverProc() { +} + +int HostResolverProc::ResolveUsingPrevious( + const std::string& host, + AddressFamily address_family, + HostResolverFlags host_resolver_flags, + AddressList* addrlist, + int* os_error) { + if (previous_proc_.get()) { + return previous_proc_->Resolve( + host, address_family, host_resolver_flags, addrlist, os_error); + } + + // Final fallback is the system resolver. + return SystemHostResolverCall(host, address_family, host_resolver_flags, + addrlist, os_error); +} + +void HostResolverProc::SetPreviousProc(HostResolverProc* proc) { + HostResolverProc* current_previous = previous_proc_.get(); + previous_proc_ = NULL; + // Now that we've guaranteed |this| is the last proc in a chain, we can + // detect potential cycles using GetLastProc(). + previous_proc_ = (GetLastProc(proc) == this) ? current_previous : proc; +} + +void HostResolverProc::SetLastProc(HostResolverProc* proc) { + GetLastProc(this)->SetPreviousProc(proc); +} + +// static +HostResolverProc* HostResolverProc::GetLastProc(HostResolverProc* proc) { + if (proc == NULL) + return NULL; + HostResolverProc* last_proc = proc; + while (last_proc->previous_proc_.get() != NULL) + last_proc = last_proc->previous_proc_.get(); + return last_proc; +} + +// static +HostResolverProc* HostResolverProc::SetDefault(HostResolverProc* proc) { + HostResolverProc* old = default_proc_; + default_proc_ = proc; + return old; +} + +// static +HostResolverProc* HostResolverProc::GetDefault() { + return default_proc_; +} + +int SystemHostResolverCall(const std::string& host, + AddressFamily address_family, + HostResolverFlags host_resolver_flags, + AddressList* addrlist, + int* os_error) { + if (os_error) + *os_error = 0; + + struct addrinfo* ai = NULL; + struct addrinfo hints = {0}; + + switch (address_family) { + case ADDRESS_FAMILY_IPV4: + hints.ai_family = AF_INET; + break; + case ADDRESS_FAMILY_IPV6: + hints.ai_family = AF_INET6; + break; + case ADDRESS_FAMILY_UNSPECIFIED: + hints.ai_family = AF_UNSPEC; + break; + default: + NOTREACHED(); + hints.ai_family = AF_UNSPEC; + } + +#if defined(OS_WIN) + // DO NOT USE AI_ADDRCONFIG ON WINDOWS. + // + // The following comment in <winsock2.h> is the best documentation I found + // on AI_ADDRCONFIG for Windows: + // Flags used in "hints" argument to getaddrinfo() + // - AI_ADDRCONFIG is supported starting with Vista + // - default is AI_ADDRCONFIG ON whether the flag is set or not + // because the performance penalty in not having ADDRCONFIG in + // the multi-protocol stack environment is severe; + // this defaulting may be disabled by specifying the AI_ALL flag, + // in that case AI_ADDRCONFIG must be EXPLICITLY specified to + // enable ADDRCONFIG behavior + // + // Not only is AI_ADDRCONFIG unnecessary, but it can be harmful. If the + // computer is not connected to a network, AI_ADDRCONFIG causes getaddrinfo + // to fail with WSANO_DATA (11004) for "localhost", probably because of the + // following note on AI_ADDRCONFIG in the MSDN getaddrinfo page: + // The IPv4 or IPv6 loopback address is not considered a valid global + // address. + // See http://crbug.com/5234. + // + // OpenBSD does not support it, either. + hints.ai_flags = 0; +#else + hints.ai_flags = AI_ADDRCONFIG; +#endif + + // On Linux AI_ADDRCONFIG doesn't consider loopback addreses, even if only + // loopback addresses are configured. So don't use it when there are only + // loopback addresses. + if (host_resolver_flags & HOST_RESOLVER_LOOPBACK_ONLY) + hints.ai_flags &= ~AI_ADDRCONFIG; + + if (host_resolver_flags & HOST_RESOLVER_CANONNAME) + hints.ai_flags |= AI_CANONNAME; + + // Restrict result set to only this socket type to avoid duplicates. + hints.ai_socktype = SOCK_STREAM; + +#if defined(OS_POSIX) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) && \ + !defined(OS_ANDROID) + DnsReloaderMaybeReload(); +#endif + int err = getaddrinfo(host.c_str(), NULL, &hints, &ai); + bool should_retry = false; + // If the lookup was restricted (either by address family, or address + // detection), and the results where all localhost of a single family, + // maybe we should retry. There were several bugs related to these + // issues, for example http://crbug.com/42058 and http://crbug.com/49024 + if ((hints.ai_family != AF_UNSPEC || hints.ai_flags & AI_ADDRCONFIG) && + err == 0 && IsAllLocalhostOfOneFamily(ai)) { + if (host_resolver_flags & HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6) { + hints.ai_family = AF_UNSPEC; + should_retry = true; + } + if (hints.ai_flags & AI_ADDRCONFIG) { + hints.ai_flags &= ~AI_ADDRCONFIG; + should_retry = true; + } + } + if (should_retry) { + if (ai != NULL) { + freeaddrinfo(ai); + ai = NULL; + } + err = getaddrinfo(host.c_str(), NULL, &hints, &ai); + } + + if (err) { +#if defined(OS_WIN) + err = WSAGetLastError(); +#endif + + // Return the OS error to the caller. + if (os_error) + *os_error = err; + + // If the call to getaddrinfo() failed because of a system error, report + // it separately from ERR_NAME_NOT_RESOLVED. +#if defined(OS_WIN) + if (err != WSAHOST_NOT_FOUND && err != WSANO_DATA) + return ERR_NAME_RESOLUTION_FAILED; +#elif defined(OS_POSIX) && !defined(OS_FREEBSD) + if (err != EAI_NONAME && err != EAI_NODATA) + return ERR_NAME_RESOLUTION_FAILED; +#endif + + return ERR_NAME_NOT_RESOLVED; + } + +#if defined(OS_ANDROID) + // Workaround for Android's getaddrinfo leaving ai==NULL without an error. + // http://crbug.com/134142 + if (ai == NULL) + return ERR_NAME_NOT_RESOLVED; +#endif + + *addrlist = AddressList::CreateFromAddrinfo(ai); + freeaddrinfo(ai); + return OK; +} + +SystemHostResolverProc::SystemHostResolverProc() : HostResolverProc(NULL) {} + +int SystemHostResolverProc::Resolve(const std::string& hostname, + AddressFamily address_family, + HostResolverFlags host_resolver_flags, + AddressList* addr_list, + int* os_error) { + return SystemHostResolverCall(hostname, + address_family, + host_resolver_flags, + addr_list, + os_error); +} + +SystemHostResolverProc::~SystemHostResolverProc() {} + +} // namespace net diff --git a/chromium/net/dns/host_resolver_proc.h b/chromium/net/dns/host_resolver_proc.h new file mode 100644 index 00000000000..014a720e375 --- /dev/null +++ b/chromium/net/dns/host_resolver_proc.h @@ -0,0 +1,111 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_HOST_RESOLVER_PROC_H_ +#define NET_DNS_HOST_RESOLVER_PROC_H_ + +#include <string> + +#include "base/memory/ref_counted.h" +#include "net/base/address_family.h" +#include "net/base/net_export.h" + +namespace net { + +class AddressList; + +// Interface for a getaddrinfo()-like procedure. This is used by unit-tests +// to control the underlying resolutions in HostResolverImpl. HostResolverProcs +// can be chained together; they fallback to the next procedure in the chain +// by calling ResolveUsingPrevious(). +// +// Note that implementations of HostResolverProc *MUST BE THREADSAFE*, since +// the HostResolver implementation using them can be multi-threaded. +class NET_EXPORT HostResolverProc + : public base::RefCountedThreadSafe<HostResolverProc> { + public: + explicit HostResolverProc(HostResolverProc* previous); + + // Resolves |host| to an address list, restricting the results to addresses + // in |address_family|. If successful returns OK and fills |addrlist| with + // a list of socket addresses. Otherwise returns a network error code, and + // fills |os_error| with a more specific error if it was non-NULL. + virtual int Resolve(const std::string& host, + AddressFamily address_family, + HostResolverFlags host_resolver_flags, + AddressList* addrlist, + int* os_error) = 0; + + protected: + friend class base::RefCountedThreadSafe<HostResolverProc>; + + virtual ~HostResolverProc(); + + // Asks the fallback procedure (if set) to do the resolve. + int ResolveUsingPrevious(const std::string& host, + AddressFamily address_family, + HostResolverFlags host_resolver_flags, + AddressList* addrlist, + int* os_error); + + private: + friend class HostResolverImpl; + friend class MockHostResolverBase; + friend class ScopedDefaultHostResolverProc; + + // Sets the previous procedure in the chain. Aborts if this would result in a + // cycle. + void SetPreviousProc(HostResolverProc* proc); + + // Sets the last procedure in the chain, i.e. appends |proc| to the end of the + // current chain. Aborts if this would result in a cycle. + void SetLastProc(HostResolverProc* proc); + + // Returns the last procedure in the chain starting at |proc|. Will return + // NULL iff |proc| is NULL. + static HostResolverProc* GetLastProc(HostResolverProc* proc); + + // Sets the default host resolver procedure that is used by HostResolverImpl. + // This can be used through ScopedDefaultHostResolverProc to set a catch-all + // DNS block in unit-tests (individual tests should use MockHostResolver to + // prevent hitting the network). + static HostResolverProc* SetDefault(HostResolverProc* proc); + static HostResolverProc* GetDefault(); + + scoped_refptr<HostResolverProc> previous_proc_; + static HostResolverProc* default_proc_; + + DISALLOW_COPY_AND_ASSIGN(HostResolverProc); +}; + +// Resolves |host| to an address list, using the system's default host resolver. +// (i.e. this calls out to getaddrinfo()). If successful returns OK and fills +// |addrlist| with a list of socket addresses. Otherwise returns a +// network error code, and fills |os_error| with a more specific error if it +// was non-NULL. +NET_EXPORT_PRIVATE int SystemHostResolverCall( + const std::string& host, + AddressFamily address_family, + HostResolverFlags host_resolver_flags, + AddressList* addrlist, + int* os_error); + +// Wraps call to SystemHostResolverCall as an instance of HostResolverProc. +class NET_EXPORT_PRIVATE SystemHostResolverProc : public HostResolverProc { + public: + SystemHostResolverProc(); + virtual int Resolve(const std::string& hostname, + AddressFamily address_family, + HostResolverFlags host_resolver_flags, + AddressList* addr_list, + int* os_error) OVERRIDE; + protected: + virtual ~SystemHostResolverProc(); + + DISALLOW_COPY_AND_ASSIGN(SystemHostResolverProc); +}; + +} // namespace net + +#endif // NET_DNS_HOST_RESOLVER_PROC_H_ diff --git a/chromium/net/dns/mapped_host_resolver.cc b/chromium/net/dns/mapped_host_resolver.cc new file mode 100644 index 00000000000..4db7bc97928 --- /dev/null +++ b/chromium/net/dns/mapped_host_resolver.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/mapped_host_resolver.h" + +#include "base/strings/string_util.h" +#include "net/base/host_port_pair.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" + +namespace net { + +MappedHostResolver::MappedHostResolver(scoped_ptr<HostResolver> impl) + : impl_(impl.Pass()) { +} + +MappedHostResolver::~MappedHostResolver() { +} + +int MappedHostResolver::Resolve(const RequestInfo& original_info, + AddressList* addresses, + const CompletionCallback& callback, + RequestHandle* out_req, + const BoundNetLog& net_log) { + RequestInfo info = original_info; + int rv = ApplyRules(&info); + if (rv != OK) + return rv; + + return impl_->Resolve(info, addresses, callback, out_req, net_log); +} + +int MappedHostResolver::ResolveFromCache(const RequestInfo& original_info, + AddressList* addresses, + const BoundNetLog& net_log) { + RequestInfo info = original_info; + int rv = ApplyRules(&info); + if (rv != OK) + return rv; + + return impl_->ResolveFromCache(info, addresses, net_log); +} + +void MappedHostResolver::CancelRequest(RequestHandle req) { + impl_->CancelRequest(req); +} + +HostCache* MappedHostResolver::GetHostCache() { + return impl_->GetHostCache(); +} + +int MappedHostResolver::ApplyRules(RequestInfo* info) const { + HostPortPair host_port(info->host_port_pair()); + if (rules_.RewriteHost(&host_port)) { + if (host_port.host() == "~NOTFOUND") + return ERR_NAME_NOT_RESOLVED; + info->set_host_port_pair(host_port); + } + return OK; +} + +} // namespace net diff --git a/chromium/net/dns/mapped_host_resolver.h b/chromium/net/dns/mapped_host_resolver.h new file mode 100644 index 00000000000..50062a9848e --- /dev/null +++ b/chromium/net/dns/mapped_host_resolver.h @@ -0,0 +1,71 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_MAPPED_HOST_RESOLVER_H_ +#define NET_DNS_MAPPED_HOST_RESOLVER_H_ + +#include <string> + +#include "base/memory/scoped_ptr.h" +#include "net/base/host_mapping_rules.h" +#include "net/base/net_export.h" +#include "net/dns/host_resolver.h" + +namespace net { + +// This class wraps an existing HostResolver instance, but modifies the +// request before passing it off to |impl|. This is different from +// MockHostResolver which does the remapping at the HostResolverProc +// layer, so it is able to preserve the effectiveness of the cache. +class NET_EXPORT MappedHostResolver : public HostResolver { + public: + // Creates a MappedHostResolver that forwards all of its requests through + // |impl|. + explicit MappedHostResolver(scoped_ptr<HostResolver> impl); + virtual ~MappedHostResolver(); + + // Adds a rule to this mapper. The format of the rule can be one of: + // + // "MAP" <hostname_pattern> <replacement_host> [":" <replacement_port>] + // "EXCLUDE" <hostname_pattern> + // + // The <replacement_host> can be either a hostname, or an IP address literal, + // or "~NOTFOUND". If it is "~NOTFOUND" then all matched hostnames will fail + // to be resolved with ERR_NAME_NOT_RESOLVED. + // + // Returns true if the rule was successfully parsed and added. + bool AddRuleFromString(const std::string& rule_string) { + return rules_.AddRuleFromString(rule_string); + } + + // Takes a comma separated list of rules, and assigns them to this resolver. + void SetRulesFromString(const std::string& rules_string) { + rules_.SetRulesFromString(rules_string); + } + + // HostResolver methods: + virtual int Resolve(const RequestInfo& info, + AddressList* addresses, + const CompletionCallback& callback, + RequestHandle* out_req, + const BoundNetLog& net_log) OVERRIDE; + virtual int ResolveFromCache(const RequestInfo& info, + AddressList* addresses, + const BoundNetLog& net_log) OVERRIDE; + virtual void CancelRequest(RequestHandle req) OVERRIDE; + virtual HostCache* GetHostCache() OVERRIDE; + + private: + // Modify the request |info| according to |rules_|. Returns either OK or + // the network error code that the hostname's resolution mapped to. + int ApplyRules(RequestInfo* info) const; + + scoped_ptr<HostResolver> impl_; + + HostMappingRules rules_; +}; + +} // namespace net + +#endif // NET_DNS_MAPPED_HOST_RESOLVER_H_ diff --git a/chromium/net/dns/mapped_host_resolver_unittest.cc b/chromium/net/dns/mapped_host_resolver_unittest.cc new file mode 100644 index 00000000000..d8594663c9d --- /dev/null +++ b/chromium/net/dns/mapped_host_resolver_unittest.cc @@ -0,0 +1,219 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/mapped_host_resolver.h" + +#include "net/base/address_list.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/net_util.h" +#include "net/base/test_completion_callback.h" +#include "net/dns/mock_host_resolver.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +std::string FirstAddress(const AddressList& address_list) { + if (address_list.empty()) + return std::string(); + return address_list.front().ToString(); +} + +TEST(MappedHostResolverTest, Inclusion) { + // Create a mock host resolver, with specific hostname to IP mappings. + scoped_ptr<MockHostResolver> resolver_impl(new MockHostResolver()); + resolver_impl->rules()->AddSimulatedFailure("*google.com"); + resolver_impl->rules()->AddRule("baz.com", "192.168.1.5"); + resolver_impl->rules()->AddRule("foo.com", "192.168.1.8"); + resolver_impl->rules()->AddRule("proxy", "192.168.1.11"); + + // Create a remapped resolver that uses |resolver_impl|. + scoped_ptr<MappedHostResolver> resolver( + new MappedHostResolver(resolver_impl.PassAs<HostResolver>())); + + int rv; + AddressList address_list; + + // Try resolving "www.google.com:80". There are no mappings yet, so this + // hits |resolver_impl| and fails. + TestCompletionCallback callback; + rv = resolver->Resolve(HostResolver::RequestInfo( + HostPortPair("www.google.com", 80)), + &address_list, callback.callback(), NULL, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback.WaitForResult(); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv); + + // Remap *.google.com to baz.com. + EXPECT_TRUE(resolver->AddRuleFromString("map *.google.com baz.com")); + + // Try resolving "www.google.com:80". Should be remapped to "baz.com:80". + rv = resolver->Resolve(HostResolver::RequestInfo( + HostPortPair("www.google.com", 80)), + &address_list, callback.callback(), NULL, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_EQ("192.168.1.5:80", FirstAddress(address_list)); + + // Try resolving "foo.com:77". This will NOT be remapped, so result + // is "foo.com:77". + rv = resolver->Resolve(HostResolver::RequestInfo(HostPortPair("foo.com", 77)), + &address_list, callback.callback(), NULL, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_EQ("192.168.1.8:77", FirstAddress(address_list)); + + // Remap "*.org" to "proxy:99". + EXPECT_TRUE(resolver->AddRuleFromString("Map *.org proxy:99")); + + // Try resolving "chromium.org:61". Should be remapped to "proxy:99". + rv = resolver->Resolve(HostResolver::RequestInfo + (HostPortPair("chromium.org", 61)), + &address_list, callback.callback(), NULL, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_EQ("192.168.1.11:99", FirstAddress(address_list)); +} + +// Tests that exclusions are respected. +TEST(MappedHostResolverTest, Exclusion) { + // Create a mock host resolver, with specific hostname to IP mappings. + scoped_ptr<MockHostResolver> resolver_impl(new MockHostResolver()); + resolver_impl->rules()->AddRule("baz", "192.168.1.5"); + resolver_impl->rules()->AddRule("www.google.com", "192.168.1.3"); + + // Create a remapped resolver that uses |resolver_impl|. + scoped_ptr<MappedHostResolver> resolver( + new MappedHostResolver(resolver_impl.PassAs<HostResolver>())); + + int rv; + AddressList address_list; + TestCompletionCallback callback; + + // Remap "*.com" to "baz". + EXPECT_TRUE(resolver->AddRuleFromString("map *.com baz")); + + // Add an exclusion for "*.google.com". + EXPECT_TRUE(resolver->AddRuleFromString("EXCLUDE *.google.com")); + + // Try resolving "www.google.com". Should not be remapped due to exclusion). + rv = resolver->Resolve(HostResolver::RequestInfo( + HostPortPair("www.google.com", 80)), + &address_list, callback.callback(), NULL, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_EQ("192.168.1.3:80", FirstAddress(address_list)); + + // Try resolving "chrome.com:80". Should be remapped to "baz:80". + rv = resolver->Resolve(HostResolver::RequestInfo( + HostPortPair("chrome.com", 80)), + &address_list, callback.callback(), NULL, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_EQ("192.168.1.5:80", FirstAddress(address_list)); +} + +TEST(MappedHostResolverTest, SetRulesFromString) { + // Create a mock host resolver, with specific hostname to IP mappings. + scoped_ptr<MockHostResolver> resolver_impl(new MockHostResolver()); + resolver_impl->rules()->AddRule("baz", "192.168.1.7"); + resolver_impl->rules()->AddRule("bar", "192.168.1.9"); + + // Create a remapped resolver that uses |resolver_impl|. + scoped_ptr<MappedHostResolver> resolver( + new MappedHostResolver(resolver_impl.PassAs<HostResolver>())); + + int rv; + AddressList address_list; + TestCompletionCallback callback; + + // Remap "*.com" to "baz", and *.net to "bar:60". + resolver->SetRulesFromString("map *.com baz , map *.net bar:60"); + + // Try resolving "www.google.com". Should be remapped to "baz". + rv = resolver->Resolve(HostResolver::RequestInfo( + HostPortPair("www.google.com", 80)), + &address_list, callback.callback(), NULL, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_EQ("192.168.1.7:80", FirstAddress(address_list)); + + // Try resolving "chrome.net:80". Should be remapped to "bar:60". + rv = resolver->Resolve(HostResolver::RequestInfo( + HostPortPair("chrome.net", 80)), + &address_list, callback.callback(), NULL, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_EQ("192.168.1.9:60", FirstAddress(address_list)); +} + +// Parsing bad rules should silently discard the rule (and never crash). +TEST(MappedHostResolverTest, ParseInvalidRules) { + scoped_ptr<MappedHostResolver> resolver( + new MappedHostResolver(scoped_ptr<HostResolver>())); + + EXPECT_FALSE(resolver->AddRuleFromString("xyz")); + EXPECT_FALSE(resolver->AddRuleFromString(std::string())); + EXPECT_FALSE(resolver->AddRuleFromString(" ")); + EXPECT_FALSE(resolver->AddRuleFromString("EXCLUDE")); + EXPECT_FALSE(resolver->AddRuleFromString("EXCLUDE foo bar")); + EXPECT_FALSE(resolver->AddRuleFromString("INCLUDE")); + EXPECT_FALSE(resolver->AddRuleFromString("INCLUDE x")); + EXPECT_FALSE(resolver->AddRuleFromString("INCLUDE x :10")); +} + +// Test mapping hostnames to resolving failures. +TEST(MappedHostResolverTest, MapToError) { + scoped_ptr<MockHostResolver> resolver_impl(new MockHostResolver()); + resolver_impl->rules()->AddRule("*", "192.168.1.5"); + + scoped_ptr<MappedHostResolver> resolver( + new MappedHostResolver(resolver_impl.PassAs<HostResolver>())); + + int rv; + AddressList address_list; + + // Remap *.google.com to resolving failures. + EXPECT_TRUE(resolver->AddRuleFromString("MAP *.google.com ~NOTFOUND")); + + // Try resolving www.google.com --> Should give an error. + TestCompletionCallback callback1; + rv = resolver->Resolve(HostResolver::RequestInfo( + HostPortPair("www.google.com", 80)), + &address_list, callback1.callback(), NULL, + BoundNetLog()); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv); + + // Try resolving www.foo.com --> Should succeed. + TestCompletionCallback callback2; + rv = resolver->Resolve(HostResolver::RequestInfo( + HostPortPair("www.foo.com", 80)), + &address_list, callback2.callback(), NULL, + BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback2.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_EQ("192.168.1.5:80", FirstAddress(address_list)); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/dns/mdns_cache.cc b/chromium/net/dns/mdns_cache.cc new file mode 100644 index 00000000000..010a34f45d4 --- /dev/null +++ b/chromium/net/dns/mdns_cache.cc @@ -0,0 +1,212 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/mdns_cache.h" + +#include <algorithm> +#include <utility> + +#include "base/stl_util.h" +#include "base/strings/string_number_conversions.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/record_parsed.h" +#include "net/dns/record_rdata.h" + +// TODO(noamsml): Recursive CNAME closure (backwards and forwards). + +namespace net { + +// The effective TTL given to records with a nominal zero TTL. +// Allows time for hosts to send updated records, as detailed in RFC 6762 +// Section 10.1. +static const unsigned kZeroTTLSeconds = 1; + +MDnsCache::Key::Key(unsigned type, const std::string& name, + const std::string& optional) + : type_(type), name_(name), optional_(optional) { +} + +MDnsCache::Key::Key( + const MDnsCache::Key& other) + : type_(other.type_), name_(other.name_), optional_(other.optional_) { +} + + +MDnsCache::Key& MDnsCache::Key::operator=( + const MDnsCache::Key& other) { + type_ = other.type_; + name_ = other.name_; + optional_ = other.optional_; + return *this; +} + +MDnsCache::Key::~Key() { +} + +bool MDnsCache::Key::operator<(const MDnsCache::Key& key) const { + if (name_ != key.name_) + return name_ < key.name_; + + if (type_ != key.type_) + return type_ < key.type_; + + if (optional_ != key.optional_) + return optional_ < key.optional_; + return false; // keys are equal +} + +bool MDnsCache::Key::operator==(const MDnsCache::Key& key) const { + return type_ == key.type_ && name_ == key.name_ && optional_ == key.optional_; +} + +// static +MDnsCache::Key MDnsCache::Key::CreateFor(const RecordParsed* record) { + return Key(record->type(), + record->name(), + GetOptionalFieldForRecord(record)); +} + + +MDnsCache::MDnsCache() { +} + +MDnsCache::~MDnsCache() { + Clear(); +} + +void MDnsCache::Clear() { + next_expiration_ = base::Time(); + STLDeleteValues(&mdns_cache_); +} + +const RecordParsed* MDnsCache::LookupKey(const Key& key) { + RecordMap::iterator found = mdns_cache_.find(key); + if (found != mdns_cache_.end()) { + return found->second; + } + return NULL; +} + +MDnsCache::UpdateType MDnsCache::UpdateDnsRecord( + scoped_ptr<const RecordParsed> record) { + Key cache_key = Key::CreateFor(record.get()); + + // Ignore "goodbye" packets for records not in cache. + if (record->ttl() == 0 && mdns_cache_.find(cache_key) == mdns_cache_.end()) + return NoChange; + + base::Time new_expiration = GetEffectiveExpiration(record.get()); + if (next_expiration_ != base::Time()) + new_expiration = std::min(new_expiration, next_expiration_); + + std::pair<RecordMap::iterator, bool> insert_result = + mdns_cache_.insert(std::make_pair(cache_key, (const RecordParsed*)NULL)); + UpdateType type = NoChange; + if (insert_result.second) { + type = RecordAdded; + } else { + const RecordParsed* other_record = insert_result.first->second; + + if (record->ttl() != 0 && !record->IsEqual(other_record, true)) { + type = RecordChanged; + } + delete other_record; + } + + insert_result.first->second = record.release(); + next_expiration_ = new_expiration; + return type; +} + +void MDnsCache::CleanupRecords( + base::Time now, + const RecordRemovedCallback& record_removed_callback) { + base::Time next_expiration; + + // We are guaranteed that |next_expiration_| will be at or before the next + // expiration. This allows clients to eagrely call CleanupRecords with + // impunity. + if (now < next_expiration_) return; + + for (RecordMap::iterator i = mdns_cache_.begin(); + i != mdns_cache_.end(); ) { + base::Time expiration = GetEffectiveExpiration(i->second); + if (now >= expiration) { + record_removed_callback.Run(i->second); + delete i->second; + mdns_cache_.erase(i++); + } else { + if (next_expiration == base::Time() || expiration < next_expiration) { + next_expiration = expiration; + } + ++i; + } + } + + next_expiration_ = next_expiration; +} + +void MDnsCache::FindDnsRecords(unsigned type, + const std::string& name, + std::vector<const RecordParsed*>* results, + base::Time now) const { + DCHECK(results); + results->clear(); + + RecordMap::const_iterator i = mdns_cache_.lower_bound(Key(type, name, "")); + for (; i != mdns_cache_.end(); ++i) { + if (i->first.name() != name || + (type != 0 && i->first.type() != type)) { + break; + } + + const RecordParsed* record = i->second; + + // Records are deleted only upon request. + if (now >= GetEffectiveExpiration(record)) continue; + + results->push_back(record); + } +} + +scoped_ptr<const RecordParsed> MDnsCache::RemoveRecord( + const RecordParsed* record) { + Key key = Key::CreateFor(record); + RecordMap::iterator found = mdns_cache_.find(key); + + if (found != mdns_cache_.end() && found->second == record) { + mdns_cache_.erase(key); + return scoped_ptr<const RecordParsed>(record); + } + + return scoped_ptr<const RecordParsed>(); +} + +// static +std::string MDnsCache::GetOptionalFieldForRecord( + const RecordParsed* record) { + switch (record->type()) { + case PtrRecordRdata::kType: { + const PtrRecordRdata* rdata = record->rdata<PtrRecordRdata>(); + return rdata->ptrdomain(); + } + default: // Most records are considered unique for our purposes + return ""; + } +} + +// static +base::Time MDnsCache::GetEffectiveExpiration(const RecordParsed* record) { + base::TimeDelta ttl; + + if (record->ttl()) { + ttl = base::TimeDelta::FromSeconds(record->ttl()); + } else { + ttl = base::TimeDelta::FromSeconds(kZeroTTLSeconds); + } + + return record->time_created() + ttl; +} + +} // namespace net diff --git a/chromium/net/dns/mdns_cache.h b/chromium/net/dns/mdns_cache.h new file mode 100644 index 00000000000..27a14d8fa44 --- /dev/null +++ b/chromium/net/dns/mdns_cache.h @@ -0,0 +1,119 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_MDNS_CACHE_H_ +#define NET_DNS_MDNS_CACHE_H_ + +#include <map> +#include <string> +#include <vector> + +#include "base/callback.h" +#include "base/memory/scoped_ptr.h" +#include "base/time/time.h" +#include "net/base/net_export.h" + +namespace net { + +class ParsedDnsRecord; +class RecordParsed; + +// mDNS Cache +// This is a cache of mDNS records. It keeps track of expiration times and is +// guaranteed not to return expired records. It also has facilities for timely +// record expiration. +class NET_EXPORT_PRIVATE MDnsCache { + public: + // Key type for the record map. It is a 3-tuple of type, name and optional + // value ordered by type, then name, then optional value. This allows us to + // query for all records of a certain type and name, while also allowing us + // to set records of a certain type, name and optionally value as unique. + class Key { + public: + Key(unsigned type, const std::string& name, const std::string& optional); + Key(const Key&); + Key& operator=(const Key&); + ~Key(); + bool operator<(const Key& key) const; + bool operator==(const Key& key) const; + + unsigned type() const { return type_; } + const std::string& name() const { return name_; } + const std::string& optional() const { return optional_; } + + // Create the cache key corresponding to |record|. + static Key CreateFor(const RecordParsed* record); + private: + unsigned type_; + std::string name_; + std::string optional_; + }; + + typedef base::Callback<void(const RecordParsed*)> RecordRemovedCallback; + + enum UpdateType { + RecordAdded, + RecordChanged, + NoChange + }; + + MDnsCache(); + ~MDnsCache(); + + // Return value indicates whether the record was added, changed + // (existed previously with different value) or not changed (existed + // previously with same value). + UpdateType UpdateDnsRecord(scoped_ptr<const RecordParsed> record); + + // Check cache for record with key |key|. Return the record if it exists, or + // NULL if it doesn't. + const RecordParsed* LookupKey(const Key& key); + + // Return records with type |type| and name |name|. Expired records will not + // be returned. If |type| is zero, return all records with name |name|. + void FindDnsRecords(unsigned type, + const std::string& name, + std::vector<const RecordParsed*>* records, + base::Time now) const; + + // Remove expired records, call |record_removed_callback| for every removed + // record. + void CleanupRecords(base::Time now, + const RecordRemovedCallback& record_removed_callback); + + // Returns a time less than or equal to the next time a record will expire. + // Is updated when CleanupRecords or UpdateDnsRecord are called. Returns + // base::Time when the cache is empty. + base::Time next_expiration() const { return next_expiration_; } + + // Remove a record from the cache. Returns a scoped version of the pointer + // passed in if it was removed, scoped null otherwise. + scoped_ptr<const RecordParsed> RemoveRecord(const RecordParsed* record); + + void Clear(); + + private: + typedef std::map<Key, const RecordParsed*> RecordMap; + + // Get the effective expiration of a cache entry, based on its creation time + // and TTL. Does adjustments so entries with a TTL of zero will have a + // nonzero TTL, as explained in RFC 6762 Section 10.1. + static base::Time GetEffectiveExpiration(const RecordParsed* entry); + + // Get optional part of the DNS key for shared records. For example, in PTR + // records this is the pointed domain, since multiple PTR records may exist + // for the same name. + static std::string GetOptionalFieldForRecord( + const RecordParsed* record); + + RecordMap mdns_cache_; + + base::Time next_expiration_; + + DISALLOW_COPY_AND_ASSIGN(MDnsCache); +}; + +} // namespace net + +#endif // NET_DNS_MDNS_CACHE_H_ diff --git a/chromium/net/dns/mdns_cache_unittest.cc b/chromium/net/dns/mdns_cache_unittest.cc new file mode 100644 index 00000000000..c12ad6b6ec3 --- /dev/null +++ b/chromium/net/dns/mdns_cache_unittest.cc @@ -0,0 +1,375 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <algorithm> + +#include "base/bind.h" +#include "net/dns/dns_response.h" +#include "net/dns/dns_test_util.h" +#include "net/dns/mdns_cache.h" +#include "net/dns/record_parsed.h" +#include "net/dns/record_rdata.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +using ::testing::Return; +using ::testing::StrictMock; + +namespace net { + +static const uint8 kTestResponsesDifferentAnswers[] = { + // Answer 1 + // ghs.l.google.com in DNS format. + 3, 'g', 'h', 's', + 1, 'l', + 6, 'g', 'o', 'o', 'g', 'l', 'e', + 3, 'c', 'o', 'm', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0, 0, 0, 53, // TTL (4 bytes) is 53 seconds. + 0, 4, // RDLENGTH is 4 bytes. + 74, 125, 95, 121, // RDATA is the IP: 74.125.95.121 + + // Answer 2 + // Pointer to answer 1 + 0xc0, 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0, 0, 0, 53, // TTL (4 bytes) is 53 seconds. + 0, 4, // RDLENGTH is 4 bytes. + 74, 125, 95, 122, // RDATA is the IP: 74.125.95.122 +}; + +static const uint8 kTestResponsesSameAnswers[] = { + // Answer 1 + // ghs.l.google.com in DNS format. + 3, 'g', 'h', 's', + 1, 'l', + 6, 'g', 'o', 'o', 'g', 'l', 'e', + 3, 'c', 'o', 'm', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0, 0, 0, 53, // TTL (4 bytes) is 53 seconds. + 0, 4, // RDLENGTH is 4 bytes. + 74, 125, 95, 121, // RDATA is the IP: 74.125.95.121 + + // Answer 2 + // Pointer to answer 1 + 0xc0, 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0, 0, 0, 112, // TTL (4 bytes) is 112 seconds. + 0, 4, // RDLENGTH is 4 bytes. + 74, 125, 95, 121, // RDATA is the IP: 74.125.95.121 +}; + +static const uint8 kTestResponseTwoRecords[] = { + // Answer 1 + // ghs.l.google.com in DNS format. (A) + 3, 'g', 'h', 's', + 1, 'l', + 6, 'g', 'o', 'o', 'g', 'l', 'e', + 3, 'c', 'o', 'm', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0, 0, 0, 53, // TTL (4 bytes) is 53 seconds. + 0, 4, // RDLENGTH is 4 bytes. + 74, 125, 95, 121, // RDATA is the IP: 74.125.95.121 + + // Answer 2 + // ghs.l.google.com in DNS format. (AAAA) + 3, 'g', 'h', 's', + 1, 'l', + 6, 'g', 'o', 'o', 'g', 'l', 'e', + 3, 'c', 'o', 'm', + 0x00, + 0x00, 0x1c, // TYPE is AAA. + 0x00, 0x01, // CLASS is IN. + 0, 0, 0, 53, // TTL (4 bytes) is 53 seconds. + 0, 16, // RDLENGTH is 16 bytes. + 0x4a, 0x7d, 0x4a, 0x7d, + 0x5f, 0x79, 0x5f, 0x79, + 0x5f, 0x79, 0x5f, 0x79, + 0x5f, 0x79, 0x5f, 0x79, +}; + +static const uint8 kTestResponsesGoodbyePacket[] = { + // Answer 1 + // ghs.l.google.com in DNS format. (Goodbye packet) + 3, 'g', 'h', 's', + 1, 'l', + 6, 'g', 'o', 'o', 'g', 'l', 'e', + 3, 'c', 'o', 'm', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0, 0, 0, 0, // TTL (4 bytes) is zero. + 0, 4, // RDLENGTH is 4 bytes. + 74, 125, 95, 121, // RDATA is the IP: 74.125.95.121 + + // Answer 2 + // ghs.l.google.com in DNS format. + 3, 'g', 'h', 's', + 1, 'l', + 6, 'g', 'o', 'o', 'g', 'l', 'e', + 3, 'c', 'o', 'm', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0, 0, 0, 53, // TTL (4 bytes) is 53 seconds. + 0, 4, // RDLENGTH is 4 bytes. + 74, 125, 95, 121, // RDATA is the IP: 74.125.95.121 +}; + +class RecordRemovalMock { + public: + MOCK_METHOD1(OnRecordRemoved, void(const RecordParsed*)); +}; + +class MDnsCacheTest : public ::testing::Test { + public: + MDnsCacheTest() + : default_time_(base::Time::FromDoubleT(1234.0)) {} + virtual ~MDnsCacheTest() {} + + protected: + base::Time default_time_; + StrictMock<RecordRemovalMock> record_removal_; + MDnsCache cache_; +}; + +// Test a single insert, corresponding lookup, and unsuccessful lookup. +TEST_F(MDnsCacheTest, InsertLookupSingle) { + DnsRecordParser parser(kT1ResponseDatagram, sizeof(kT1ResponseDatagram), + sizeof(dns_protocol::Header)); + parser.SkipQuestion(); + + scoped_ptr<const RecordParsed> record1; + scoped_ptr<const RecordParsed> record2; + std::vector<const RecordParsed*> results; + + record1 = RecordParsed::CreateFrom(&parser, default_time_); + record2 = RecordParsed::CreateFrom(&parser, default_time_); + + EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record1.Pass())); + + EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record2.Pass())); + + cache_.FindDnsRecords(ARecordRdata::kType, "ghs.l.google.com", &results, + default_time_); + + EXPECT_EQ(1u, results.size()); + EXPECT_EQ(default_time_, results.front()->time_created()); + + EXPECT_EQ("ghs.l.google.com", results.front()->name()); + + results.clear(); + cache_.FindDnsRecords(PtrRecordRdata::kType, "ghs.l.google.com", &results, + default_time_); + + EXPECT_EQ(0u, results.size()); +} + +// Test that records expire when their ttl has passed. +TEST_F(MDnsCacheTest, Expiration) { + DnsRecordParser parser(kT1ResponseDatagram, sizeof(kT1ResponseDatagram), + sizeof(dns_protocol::Header)); + parser.SkipQuestion(); + scoped_ptr<const RecordParsed> record1; + scoped_ptr<const RecordParsed> record2; + + std::vector<const RecordParsed*> results; + const RecordParsed* record_to_be_deleted; + + record1 = RecordParsed::CreateFrom(&parser, default_time_); + base::TimeDelta ttl1 = base::TimeDelta::FromSeconds(record1->ttl()); + + record2 = RecordParsed::CreateFrom(&parser, default_time_); + base::TimeDelta ttl2 = base::TimeDelta::FromSeconds(record2->ttl()); + record_to_be_deleted = record2.get(); + + EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record1.Pass())); + EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record2.Pass())); + + cache_.FindDnsRecords(ARecordRdata::kType, "ghs.l.google.com", &results, + default_time_); + + EXPECT_EQ(1u, results.size()); + + EXPECT_EQ(default_time_ + ttl2, cache_.next_expiration()); + + + cache_.FindDnsRecords(ARecordRdata::kType, "ghs.l.google.com", &results, + default_time_ + ttl2); + + EXPECT_EQ(0u, results.size()); + + EXPECT_CALL(record_removal_, OnRecordRemoved(record_to_be_deleted)); + + cache_.CleanupRecords(default_time_ + ttl2, base::Bind( + &RecordRemovalMock::OnRecordRemoved, base::Unretained(&record_removal_))); + + // To make sure that we've indeed removed them from the map, check no funny + // business happens once they're deleted for good. + + EXPECT_EQ(default_time_ + ttl1, cache_.next_expiration()); + cache_.FindDnsRecords(ARecordRdata::kType, "ghs.l.google.com", &results, + default_time_ + ttl2); + + EXPECT_EQ(0u, results.size()); +} + +// Test that a new record replacing one with the same identity (name/rrtype for +// unique records) causes the cache to output a "record changed" event. +TEST_F(MDnsCacheTest, RecordChange) { + DnsRecordParser parser(kTestResponsesDifferentAnswers, + sizeof(kTestResponsesDifferentAnswers), + 0); + + scoped_ptr<const RecordParsed> record1; + scoped_ptr<const RecordParsed> record2; + std::vector<const RecordParsed*> results; + + record1 = RecordParsed::CreateFrom(&parser, default_time_); + record2 = RecordParsed::CreateFrom(&parser, default_time_); + + EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record1.Pass())); + EXPECT_EQ(MDnsCache::RecordChanged, + cache_.UpdateDnsRecord(record2.Pass())); +} + +// Test that a new record replacing an otherwise identical one already in the +// cache causes the cache to output a "no change" event. +TEST_F(MDnsCacheTest, RecordNoChange) { + DnsRecordParser parser(kTestResponsesSameAnswers, + sizeof(kTestResponsesSameAnswers), + 0); + + scoped_ptr<const RecordParsed> record1; + scoped_ptr<const RecordParsed> record2; + std::vector<const RecordParsed*> results; + + record1 = RecordParsed::CreateFrom(&parser, default_time_); + record2 = RecordParsed::CreateFrom(&parser, default_time_ + + base::TimeDelta::FromSeconds(1)); + + EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record1.Pass())); + EXPECT_EQ(MDnsCache::NoChange, cache_.UpdateDnsRecord(record2.Pass())); +} + +// Test that the next expiration time of the cache is updated properly on record +// insertion. +TEST_F(MDnsCacheTest, RecordPreemptExpirationTime) { + DnsRecordParser parser(kTestResponsesSameAnswers, + sizeof(kTestResponsesSameAnswers), + 0); + + scoped_ptr<const RecordParsed> record1; + scoped_ptr<const RecordParsed> record2; + std::vector<const RecordParsed*> results; + + record1 = RecordParsed::CreateFrom(&parser, default_time_); + record2 = RecordParsed::CreateFrom(&parser, default_time_); + base::TimeDelta ttl1 = base::TimeDelta::FromSeconds(record1->ttl()); + base::TimeDelta ttl2 = base::TimeDelta::FromSeconds(record2->ttl()); + + EXPECT_EQ(base::Time(), cache_.next_expiration()); + EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record2.Pass())); + EXPECT_EQ(default_time_ + ttl2, cache_.next_expiration()); + EXPECT_EQ(MDnsCache::NoChange, cache_.UpdateDnsRecord(record1.Pass())); + EXPECT_EQ(default_time_ + ttl1, cache_.next_expiration()); +} + +// Test that the cache handles mDNS "goodbye" packets correctly, not adding the +// records to the cache if they are not already there, and eventually removing +// records from the cache if they are. +TEST_F(MDnsCacheTest, GoodbyePacket) { + DnsRecordParser parser(kTestResponsesGoodbyePacket, + sizeof(kTestResponsesGoodbyePacket), + 0); + + scoped_ptr<const RecordParsed> record_goodbye; + scoped_ptr<const RecordParsed> record_hello; + scoped_ptr<const RecordParsed> record_goodbye2; + std::vector<const RecordParsed*> results; + + record_goodbye = RecordParsed::CreateFrom(&parser, default_time_); + record_hello = RecordParsed::CreateFrom(&parser, default_time_); + parser = DnsRecordParser(kTestResponsesGoodbyePacket, + sizeof(kTestResponsesGoodbyePacket), + 0); + record_goodbye2 = RecordParsed::CreateFrom(&parser, default_time_); + + base::TimeDelta ttl = base::TimeDelta::FromSeconds(record_hello->ttl()); + + EXPECT_EQ(base::Time(), cache_.next_expiration()); + EXPECT_EQ(MDnsCache::NoChange, cache_.UpdateDnsRecord(record_goodbye.Pass())); + EXPECT_EQ(base::Time(), cache_.next_expiration()); + EXPECT_EQ(MDnsCache::RecordAdded, + cache_.UpdateDnsRecord(record_hello.Pass())); + EXPECT_EQ(default_time_ + ttl, cache_.next_expiration()); + EXPECT_EQ(MDnsCache::NoChange, + cache_.UpdateDnsRecord(record_goodbye2.Pass())); + EXPECT_EQ(default_time_ + base::TimeDelta::FromSeconds(1), + cache_.next_expiration()); +} + +TEST_F(MDnsCacheTest, AnyRRType) { + DnsRecordParser parser(kTestResponseTwoRecords, + sizeof(kTestResponseTwoRecords), + 0); + + scoped_ptr<const RecordParsed> record1; + scoped_ptr<const RecordParsed> record2; + std::vector<const RecordParsed*> results; + + record1 = RecordParsed::CreateFrom(&parser, default_time_); + record2 = RecordParsed::CreateFrom(&parser, default_time_); + EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record1.Pass())); + EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record2.Pass())); + + cache_.FindDnsRecords(0, "ghs.l.google.com", &results, default_time_); + + EXPECT_EQ(2u, results.size()); + EXPECT_EQ(default_time_, results.front()->time_created()); + + EXPECT_EQ("ghs.l.google.com", results[0]->name()); + EXPECT_EQ("ghs.l.google.com", results[1]->name()); + EXPECT_EQ(dns_protocol::kTypeA, + std::min(results[0]->type(), results[1]->type())); + EXPECT_EQ(dns_protocol::kTypeAAAA, + std::max(results[0]->type(), results[1]->type())); +} + +TEST_F(MDnsCacheTest, RemoveRecord) { + DnsRecordParser parser(kT1ResponseDatagram, sizeof(kT1ResponseDatagram), + sizeof(dns_protocol::Header)); + parser.SkipQuestion(); + + scoped_ptr<const RecordParsed> record1; + std::vector<const RecordParsed*> results; + + record1 = RecordParsed::CreateFrom(&parser, default_time_); + EXPECT_EQ(MDnsCache::RecordAdded, cache_.UpdateDnsRecord(record1.Pass())); + + cache_.FindDnsRecords(dns_protocol::kTypeCNAME, "codereview.chromium.org", + &results, default_time_); + + EXPECT_EQ(1u, results.size()); + + scoped_ptr<const RecordParsed> record_out = + cache_.RemoveRecord(results.front()); + + EXPECT_EQ(record_out.get(), results.front()); + + cache_.FindDnsRecords(dns_protocol::kTypeCNAME, "codereview.chromium.org", + &results, default_time_); + + EXPECT_EQ(0u, results.size()); +} + +} // namespace net diff --git a/chromium/net/dns/mdns_client.cc b/chromium/net/dns/mdns_client.cc new file mode 100644 index 00000000000..631b01a706e --- /dev/null +++ b/chromium/net/dns/mdns_client.cc @@ -0,0 +1,17 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/mdns_client.h" + +#include "net/dns/mdns_client_impl.h" + +namespace net { + +// static +scoped_ptr<MDnsClient> MDnsClient::CreateDefault() { + return scoped_ptr<MDnsClient>( + new MDnsClientImpl(MDnsConnection::SocketFactory::CreateDefault())); +} + +} // namespace net diff --git a/chromium/net/dns/mdns_client.h b/chromium/net/dns/mdns_client.h new file mode 100644 index 00000000000..f12c6e292cd --- /dev/null +++ b/chromium/net/dns/mdns_client.h @@ -0,0 +1,158 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_MDNS_CLIENT_H_ +#define NET_DNS_MDNS_CLIENT_H_ + +#include <string> +#include <vector> + +#include "base/callback.h" +#include "net/dns/dns_query.h" +#include "net/dns/dns_response.h" +#include "net/dns/record_parsed.h" + +namespace net { + +class RecordParsed; + +// Represents a one-time record lookup. A transaction takes one +// associated callback (see |MDnsClient::CreateTransaction|) and calls it +// whenever a matching record has been found, either from the cache or +// by querying the network (it may choose to query either or both based on its +// creation flags, see MDnsTransactionFlags). Network-based transactions will +// time out after a reasonable number of seconds. +class NET_EXPORT MDnsTransaction { + public: + // Used to signify what type of result the transaction has recieved. + enum Result { + // Passed whenever a record is found. + RESULT_RECORD, + // The transaction is done. Applies to non-single-valued transactions. Is + // called when the transaction has finished (this is the last call to the + // callback). + RESULT_DONE, + // No results have been found. Applies to single-valued transactions. Is + // called when the transaction has finished without finding any results. + // For transactions that use the network, this happens when a timeout + // occurs, for transactions that are cache-only, this happens when no + // results are in the cache. + RESULT_NO_RESULTS, + // Called when an NSec record is read for this transaction's + // query. This means there cannot possibly be a record of the type + // and name for this transaction. + RESULT_NSEC + }; + + // Used when creating an MDnsTransaction. + enum Flags { + // Transaction should return only one result, and stop listening after it. + // Note that single result transactions will signal when their timeout is + // reached, whereas multi-result transactions will not. + SINGLE_RESULT = 1 << 0, + // Query the cache or the network. May both be used. One must be present. + QUERY_CACHE = 1 << 1, + QUERY_NETWORK = 1 << 2, + // TODO(noamsml): Add flag for flushing cache when feature is implemented + // Mask of all possible flags on MDnsTransaction. + FLAG_MASK = (1 << 3) - 1, + }; + + typedef base::Callback<void(Result, const RecordParsed*)> + ResultCallback; + + // Destroying the transaction cancels it. + virtual ~MDnsTransaction() {} + + // Start the transaction. Return true on success. Cache-based transactions + // will execute the callback synchronously. + virtual bool Start() = 0; + + // Get the host or service name for the transaction. + virtual const std::string& GetName() const = 0; + + // Get the type for this transaction (SRV, TXT, A, AAA, etc) + virtual uint16 GetType() const = 0; +}; + +// A listener listens for updates regarding a specific record or set of records. +// Created by the MDnsClient (see |MDnsClient::CreateListener|) and used to keep +// track of listeners. +class NET_EXPORT MDnsListener { + public: + // Used in the MDnsListener delegate to signify what type of change has been + // made to a record. + enum UpdateType { + RECORD_ADDED, + RECORD_CHANGED, + RECORD_REMOVED + }; + + class Delegate { + public: + virtual ~Delegate() {} + + // Called when a record is added, removed or updated. + virtual void OnRecordUpdate(UpdateType update, + const RecordParsed* record) = 0; + + // Called when a record is marked nonexistent by an NSEC record. + virtual void OnNsecRecord(const std::string& name, unsigned type) = 0; + + // Called when the cache is purged (due, for example, ot the network + // disconnecting). + virtual void OnCachePurged() = 0; + }; + + // Destroying the listener stops listening. + virtual ~MDnsListener() {} + + // Start the listener. Return true on success. + virtual bool Start() = 0; + + // Get the host or service name for this query. + // Return an empty string for no name. + virtual const std::string& GetName() const = 0; + + // Get the type for this query (SRV, TXT, A, AAA, etc) + virtual uint16 GetType() const = 0; +}; + +// Listens for Multicast DNS on the local network. You can access information +// regarding multicast DNS either by creating an |MDnsListener| to be notified +// of new records, or by creating an |MDnsTransaction| to look up the value of a +// specific records. When all listeners and active transactions are destroyed, +// the client stops listening on the network and destroys the cache. +class NET_EXPORT MDnsClient { + public: + virtual ~MDnsClient() {} + + // Create listener object for RRType |rrtype| and name |name|. + virtual scoped_ptr<MDnsListener> CreateListener( + uint16 rrtype, + const std::string& name, + MDnsListener::Delegate* delegate) = 0; + + // Create a transaction that can be used to query either the MDns cache, the + // network, or both for records of type |rrtype| and name |name|. |flags| is + // defined by MDnsTransactionFlags. + virtual scoped_ptr<MDnsTransaction> CreateTransaction( + uint16 rrtype, + const std::string& name, + int flags, + const MDnsTransaction::ResultCallback& callback) = 0; + + virtual bool StartListening() = 0; + + // Do not call this inside callbacks from related MDnsListener and + // MDnsTransaction objects. + virtual void StopListening() = 0; + virtual bool IsListening() const = 0; + + // Create the default MDnsClient + static scoped_ptr<MDnsClient> CreateDefault(); +}; + +} // namespace net +#endif // NET_DNS_MDNS_CLIENT_H_ diff --git a/chromium/net/dns/mdns_client_impl.cc b/chromium/net/dns/mdns_client_impl.cc new file mode 100644 index 00000000000..8f79edf4cdf --- /dev/null +++ b/chromium/net/dns/mdns_client_impl.cc @@ -0,0 +1,671 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/mdns_client_impl.h" + +#include "base/bind.h" +#include "base/message_loop/message_loop_proxy.h" +#include "base/stl_util.h" +#include "base/time/default_clock.h" +#include "net/base/dns_util.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/rand_callback.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/record_rdata.h" +#include "net/udp/datagram_socket.h" + +// TODO(gene): Remove this temporary method of disabling NSEC support once it +// becomes clear whether this feature should be +// supported. http://crbug.com/255232 +#define ENABLE_NSEC + +namespace net { + +namespace { +const char kMDnsMulticastGroupIPv4[] = "224.0.0.251"; +const char kMDnsMulticastGroupIPv6[] = "FF02::FB"; +const unsigned MDnsTransactionTimeoutSeconds = 3; +} + +MDnsConnection::SocketHandler::SocketHandler( + MDnsConnection* connection, const IPEndPoint& multicast_addr, + MDnsConnection::SocketFactory* socket_factory) + : socket_(socket_factory->CreateSocket()), connection_(connection), + response_(new DnsResponse(dns_protocol::kMaxMulticastSize)), + multicast_addr_(multicast_addr) { +} + +MDnsConnection::SocketHandler::~SocketHandler() { +} + +int MDnsConnection::SocketHandler::Start() { + int rv = BindSocket(); + if (rv != OK) { + return rv; + } + + return DoLoop(0); +} + +int MDnsConnection::SocketHandler::DoLoop(int rv) { + do { + if (rv > 0) + connection_->OnDatagramReceived(response_.get(), recv_addr_, rv); + + rv = socket_->RecvFrom( + response_->io_buffer(), + response_->io_buffer()->size(), + &recv_addr_, + base::Bind(&MDnsConnection::SocketHandler::OnDatagramReceived, + base::Unretained(this))); + } while (rv > 0); + + if (rv != ERR_IO_PENDING) + return rv; + + return OK; +} + +void MDnsConnection::SocketHandler::OnDatagramReceived(int rv) { + if (rv >= OK) + rv = DoLoop(rv); + + if (rv != OK) + connection_->OnError(this, rv); +} + +int MDnsConnection::SocketHandler::Send(IOBuffer* buffer, unsigned size) { + return socket_->SendTo( + buffer, size, multicast_addr_, + base::Bind(&MDnsConnection::SocketHandler::SendDone, + base::Unretained(this) )); +} + +void MDnsConnection::SocketHandler::SendDone(int rv) { + // TODO(noamsml): Retry logic. +} + +int MDnsConnection::SocketHandler::BindSocket() { + IPAddressNumber address_any(multicast_addr_.address().size()); + + IPEndPoint bind_endpoint(address_any, multicast_addr_.port()); + + socket_->AllowAddressReuse(); + int rv = socket_->Listen(bind_endpoint); + + if (rv < OK) return rv; + + socket_->SetMulticastLoopbackMode(false); + + return socket_->JoinGroup(multicast_addr_.address()); +} + +MDnsConnection::MDnsConnection(MDnsConnection::SocketFactory* socket_factory, + MDnsConnection::Delegate* delegate) + : socket_handler_ipv4_(this, + GetMDnsIPEndPoint(kMDnsMulticastGroupIPv4), + socket_factory), + socket_handler_ipv6_(this, + GetMDnsIPEndPoint(kMDnsMulticastGroupIPv6), + socket_factory), + delegate_(delegate) { +} + +MDnsConnection::~MDnsConnection() { +} + +int MDnsConnection::Init() { + int rv; + + rv = socket_handler_ipv4_.Start(); + if (rv != OK) return rv; + rv = socket_handler_ipv6_.Start(); + if (rv != OK) return rv; + + return OK; +} + +int MDnsConnection::Send(IOBuffer* buffer, unsigned size) { + int rv; + + rv = socket_handler_ipv4_.Send(buffer, size); + if (rv < OK && rv != ERR_IO_PENDING) return rv; + + rv = socket_handler_ipv6_.Send(buffer, size); + if (rv < OK && rv != ERR_IO_PENDING) return rv; + + return OK; +} + +void MDnsConnection::OnError(SocketHandler* loop, + int error) { + // TODO(noamsml): Specific handling of intermittent errors that can be handled + // in the connection. + delegate_->OnConnectionError(error); +} + +IPEndPoint MDnsConnection::GetMDnsIPEndPoint(const char* address) { + IPAddressNumber multicast_group_number; + bool success = ParseIPLiteralToNumber(address, + &multicast_group_number); + DCHECK(success); + return IPEndPoint(multicast_group_number, + dns_protocol::kDefaultPortMulticast); +} + +void MDnsConnection::OnDatagramReceived( + DnsResponse* response, + const IPEndPoint& recv_addr, + int bytes_read) { + // TODO(noamsml): More sophisticated error handling. + DCHECK_GT(bytes_read, 0); + delegate_->HandlePacket(response, bytes_read); +} + +class MDnsConnectionSocketFactoryImpl + : public MDnsConnection::SocketFactory { + public: + MDnsConnectionSocketFactoryImpl(); + virtual ~MDnsConnectionSocketFactoryImpl(); + + virtual scoped_ptr<DatagramServerSocket> CreateSocket() OVERRIDE; +}; + +MDnsConnectionSocketFactoryImpl::MDnsConnectionSocketFactoryImpl() { +} + +MDnsConnectionSocketFactoryImpl::~MDnsConnectionSocketFactoryImpl() { +} + +scoped_ptr<DatagramServerSocket> +MDnsConnectionSocketFactoryImpl::CreateSocket() { + return scoped_ptr<DatagramServerSocket>(new UDPServerSocket( + NULL, NetLog::Source())); +} + +// static +scoped_ptr<MDnsConnection::SocketFactory> +MDnsConnection::SocketFactory::CreateDefault() { + return scoped_ptr<MDnsConnection::SocketFactory>( + new MDnsConnectionSocketFactoryImpl); +} + +MDnsClientImpl::Core::Core(MDnsClientImpl* client, + MDnsConnection::SocketFactory* socket_factory) + : client_(client), connection_(new MDnsConnection(socket_factory, this)) { +} + +MDnsClientImpl::Core::~Core() { + STLDeleteValues(&listeners_); +} + +bool MDnsClientImpl::Core::Init() { + return connection_->Init() == OK; +} + +bool MDnsClientImpl::Core::SendQuery(uint16 rrtype, std::string name) { + std::string name_dns; + if (!DNSDomainFromDot(name, &name_dns)) + return false; + + DnsQuery query(0, name_dns, rrtype); + query.set_flags(0); // Remove the RD flag from the query. It is unneeded. + + return connection_->Send(query.io_buffer(), query.io_buffer()->size()) == OK; +} + +void MDnsClientImpl::Core::HandlePacket(DnsResponse* response, + int bytes_read) { + unsigned offset; + // Note: We store cache keys rather than record pointers to avoid + // erroneous behavior in case a packet contains multiple exclusive + // records with the same type and name. + std::map<MDnsCache::Key, MDnsListener::UpdateType> update_keys; + + if (!response->InitParseWithoutQuery(bytes_read)) { + LOG(WARNING) << "Could not understand an mDNS packet."; + return; // Message is unreadable. + } + + // TODO(noamsml): duplicate query suppression. + if (!(response->flags() & dns_protocol::kFlagResponse)) + return; // Message is a query. ignore it. + + DnsRecordParser parser = response->Parser(); + unsigned answer_count = response->answer_count() + + response->additional_answer_count(); + + for (unsigned i = 0; i < answer_count; i++) { + offset = parser.GetOffset(); + scoped_ptr<const RecordParsed> record = RecordParsed::CreateFrom( + &parser, base::Time::Now()); + + if (!record) { + LOG(WARNING) << "Could not understand an mDNS record."; + + if (offset == parser.GetOffset()) { + LOG(WARNING) << "Abandoned parsing the rest of the packet."; + return; // The parser did not advance, abort reading the packet. + } else { + continue; // We may be able to extract other records from the packet. + } + } + + if ((record->klass() & dns_protocol::kMDnsClassMask) != + dns_protocol::kClassIN) { + LOG(WARNING) << "Received an mDNS record with non-IN class. Ignoring."; + continue; // Ignore all records not in the IN class. + } + + MDnsCache::Key update_key = MDnsCache::Key::CreateFor(record.get()); + MDnsCache::UpdateType update = cache_.UpdateDnsRecord(record.Pass()); + + // Cleanup time may have changed. + ScheduleCleanup(cache_.next_expiration()); + + if (update != MDnsCache::NoChange) { + MDnsListener::UpdateType update_external; + + switch (update) { + case MDnsCache::RecordAdded: + update_external = MDnsListener::RECORD_ADDED; + break; + case MDnsCache::RecordChanged: + update_external = MDnsListener::RECORD_CHANGED; + break; + case MDnsCache::NoChange: + default: + NOTREACHED(); + // Dummy assignment to suppress compiler warning. + update_external = MDnsListener::RECORD_CHANGED; + break; + } + + update_keys.insert(std::make_pair(update_key, update_external)); + } + } + + for (std::map<MDnsCache::Key, MDnsListener::UpdateType>::iterator i = + update_keys.begin(); i != update_keys.end(); i++) { + const RecordParsed* record = cache_.LookupKey(i->first); + if (!record) + continue; + + if (record->type() == dns_protocol::kTypeNSEC) { +#if defined(ENABLE_NSEC) + NotifyNsecRecord(record); +#endif + } else { + AlertListeners(i->second, ListenerKey(record->name(), record->type()), + record); + } + } +} + +void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) { + DCHECK_EQ(dns_protocol::kTypeNSEC, record->type()); + const NsecRecordRdata* rdata = record->rdata<NsecRecordRdata>(); + DCHECK(rdata); + + // Remove all cached records matching the nonexistent RR types. + std::vector<const RecordParsed*> records_to_remove; + + cache_.FindDnsRecords(0, record->name(), &records_to_remove, + base::Time::Now()); + + for (std::vector<const RecordParsed*>::iterator i = records_to_remove.begin(); + i != records_to_remove.end(); i++) { + if ((*i)->type() == dns_protocol::kTypeNSEC) + continue; + if (!rdata->GetBit((*i)->type())) { + scoped_ptr<const RecordParsed> record_removed = cache_.RemoveRecord((*i)); + DCHECK(record_removed); + OnRecordRemoved(record_removed.get()); + } + } + + // Alert all listeners waiting for the nonexistent RR types. + ListenerMap::iterator i = + listeners_.upper_bound(ListenerKey(record->name(), 0)); + for (; i != listeners_.end() && i->first.first == record->name(); i++) { + if (!rdata->GetBit(i->first.second)) { + FOR_EACH_OBSERVER(MDnsListenerImpl, *i->second, AlertNsecRecord()); + } + } +} + +void MDnsClientImpl::Core::OnConnectionError(int error) { + // TODO(noamsml): On connection error, recreate connection and flush cache. +} + +void MDnsClientImpl::Core::AlertListeners( + MDnsListener::UpdateType update_type, + const ListenerKey& key, + const RecordParsed* record) { + ListenerMap::iterator listener_map_iterator = listeners_.find(key); + if (listener_map_iterator == listeners_.end()) return; + + FOR_EACH_OBSERVER(MDnsListenerImpl, *listener_map_iterator->second, + AlertDelegate(update_type, record)); +} + +void MDnsClientImpl::Core::AddListener( + MDnsListenerImpl* listener) { + ListenerKey key(listener->GetName(), listener->GetType()); + std::pair<ListenerMap::iterator, bool> observer_insert_result = + listeners_.insert( + make_pair(key, static_cast<ObserverList<MDnsListenerImpl>*>(NULL))); + + // If an equivalent key does not exist, actually create the observer list. + if (observer_insert_result.second) + observer_insert_result.first->second = new ObserverList<MDnsListenerImpl>(); + + ObserverList<MDnsListenerImpl>* observer_list = + observer_insert_result.first->second; + + observer_list->AddObserver(listener); +} + +void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) { + ListenerKey key(listener->GetName(), listener->GetType()); + ListenerMap::iterator observer_list_iterator = listeners_.find(key); + + DCHECK(observer_list_iterator != listeners_.end()); + DCHECK(observer_list_iterator->second->HasObserver(listener)); + + observer_list_iterator->second->RemoveObserver(listener); + + // Remove the observer list from the map if it is empty + if (observer_list_iterator->second->size() == 0) { + // Schedule the actual removal for later in case the listener removal + // happens while iterating over the observer list. + base::MessageLoop::current()->PostTask( + FROM_HERE, base::Bind( + &MDnsClientImpl::Core::CleanupObserverList, AsWeakPtr(), key)); + } +} + +void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey& key) { + ListenerMap::iterator found = listeners_.find(key); + if (found != listeners_.end() && found->second->size() == 0) { + delete found->second; + listeners_.erase(found); + } +} + +void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup) { + // Cleanup is already scheduled, no need to do anything. + if (cleanup == scheduled_cleanup_) return; + scheduled_cleanup_ = cleanup; + + // This cancels the previously scheduled cleanup. + cleanup_callback_.Reset(base::Bind( + &MDnsClientImpl::Core::DoCleanup, base::Unretained(this))); + + // If |cleanup| is empty, then no cleanup necessary. + if (cleanup != base::Time()) { + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + cleanup_callback_.callback(), + cleanup - base::Time::Now()); + } +} + +void MDnsClientImpl::Core::DoCleanup() { + cache_.CleanupRecords(base::Time::Now(), base::Bind( + &MDnsClientImpl::Core::OnRecordRemoved, base::Unretained(this))); + + ScheduleCleanup(cache_.next_expiration()); +} + +void MDnsClientImpl::Core::OnRecordRemoved( + const RecordParsed* record) { + AlertListeners(MDnsListener::RECORD_REMOVED, + ListenerKey(record->name(), record->type()), record); +} + +void MDnsClientImpl::Core::QueryCache( + uint16 rrtype, const std::string& name, + std::vector<const RecordParsed*>* records) const { + cache_.FindDnsRecords(rrtype, name, records, base::Time::Now()); +} + +MDnsClientImpl::MDnsClientImpl( + scoped_ptr<MDnsConnection::SocketFactory> socket_factory) + : socket_factory_(socket_factory.Pass()) { +} + +MDnsClientImpl::~MDnsClientImpl() { +} + +bool MDnsClientImpl::StartListening() { + DCHECK(!core_.get()); + core_.reset(new Core(this, socket_factory_.get())); + if (!core_->Init()) { + core_.reset(); + return false; + } + return true; +} + +void MDnsClientImpl::StopListening() { + core_.reset(); +} + +bool MDnsClientImpl::IsListening() const { + return core_.get() != NULL; +} + +scoped_ptr<MDnsListener> MDnsClientImpl::CreateListener( + uint16 rrtype, + const std::string& name, + MDnsListener::Delegate* delegate) { + return scoped_ptr<net::MDnsListener>( + new MDnsListenerImpl(rrtype, name, delegate, this)); +} + +scoped_ptr<MDnsTransaction> MDnsClientImpl::CreateTransaction( + uint16 rrtype, + const std::string& name, + int flags, + const MDnsTransaction::ResultCallback& callback) { + return scoped_ptr<MDnsTransaction>( + new MDnsTransactionImpl(rrtype, name, flags, callback, this)); +} + +MDnsListenerImpl::MDnsListenerImpl( + uint16 rrtype, + const std::string& name, + MDnsListener::Delegate* delegate, + MDnsClientImpl* client) + : rrtype_(rrtype), name_(name), client_(client), delegate_(delegate), + started_(false) { +} + +bool MDnsListenerImpl::Start() { + DCHECK(!started_); + + started_ = true; + + DCHECK(client_->core()); + client_->core()->AddListener(this); + + return true; +} + +MDnsListenerImpl::~MDnsListenerImpl() { + if (started_) { + DCHECK(client_->core()); + client_->core()->RemoveListener(this); + } +} + +const std::string& MDnsListenerImpl::GetName() const { + return name_; +} + +uint16 MDnsListenerImpl::GetType() const { + return rrtype_; +} + +void MDnsListenerImpl::AlertDelegate(MDnsListener::UpdateType update_type, + const RecordParsed* record) { + DCHECK(started_); + delegate_->OnRecordUpdate(update_type, record); +} + +void MDnsListenerImpl::AlertNsecRecord() { + DCHECK(started_); + delegate_->OnNsecRecord(name_, rrtype_); +} + +MDnsTransactionImpl::MDnsTransactionImpl( + uint16 rrtype, + const std::string& name, + int flags, + const MDnsTransaction::ResultCallback& callback, + MDnsClientImpl* client) + : rrtype_(rrtype), name_(name), callback_(callback), client_(client), + started_(false), flags_(flags) { + DCHECK((flags_ & MDnsTransaction::FLAG_MASK) == flags_); + DCHECK(flags_ & MDnsTransaction::QUERY_CACHE || + flags_ & MDnsTransaction::QUERY_NETWORK); +} + +MDnsTransactionImpl::~MDnsTransactionImpl() { + timeout_.Cancel(); +} + +bool MDnsTransactionImpl::Start() { + DCHECK(!started_); + started_ = true; + + base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr(); + if (flags_ & MDnsTransaction::QUERY_CACHE) { + ServeRecordsFromCache(); + + if (!weak_this || !is_active()) return true; + } + + if (flags_ & MDnsTransaction::QUERY_NETWORK) { + return QueryAndListen(); + } + + // If this is a cache only query, signal that the transaction is over + // immediately. + SignalTransactionOver(); + return true; +} + +const std::string& MDnsTransactionImpl::GetName() const { + return name_; +} + +uint16 MDnsTransactionImpl::GetType() const { + return rrtype_; +} + +void MDnsTransactionImpl::CacheRecordFound(const RecordParsed* record) { + DCHECK(started_); + OnRecordUpdate(MDnsListener::RECORD_ADDED, record); +} + +void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result, + const RecordParsed* record) { + DCHECK(started_); + if (!is_active()) return; + + // Ensure callback is run after touching all class state, so that + // the callback can delete the transaction. + MDnsTransaction::ResultCallback callback = callback_; + + // Reset the transaction if it expects a single result, or if the result + // is a final one (everything except for a record). + if (flags_ & MDnsTransaction::SINGLE_RESULT || + result != MDnsTransaction::RESULT_RECORD) { + Reset(); + } + + callback.Run(result, record); +} + +void MDnsTransactionImpl::Reset() { + callback_.Reset(); + listener_.reset(); + timeout_.Cancel(); +} + +void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update, + const RecordParsed* record) { + DCHECK(started_); + if (update == MDnsListener::RECORD_ADDED || + update == MDnsListener::RECORD_CHANGED) + TriggerCallback(MDnsTransaction::RESULT_RECORD, record); +} + +void MDnsTransactionImpl::SignalTransactionOver() { + DCHECK(started_); + if (flags_ & MDnsTransaction::SINGLE_RESULT) { + TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS, NULL); + } else { + TriggerCallback(MDnsTransaction::RESULT_DONE, NULL); + } +} + +void MDnsTransactionImpl::ServeRecordsFromCache() { + std::vector<const RecordParsed*> records; + base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr(); + + if (client_->core()) { + client_->core()->QueryCache(rrtype_, name_, &records); + for (std::vector<const RecordParsed*>::iterator i = records.begin(); + i != records.end() && weak_this; ++i) { + weak_this->TriggerCallback(MDnsTransaction::RESULT_RECORD, *i); + } + +#if defined(ENABLE_NSEC) + if (records.empty()) { + DCHECK(weak_this); + client_->core()->QueryCache(dns_protocol::kTypeNSEC, name_, &records); + if (!records.empty()) { + const NsecRecordRdata* rdata = + records.front()->rdata<NsecRecordRdata>(); + DCHECK(rdata); + if (!rdata->GetBit(rrtype_)) + weak_this->TriggerCallback(MDnsTransaction::RESULT_NSEC, NULL); + } + } +#endif + } +} + +bool MDnsTransactionImpl::QueryAndListen() { + listener_ = client_->CreateListener(rrtype_, name_, this); + if (!listener_->Start()) + return false; + + DCHECK(client_->core()); + if (!client_->core()->SendQuery(rrtype_, name_)) + return false; + + timeout_.Reset(base::Bind(&MDnsTransactionImpl::SignalTransactionOver, + AsWeakPtr())); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + timeout_.callback(), + base::TimeDelta::FromSeconds(MDnsTransactionTimeoutSeconds)); + + return true; +} + +void MDnsTransactionImpl::OnNsecRecord(const std::string& name, unsigned type) { + TriggerCallback(RESULT_NSEC, NULL); +} + +void MDnsTransactionImpl::OnCachePurged() { + // TODO(noamsml): Cache purge situations not yet implemented +} + +} // namespace net diff --git a/chromium/net/dns/mdns_client_impl.h b/chromium/net/dns/mdns_client_impl.h new file mode 100644 index 00000000000..9fe3f99e7dd --- /dev/null +++ b/chromium/net/dns/mdns_client_impl.h @@ -0,0 +1,298 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_MDNS_CLIENT_IMPL_H_ +#define NET_DNS_MDNS_CLIENT_IMPL_H_ + +#include <map> +#include <string> +#include <utility> +#include <vector> + +#include "base/cancelable_callback.h" +#include "base/observer_list.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/dns/mdns_cache.h" +#include "net/dns/mdns_client.h" +#include "net/udp/datagram_server_socket.h" +#include "net/udp/udp_server_socket.h" +#include "net/udp/udp_socket.h" + +namespace net { + +// A connection to the network for multicast DNS clients. It reads data into +// DnsResponse objects and alerts the delegate that a packet has been received. +class NET_EXPORT_PRIVATE MDnsConnection { + public: + class SocketFactory { + public: + virtual ~SocketFactory() {} + + virtual scoped_ptr<DatagramServerSocket> CreateSocket() = 0; + + static scoped_ptr<SocketFactory> CreateDefault(); + }; + + class Delegate { + public: + // Handle an mDNS packet buffered in |response| with a size of |bytes_read|. + virtual void HandlePacket(DnsResponse* response, int bytes_read) = 0; + virtual void OnConnectionError(int error) = 0; + virtual ~Delegate() {} + }; + + explicit MDnsConnection(SocketFactory* socket_factory, + MDnsConnection::Delegate* delegate); + + virtual ~MDnsConnection(); + + int Init(); + int Send(IOBuffer* buffer, unsigned size); + + private: + class SocketHandler { + public: + SocketHandler(MDnsConnection* connection, + const IPEndPoint& multicast_addr, + SocketFactory* socket_factory); + ~SocketHandler(); + int DoLoop(int rv); + int Start(); + + int Send(IOBuffer* buffer, unsigned size); + + private: + int BindSocket(); + void OnDatagramReceived(int rv); + + // Callback for when sending a query has finished. + void SendDone(int rv); + + scoped_ptr<DatagramServerSocket> socket_; + + MDnsConnection* connection_; + IPEndPoint recv_addr_; + scoped_ptr<DnsResponse> response_; + IPEndPoint multicast_addr_; + }; + + // Callback for handling a datagram being received on either ipv4 or ipv6. + void OnDatagramReceived(DnsResponse* response, + const IPEndPoint& recv_addr, + int bytes_read); + + void OnError(SocketHandler* loop, int error); + + IPEndPoint GetMDnsIPEndPoint(const char* address); + + SocketHandler socket_handler_ipv4_; + SocketHandler socket_handler_ipv6_; + + Delegate* delegate_; + + DISALLOW_COPY_AND_ASSIGN(MDnsConnection); +}; + +class MDnsListenerImpl; + +class NET_EXPORT_PRIVATE MDnsClientImpl : public MDnsClient { + public: + // The core object exists while the MDnsClient is listening, and is deleted + // whenever the number of listeners reaches zero. The deletion happens + // asychronously, so destroying the last listener does not immediately + // invalidate the core. + class Core : public base::SupportsWeakPtr<Core>, MDnsConnection::Delegate { + public: + Core(MDnsClientImpl* client, + MDnsConnection::SocketFactory* socket_factory); + virtual ~Core(); + + // Initialize the core. Returns true on success. + bool Init(); + + // Send a query with a specific rrtype and name. Returns true on success. + bool SendQuery(uint16 rrtype, std::string name); + + // Add/remove a listener to the list of listeners. + void AddListener(MDnsListenerImpl* listener); + void RemoveListener(MDnsListenerImpl* listener); + + // Query the cache for records of a specific type and name. + void QueryCache(uint16 rrtype, const std::string& name, + std::vector<const RecordParsed*>* records) const; + + // Parse the response and alert relevant listeners. + virtual void HandlePacket(DnsResponse* response, int bytes_read) OVERRIDE; + + virtual void OnConnectionError(int error) OVERRIDE; + + private: + typedef std::pair<std::string, uint16> ListenerKey; + typedef std::map<ListenerKey, ObserverList<MDnsListenerImpl>* > + ListenerMap; + + // Alert listeners of an update to the cache. + void AlertListeners(MDnsListener::UpdateType update_type, + const ListenerKey& key, const RecordParsed* record); + + // Schedule a cache cleanup to a specific time, cancelling other cleanups. + void ScheduleCleanup(base::Time cleanup); + + // Clean up the cache and schedule a new cleanup. + void DoCleanup(); + + // Callback for when a record is removed from the cache. + void OnRecordRemoved(const RecordParsed* record); + + void NotifyNsecRecord(const RecordParsed* record); + + // Delete and erase the observer list for |key|. Only deletes the observer + // list if is empty. + void CleanupObserverList(const ListenerKey& key); + + ListenerMap listeners_; + + MDnsClientImpl* client_; + MDnsCache cache_; + + base::CancelableCallback<void()> cleanup_callback_; + base::Time scheduled_cleanup_; + + scoped_ptr<MDnsConnection> connection_; + + DISALLOW_COPY_AND_ASSIGN(Core); + }; + + explicit MDnsClientImpl( + scoped_ptr<MDnsConnection::SocketFactory> socket_factory_); + virtual ~MDnsClientImpl(); + + // MDnsClient implementation: + virtual scoped_ptr<MDnsListener> CreateListener( + uint16 rrtype, + const std::string& name, + MDnsListener::Delegate* delegate) OVERRIDE; + + virtual scoped_ptr<MDnsTransaction> CreateTransaction( + uint16 rrtype, + const std::string& name, + int flags, + const MDnsTransaction::ResultCallback& callback) OVERRIDE; + + virtual bool StartListening() OVERRIDE; + virtual void StopListening() OVERRIDE; + virtual bool IsListening() const OVERRIDE; + + Core* core() { return core_.get(); } + + private: + scoped_ptr<Core> core_; + + scoped_ptr<MDnsConnection::SocketFactory> socket_factory_; + + DISALLOW_COPY_AND_ASSIGN(MDnsClientImpl); +}; + +class MDnsListenerImpl : public MDnsListener, + public base::SupportsWeakPtr<MDnsListenerImpl> { + public: + MDnsListenerImpl(uint16 rrtype, + const std::string& name, + MDnsListener::Delegate* delegate, + MDnsClientImpl* client); + + virtual ~MDnsListenerImpl(); + + // MDnsListener implementation: + virtual bool Start() OVERRIDE; + + virtual const std::string& GetName() const OVERRIDE; + + virtual uint16 GetType() const OVERRIDE; + + MDnsListener::Delegate* delegate() { return delegate_; } + + // Alert the delegate of a record update. + void AlertDelegate(MDnsListener::UpdateType update_type, + const RecordParsed* record_parsed); + + // Alert the delegate of the existence of an Nsec record. + void AlertNsecRecord(); + + private: + uint16 rrtype_; + std::string name_; + MDnsClientImpl* client_; + MDnsListener::Delegate* delegate_; + + bool started_; + DISALLOW_COPY_AND_ASSIGN(MDnsListenerImpl); +}; + +class MDnsTransactionImpl : public base::SupportsWeakPtr<MDnsTransactionImpl>, + public MDnsTransaction, + public MDnsListener::Delegate { + public: + MDnsTransactionImpl(uint16 rrtype, + const std::string& name, + int flags, + const MDnsTransaction::ResultCallback& callback, + MDnsClientImpl* client); + virtual ~MDnsTransactionImpl(); + + // MDnsTransaction implementation: + virtual bool Start() OVERRIDE; + + virtual const std::string& GetName() const OVERRIDE; + virtual uint16 GetType() const OVERRIDE; + + // MDnsListener::Delegate implementation: + virtual void OnRecordUpdate(MDnsListener::UpdateType update, + const RecordParsed* record) OVERRIDE; + virtual void OnNsecRecord(const std::string& name, unsigned type) OVERRIDE; + + virtual void OnCachePurged() OVERRIDE; + + private: + bool is_active() { return !callback_.is_null(); } + + void Reset(); + + // Trigger the callback and reset all related variables. + void TriggerCallback(MDnsTransaction::Result result, + const RecordParsed* record); + + // Internal callback for when a cache record is found. + void CacheRecordFound(const RecordParsed* record); + + // Signal the transactionis over and release all related resources. + void SignalTransactionOver(); + + // Reads records from the cache and calls the callback for every + // record read. + void ServeRecordsFromCache(); + + // Send a query to the network and set up a timeout to time out the + // transaction. Returns false if it fails to start listening on the network + // or if it fails to send a query. + bool QueryAndListen(); + + uint16 rrtype_; + std::string name_; + MDnsTransaction::ResultCallback callback_; + + scoped_ptr<MDnsListener> listener_; + base::CancelableCallback<void()> timeout_; + + MDnsClientImpl* client_; + + bool started_; + int flags_; + + DISALLOW_COPY_AND_ASSIGN(MDnsTransactionImpl); +}; + +} // namespace net +#endif // NET_DNS_MDNS_CLIENT_IMPL_H_ diff --git a/chromium/net/dns/mdns_client_unittest.cc b/chromium/net/dns/mdns_client_unittest.cc new file mode 100644 index 00000000000..324b4dfbee0 --- /dev/null +++ b/chromium/net/dns/mdns_client_unittest.cc @@ -0,0 +1,1176 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <queue> + +#include "base/memory/ref_counted.h" +#include "base/message_loop/message_loop.h" +#include "net/base/rand_callback.h" +#include "net/base/test_completion_callback.h" +#include "net/dns/mdns_client_impl.h" +#include "net/dns/mock_mdns_socket_factory.h" +#include "net/dns/record_rdata.h" +#include "net/udp/udp_client_socket.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +using ::testing::Invoke; +using ::testing::InvokeWithoutArgs; +using ::testing::StrictMock; +using ::testing::NiceMock; +using ::testing::Exactly; +using ::testing::Return; +using ::testing::SaveArg; +using ::testing::_; + +namespace net { + +namespace { + +const uint8 kSamplePacket1[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x02, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, // TTL (4 bytes) is 1 second; + 0x00, 0x01, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'h', 'e', 'l', 'l', 'o', + 0xc0, 0x0c, + + // Answer 2 + 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r', + 0xc0, 0x14, // Pointer to "._tcp.local" + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 49 seconds. + 0x24, 0x75, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'h', 'e', 'l', 'l', 'o', + 0xc0, 0x32 +}; + +const uint8 kCorruptedPacketBadQuestion[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x01, // One question + 0x00, 0x02, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Question is corrupted and cannot be read. + 0x99, 'h', 'e', 'l', 'l', 'o', + 0x00, + 0x00, 0x00, + 0x00, 0x00, + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x99, // RDLENGTH is impossible + 0x05, 'h', 'e', 'l', 'l', 'o', + 0xc0, 0x0c, + + // Answer 2 + 0x08, '_', 'p', 'r', // Useless trailing data. +}; + +const uint8 kCorruptedPacketUnsalvagable[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x02, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x99, // RDLENGTH is impossible + 0x05, 'h', 'e', 'l', 'l', 'o', + 0xc0, 0x0c, + + // Answer 2 + 0x08, '_', 'p', 'r', // Useless trailing data. +}; + +const uint8 kCorruptedPacketDoubleRecord[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x02, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Answer 1 + 0x06, 'p', 'r', 'i', 'v', 'e', 't', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x04, // RDLENGTH is 4 + 0x05, 0x03, + 0xc0, 0x0c, + + // Answer 2 -- Same key + 0x06, 'p', 'r', 'i', 'v', 'e', 't', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x04, // RDLENGTH is 4 + 0x02, 0x03, + 0x04, 0x05, +}; + +const uint8 kCorruptedPacketSalvagable[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x02, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x99, 'h', 'e', 'l', 'l', 'o', // Bad RDATA format. + 0xc0, 0x0c, + + // Answer 2 + 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r', + 0xc0, 0x14, // Pointer to "._tcp.local" + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 49 seconds. + 0x24, 0x75, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'h', 'e', 'l', 'l', 'o', + 0xc0, 0x32 +}; + +const uint8 kSamplePacket2[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x02, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'z', 'z', 'z', 'z', 'z', + 0xc0, 0x0c, + + // Answer 2 + 0x08, '_', 'p', 'r', 'i', 'n', 't', 'e', 'r', + 0xc0, 0x14, // Pointer to "._tcp.local" + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'z', 'z', 'z', 'z', 'z', + 0xc0, 0x32 +}; + +const uint8 kQueryPacketPrivet[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x00, 0x00, // No flags. + 0x00, 0x01, // One question. + 0x00, 0x00, // 0 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Question + // This part is echoed back from the respective query. + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. +}; + +const uint8 kSamplePacketAdditionalOnly[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x00, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x01, // 0 additional RRs + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'h', 'e', 'l', 'l', 'o', + 0xc0, 0x0c, +}; + +const uint8 kSamplePacketNsec[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x01, // 1 RR (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x2f, // TYPE is NSEC. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x06, // RDLENGTH is 6 bytes. + 0xc0, 0x0c, + 0x00, 0x02, 0x00, 0x08 // Only A record present +}; + +const uint8 kSamplePacketAPrivet[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x01, // 1 RR (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x01, // TYPE is A. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x04, // RDLENGTH is 4 bytes. + 0xc0, 0x0c, + 0x00, 0x02, +}; + +const uint8 kSamplePacketGoodbye[] = { + // Header + 0x00, 0x00, // ID is zeroed out + 0x81, 0x80, // Standard query response, RA, no error + 0x00, 0x00, // No questions (for simplicity) + 0x00, 0x01, // 2 RRs (answers) + 0x00, 0x00, // 0 authority RRs + 0x00, 0x00, // 0 additional RRs + + // Answer 1 + 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', + 0x04, '_', 't', 'c', 'p', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x0c, // TYPE is PTR. + 0x00, 0x01, // CLASS is IN. + 0x00, 0x00, // TTL (4 bytes) is zero; + 0x00, 0x00, + 0x00, 0x08, // RDLENGTH is 8 bytes. + 0x05, 'z', 'z', 'z', 'z', 'z', + 0xc0, 0x0c, +}; + +std::string MakeString(const uint8* data, unsigned size) { + return std::string(reinterpret_cast<const char*>(data), size); +} + +class PtrRecordCopyContainer { + public: + PtrRecordCopyContainer() {} + ~PtrRecordCopyContainer() {} + + bool is_set() const { return set_; } + + void SaveWithDummyArg(int unused, const RecordParsed* value) { + Save(value); + } + + void Save(const RecordParsed* value) { + set_ = true; + name_ = value->name(); + ptrdomain_ = value->rdata<PtrRecordRdata>()->ptrdomain(); + ttl_ = value->ttl(); + } + + bool IsRecordWith(std::string name, std::string ptrdomain) { + return set_ && name_ == name && ptrdomain_ == ptrdomain; + } + + const std::string& name() { return name_; } + const std::string& ptrdomain() { return ptrdomain_; } + int ttl() { return ttl_; } + + private: + bool set_; + std::string name_; + std::string ptrdomain_; + int ttl_; +}; + +class MDnsTest : public ::testing::Test { + public: + MDnsTest(); + virtual ~MDnsTest(); + virtual void SetUp() OVERRIDE; + virtual void TearDown() OVERRIDE; + void DeleteTransaction(); + void DeleteBothListeners(); + void RunFor(base::TimeDelta time_period); + void Stop(); + + MOCK_METHOD2(MockableRecordCallback, void(MDnsTransaction::Result result, + const RecordParsed* record)); + + MOCK_METHOD2(MockableRecordCallback2, void(MDnsTransaction::Result result, + const RecordParsed* record)); + + + protected: + void ExpectPacket(const uint8* packet, unsigned size); + void SimulatePacketReceive(const uint8* packet, unsigned size); + + scoped_ptr<MDnsClientImpl> test_client_; + IPEndPoint mdns_ipv4_endpoint_; + StrictMock<MockMDnsSocketFactory>* socket_factory_; + + // Transactions and listeners that can be deleted by class methods for + // reentrancy tests. + scoped_ptr<MDnsTransaction> transaction_; + scoped_ptr<MDnsListener> listener1_; + scoped_ptr<MDnsListener> listener2_; +}; + +class MockListenerDelegate : public MDnsListener::Delegate { + public: + MOCK_METHOD2(OnRecordUpdate, + void(MDnsListener::UpdateType update, + const RecordParsed* records)); + MOCK_METHOD2(OnNsecRecord, void(const std::string&, unsigned)); + MOCK_METHOD0(OnCachePurged, void()); +}; + +MDnsTest::MDnsTest() { + socket_factory_ = new StrictMock<MockMDnsSocketFactory>(); + test_client_.reset(new MDnsClientImpl( + scoped_ptr<MDnsConnection::SocketFactory>(socket_factory_))); +} + +MDnsTest::~MDnsTest() { +} + +void MDnsTest::SetUp() { + test_client_->StartListening(); +} + +void MDnsTest::TearDown() { +} + +void MDnsTest::SimulatePacketReceive(const uint8* packet, unsigned size) { + socket_factory_->SimulateReceive(packet, size); +} + +void MDnsTest::ExpectPacket(const uint8* packet, unsigned size) { + EXPECT_CALL(*socket_factory_, OnSendTo(MakeString(packet, size))) + .Times(2); +} + +void MDnsTest::DeleteTransaction() { + transaction_.reset(); +} + +void MDnsTest::DeleteBothListeners() { + listener1_.reset(); + listener2_.reset(); +} + +void MDnsTest::RunFor(base::TimeDelta time_period) { + base::CancelableCallback<void()> callback(base::Bind(&MDnsTest::Stop, + base::Unretained(this))); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, callback.callback(), time_period); + + base::MessageLoop::current()->Run(); + callback.Cancel(); +} + +void MDnsTest::Stop() { + base::MessageLoop::current()->Quit(); +} + +TEST_F(MDnsTest, PassiveListeners) { + StrictMock<MockListenerDelegate> delegate_privet; + StrictMock<MockListenerDelegate> delegate_printer; + + PtrRecordCopyContainer record_privet; + PtrRecordCopyContainer record_printer; + + scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener( + dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet); + scoped_ptr<MDnsListener> listener_printer = test_client_->CreateListener( + dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer); + + ASSERT_TRUE(listener_privet->Start()); + ASSERT_TRUE(listener_printer->Start()); + + // Send the same packet twice to ensure no records are double-counted. + + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)) + .WillOnce(Invoke( + &record_privet, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)) + .WillOnce(Invoke( + &record_printer, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + + EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local", + "hello._privet._tcp.local")); + + EXPECT_TRUE(record_printer.IsRecordWith("_printer._tcp.local", + "hello._printer._tcp.local")); + + listener_privet.reset(); + listener_printer.reset(); +} + +TEST_F(MDnsTest, PassiveListenersCacheCleanup) { + StrictMock<MockListenerDelegate> delegate_privet; + + PtrRecordCopyContainer record_privet; + PtrRecordCopyContainer record_privet2; + + scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener( + dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet); + + ASSERT_TRUE(listener_privet->Start()); + + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)) + .WillOnce(Invoke( + &record_privet, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + + EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local", + "hello._privet._tcp.local")); + + // Expect record is removed when its TTL expires. + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _)) + .Times(Exactly(1)) + .WillOnce(DoAll(InvokeWithoutArgs(this, &MDnsTest::Stop), + Invoke(&record_privet2, + &PtrRecordCopyContainer::SaveWithDummyArg))); + + RunFor(base::TimeDelta::FromSeconds(record_privet.ttl() + 1)); + + EXPECT_TRUE(record_privet2.IsRecordWith("_privet._tcp.local", + "hello._privet._tcp.local")); +} + +TEST_F(MDnsTest, MalformedPacket) { + StrictMock<MockListenerDelegate> delegate_printer; + + PtrRecordCopyContainer record_printer; + + scoped_ptr<MDnsListener> listener_printer = test_client_->CreateListener( + dns_protocol::kTypePTR, "_printer._tcp.local", &delegate_printer); + + ASSERT_TRUE(listener_printer->Start()); + + EXPECT_CALL(delegate_printer, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)) + .WillOnce(Invoke( + &record_printer, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + // First, send unsalvagable packet to ensure we can deal with it. + SimulatePacketReceive(kCorruptedPacketUnsalvagable, + sizeof(kCorruptedPacketUnsalvagable)); + + // Regression test: send a packet where the question cannot be read. + SimulatePacketReceive(kCorruptedPacketBadQuestion, + sizeof(kCorruptedPacketBadQuestion)); + + // Then send salvagable packet to ensure we can extract useful records. + SimulatePacketReceive(kCorruptedPacketSalvagable, + sizeof(kCorruptedPacketSalvagable)); + + EXPECT_TRUE(record_printer.IsRecordWith("_printer._tcp.local", + "hello._printer._tcp.local")); +} + +TEST_F(MDnsTest, TransactionWithEmptyCache) { + ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet)); + + scoped_ptr<MDnsTransaction> transaction_privet = + test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + ASSERT_TRUE(transaction_privet->Start()); + + PtrRecordCopyContainer record_privet; + + EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _)) + .Times(Exactly(1)) + .WillOnce(Invoke(&record_privet, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + + EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local", + "hello._privet._tcp.local")); +} + +TEST_F(MDnsTest, TransactionCacheOnlyNoResult) { + scoped_ptr<MDnsTransaction> transaction_privet = + test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + EXPECT_CALL(*this, + MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS, _)) + .Times(Exactly(1)); + + ASSERT_TRUE(transaction_privet->Start()); +} + +TEST_F(MDnsTest, TransactionWithCache) { + // Listener to force the client to listen + StrictMock<MockListenerDelegate> delegate_irrelevant; + scoped_ptr<MDnsListener> listener_irrelevant = test_client_->CreateListener( + dns_protocol::kTypeA, "codereview.chromium.local", + &delegate_irrelevant); + + ASSERT_TRUE(listener_irrelevant->Start()); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + + + PtrRecordCopyContainer record_privet; + + EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _)) + .WillOnce(Invoke(&record_privet, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + scoped_ptr<MDnsTransaction> transaction_privet = + test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + ASSERT_TRUE(transaction_privet->Start()); + + EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local", + "hello._privet._tcp.local")); +} + +TEST_F(MDnsTest, AdditionalRecords) { + StrictMock<MockListenerDelegate> delegate_privet; + + PtrRecordCopyContainer record_privet; + + scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener( + dns_protocol::kTypePTR, "_privet._tcp.local", + &delegate_privet); + + ASSERT_TRUE(listener_privet->Start()); + + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)) + .WillOnce(Invoke( + &record_privet, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + SimulatePacketReceive(kSamplePacketAdditionalOnly, + sizeof(kSamplePacketAdditionalOnly)); + + EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local", + "hello._privet._tcp.local")); +} + +TEST_F(MDnsTest, TransactionTimeout) { + ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet)); + + scoped_ptr<MDnsTransaction> transaction_privet = + test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + ASSERT_TRUE(transaction_privet->Start()); + + EXPECT_CALL(*this, + MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS, NULL)) + .Times(Exactly(1)) + .WillOnce(InvokeWithoutArgs(this, &MDnsTest::Stop)); + + RunFor(base::TimeDelta::FromSeconds(4)); +} + +TEST_F(MDnsTest, TransactionMultipleRecords) { + ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet)); + + scoped_ptr<MDnsTransaction> transaction_privet = + test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE , + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + ASSERT_TRUE(transaction_privet->Start()); + + PtrRecordCopyContainer record_privet; + PtrRecordCopyContainer record_privet2; + + EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _)) + .Times(Exactly(2)) + .WillOnce(Invoke(&record_privet, + &PtrRecordCopyContainer::SaveWithDummyArg)) + .WillOnce(Invoke(&record_privet2, + &PtrRecordCopyContainer::SaveWithDummyArg)); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + SimulatePacketReceive(kSamplePacket2, sizeof(kSamplePacket2)); + + EXPECT_TRUE(record_privet.IsRecordWith("_privet._tcp.local", + "hello._privet._tcp.local")); + + EXPECT_TRUE(record_privet2.IsRecordWith("_privet._tcp.local", + "zzzzz._privet._tcp.local")); + + EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_DONE, NULL)) + .WillOnce(InvokeWithoutArgs(this, &MDnsTest::Stop)); + + RunFor(base::TimeDelta::FromSeconds(4)); +} + +TEST_F(MDnsTest, TransactionReentrantDelete) { + ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet)); + + transaction_ = test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + ASSERT_TRUE(transaction_->Start()); + + EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_NO_RESULTS, + NULL)) + .Times(Exactly(1)) + .WillOnce(DoAll(InvokeWithoutArgs(this, &MDnsTest::DeleteTransaction), + InvokeWithoutArgs(this, &MDnsTest::Stop))); + + RunFor(base::TimeDelta::FromSeconds(4)); + + EXPECT_EQ(NULL, transaction_.get()); +} + +TEST_F(MDnsTest, TransactionReentrantDeleteFromCache) { + StrictMock<MockListenerDelegate> delegate_irrelevant; + scoped_ptr<MDnsListener> listener_irrelevant = test_client_->CreateListener( + dns_protocol::kTypeA, "codereview.chromium.local", + &delegate_irrelevant); + ASSERT_TRUE(listener_irrelevant->Start()); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + + transaction_ = test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, _)) + .Times(Exactly(1)) + .WillOnce(InvokeWithoutArgs(this, &MDnsTest::DeleteTransaction)); + + ASSERT_TRUE(transaction_->Start()); + + EXPECT_EQ(NULL, transaction_.get()); +} + +TEST_F(MDnsTest, TransactionReentrantCacheLookupStart) { + ExpectPacket(kQueryPacketPrivet, sizeof(kQueryPacketPrivet)); + + scoped_ptr<MDnsTransaction> transaction1 = test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + scoped_ptr<MDnsTransaction> transaction2 = test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_printer._tcp.local", + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback2, + base::Unretained(this))); + + EXPECT_CALL(*this, MockableRecordCallback2(MDnsTransaction::RESULT_RECORD, + _)) + .Times(Exactly(1)); + + EXPECT_CALL(*this, MockableRecordCallback(MDnsTransaction::RESULT_RECORD, + _)) + .Times(Exactly(1)) + .WillOnce(IgnoreResult(InvokeWithoutArgs(transaction2.get(), + &MDnsTransaction::Start))); + + ASSERT_TRUE(transaction1->Start()); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); +} + +TEST_F(MDnsTest, GoodbyePacketNotification) { + StrictMock<MockListenerDelegate> delegate_privet; + + scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener( + dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet); + ASSERT_TRUE(listener_privet->Start()); + + SimulatePacketReceive(kSamplePacketGoodbye, sizeof(kSamplePacketGoodbye)); + + RunFor(base::TimeDelta::FromSeconds(2)); +} + +TEST_F(MDnsTest, GoodbyePacketRemoval) { + StrictMock<MockListenerDelegate> delegate_privet; + + scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener( + dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet); + ASSERT_TRUE(listener_privet->Start()); + + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)); + + SimulatePacketReceive(kSamplePacket2, sizeof(kSamplePacket2)); + + SimulatePacketReceive(kSamplePacketGoodbye, sizeof(kSamplePacketGoodbye)); + + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _)) + .Times(Exactly(1)); + + RunFor(base::TimeDelta::FromSeconds(2)); +} + +// In order to reliably test reentrant listener deletes, we create two listeners +// and have each of them delete both, so we're guaranteed to try and deliver a +// callback to at least one deleted listener. + +TEST_F(MDnsTest, ListenerReentrantDelete) { + StrictMock<MockListenerDelegate> delegate_privet; + + listener1_ = test_client_->CreateListener( + dns_protocol::kTypePTR, "_privet._tcp.local", + &delegate_privet); + + listener2_ = test_client_->CreateListener( + dns_protocol::kTypePTR, "_privet._tcp.local", + &delegate_privet); + + ASSERT_TRUE(listener1_->Start()); + + ASSERT_TRUE(listener2_->Start()); + + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)) + .WillOnce(InvokeWithoutArgs(this, &MDnsTest::DeleteBothListeners)); + + SimulatePacketReceive(kSamplePacket1, sizeof(kSamplePacket1)); + + EXPECT_EQ(NULL, listener1_.get()); + EXPECT_EQ(NULL, listener2_.get()); +} + +ACTION_P(SaveIPAddress, ip_container) { + ::testing::StaticAssertTypeEq<const RecordParsed*, arg1_type>(); + ::testing::StaticAssertTypeEq<IPAddressNumber*, ip_container_type>(); + + *ip_container = arg1->template rdata<ARecordRdata>()->address(); +} + +TEST_F(MDnsTest, DoubleRecordDisagreeing) { + IPAddressNumber address; + StrictMock<MockListenerDelegate> delegate_privet; + + scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener( + dns_protocol::kTypeA, "privet.local", &delegate_privet); + + ASSERT_TRUE(listener_privet->Start()); + + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .Times(Exactly(1)) + .WillOnce(SaveIPAddress(&address)); + + SimulatePacketReceive(kCorruptedPacketDoubleRecord, + sizeof(kCorruptedPacketDoubleRecord)); + + EXPECT_EQ("2.3.4.5", IPAddressToString(address)); +} + +TEST_F(MDnsTest, NsecWithListener) { + StrictMock<MockListenerDelegate> delegate_privet; + scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener( + dns_protocol::kTypeA, "_privet._tcp.local", &delegate_privet); + + // Test to make sure nsec callback is NOT called for PTR + // (which is marked as existing). + StrictMock<MockListenerDelegate> delegate_privet2; + scoped_ptr<MDnsListener> listener_privet2 = test_client_->CreateListener( + dns_protocol::kTypePTR, "_privet._tcp.local", &delegate_privet2); + + ASSERT_TRUE(listener_privet->Start()); + + EXPECT_CALL(delegate_privet, + OnNsecRecord("_privet._tcp.local", dns_protocol::kTypeA)); + + SimulatePacketReceive(kSamplePacketNsec, + sizeof(kSamplePacketNsec)); +} + +TEST_F(MDnsTest, NsecWithTransactionFromNetwork) { + scoped_ptr<MDnsTransaction> transaction_privet = + test_client_->CreateTransaction( + dns_protocol::kTypeA, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + EXPECT_CALL(*socket_factory_, OnSendTo(_)) + .Times(2); + + ASSERT_TRUE(transaction_privet->Start()); + + EXPECT_CALL(*this, + MockableRecordCallback(MDnsTransaction::RESULT_NSEC, NULL)); + + SimulatePacketReceive(kSamplePacketNsec, + sizeof(kSamplePacketNsec)); +} + +TEST_F(MDnsTest, NsecWithTransactionFromCache) { + // Force mDNS to listen. + StrictMock<MockListenerDelegate> delegate_irrelevant; + scoped_ptr<MDnsListener> listener_irrelevant = + test_client_->CreateListener(dns_protocol::kTypePTR, "_privet._tcp.local", + &delegate_irrelevant); + listener_irrelevant->Start(); + + SimulatePacketReceive(kSamplePacketNsec, + sizeof(kSamplePacketNsec)); + + EXPECT_CALL(*this, + MockableRecordCallback(MDnsTransaction::RESULT_NSEC, NULL)); + + scoped_ptr<MDnsTransaction> transaction_privet_a = + test_client_->CreateTransaction( + dns_protocol::kTypeA, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + ASSERT_TRUE(transaction_privet_a->Start()); + + // Test that a PTR transaction does NOT consider the same NSEC record to be a + // valid answer to the query + + scoped_ptr<MDnsTransaction> transaction_privet_ptr = + test_client_->CreateTransaction( + dns_protocol::kTypePTR, "_privet._tcp.local", + MDnsTransaction::QUERY_NETWORK | + MDnsTransaction::QUERY_CACHE | + MDnsTransaction::SINGLE_RESULT, + base::Bind(&MDnsTest::MockableRecordCallback, + base::Unretained(this))); + + EXPECT_CALL(*socket_factory_, OnSendTo(_)) + .Times(2); + + ASSERT_TRUE(transaction_privet_ptr->Start()); +} + +TEST_F(MDnsTest, NsecConflictRemoval) { + StrictMock<MockListenerDelegate> delegate_privet; + scoped_ptr<MDnsListener> listener_privet = test_client_->CreateListener( + dns_protocol::kTypeA, "_privet._tcp.local", &delegate_privet); + + ASSERT_TRUE(listener_privet->Start()); + + const RecordParsed* record1; + const RecordParsed* record2; + + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_ADDED, _)) + .WillOnce(SaveArg<1>(&record1)); + + SimulatePacketReceive(kSamplePacketAPrivet, + sizeof(kSamplePacketAPrivet)); + + EXPECT_CALL(delegate_privet, OnRecordUpdate(MDnsListener::RECORD_REMOVED, _)) + .WillOnce(SaveArg<1>(&record2)); + + EXPECT_CALL(delegate_privet, + OnNsecRecord("_privet._tcp.local", dns_protocol::kTypeA)); + + SimulatePacketReceive(kSamplePacketNsec, + sizeof(kSamplePacketNsec)); + + EXPECT_EQ(record1, record2); +} + + +// Note: These tests assume that the ipv4 socket will always be created first. +// This is a simplifying assumption based on the way the code works now. + +class SimpleMockSocketFactory + : public MDnsConnection::SocketFactory { + public: + SimpleMockSocketFactory() { + } + virtual ~SimpleMockSocketFactory() { + } + + virtual scoped_ptr<DatagramServerSocket> CreateSocket() OVERRIDE { + scoped_ptr<MockMDnsDatagramServerSocket> socket( + new StrictMock<MockMDnsDatagramServerSocket>); + sockets_.push(socket.get()); + return socket.PassAs<DatagramServerSocket>(); + } + + MockMDnsDatagramServerSocket* PopFirstSocket() { + MockMDnsDatagramServerSocket* socket = sockets_.front(); + sockets_.pop(); + return socket; + } + + size_t num_sockets() { + return sockets_.size(); + } + + private: + std::queue<MockMDnsDatagramServerSocket*> sockets_; +}; + +class MockMDnsConnectionDelegate : public MDnsConnection::Delegate { + public: + virtual void HandlePacket(DnsResponse* response, int size) { + HandlePacketInternal(std::string(response->io_buffer()->data(), size)); + } + + MOCK_METHOD1(HandlePacketInternal, void(std::string packet)); + + MOCK_METHOD1(OnConnectionError, void(int error)); +}; + +class MDnsConnectionTest : public ::testing::Test { + public: + MDnsConnectionTest() : connection_(&factory_, &delegate_) { + } + + protected: + // Follow successful connection initialization. + virtual void SetUp() OVERRIDE { + ASSERT_EQ(2u, factory_.num_sockets()); + + socket_ipv4_ = factory_.PopFirstSocket(); + socket_ipv6_ = factory_.PopFirstSocket(); + } + + bool InitConnection() { + EXPECT_CALL(*socket_ipv4_, AllowAddressReuse()); + EXPECT_CALL(*socket_ipv6_, AllowAddressReuse()); + + EXPECT_CALL(*socket_ipv4_, SetMulticastLoopbackMode(false)); + EXPECT_CALL(*socket_ipv6_, SetMulticastLoopbackMode(false)); + + EXPECT_CALL(*socket_ipv4_, ListenInternal("0.0.0.0:5353")) + .WillOnce(Return(OK)); + EXPECT_CALL(*socket_ipv6_, ListenInternal("[::]:5353")) + .WillOnce(Return(OK)); + + EXPECT_CALL(*socket_ipv4_, JoinGroupInternal("224.0.0.251")) + .WillOnce(Return(OK)); + EXPECT_CALL(*socket_ipv6_, JoinGroupInternal("ff02::fb")) + .WillOnce(Return(OK)); + + return connection_.Init() == OK; + } + + StrictMock<MockMDnsConnectionDelegate> delegate_; + + MockMDnsDatagramServerSocket* socket_ipv4_; + MockMDnsDatagramServerSocket* socket_ipv6_; + SimpleMockSocketFactory factory_; + MDnsConnection connection_; + TestCompletionCallback callback_; +}; + +TEST_F(MDnsConnectionTest, ReceiveSynchronous) { + std::string sample_packet = MakeString(kSamplePacket1, + sizeof(kSamplePacket1)); + + socket_ipv6_->SetResponsePacket(sample_packet); + EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _)) + .WillOnce(Return(ERR_IO_PENDING)); + EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _)) + .WillOnce( + Invoke(socket_ipv6_, &MockMDnsDatagramServerSocket::HandleRecvNow)) + .WillOnce(Return(ERR_IO_PENDING)); + + EXPECT_CALL(delegate_, HandlePacketInternal(sample_packet)); + + ASSERT_TRUE(InitConnection()); +} + +TEST_F(MDnsConnectionTest, ReceiveAsynchronous) { + std::string sample_packet = MakeString(kSamplePacket1, + sizeof(kSamplePacket1)); + socket_ipv6_->SetResponsePacket(sample_packet); + EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _)) + .WillOnce(Return(ERR_IO_PENDING)); + EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _)) + .WillOnce( + Invoke(socket_ipv6_, &MockMDnsDatagramServerSocket::HandleRecvLater)) + .WillOnce(Return(ERR_IO_PENDING)); + + ASSERT_TRUE(InitConnection()); + + EXPECT_CALL(delegate_, HandlePacketInternal(sample_packet)); + + base::MessageLoop::current()->RunUntilIdle(); +} + +TEST_F(MDnsConnectionTest, Send) { + std::string sample_packet = MakeString(kSamplePacket1, + sizeof(kSamplePacket1)); + + scoped_refptr<IOBufferWithSize> buf( + new IOBufferWithSize(sizeof kSamplePacket1)); + memcpy(buf->data(), kSamplePacket1, sizeof(kSamplePacket1)); + + EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _)) + .WillOnce(Return(ERR_IO_PENDING)); + EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _)) + .WillOnce(Return(ERR_IO_PENDING)); + + ASSERT_TRUE(InitConnection()); + + EXPECT_CALL(*socket_ipv4_, + SendToInternal(sample_packet, "224.0.0.251:5353", _)); + EXPECT_CALL(*socket_ipv6_, + SendToInternal(sample_packet, "[ff02::fb]:5353", _)); + + connection_.Send(buf, buf->size()); +} + +TEST_F(MDnsConnectionTest, Error) { + CompletionCallback callback; + + EXPECT_CALL(*socket_ipv4_, RecvFrom(_, _, _, _)) + .WillOnce(Return(ERR_IO_PENDING)); + EXPECT_CALL(*socket_ipv6_, RecvFrom(_, _, _, _)) + .WillOnce(DoAll(SaveArg<3>(&callback), Return(ERR_IO_PENDING))); + + ASSERT_TRUE(InitConnection()); + + EXPECT_CALL(delegate_, OnConnectionError(ERR_SOCKET_NOT_CONNECTED)); + callback.Run(ERR_SOCKET_NOT_CONNECTED); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/dns/mock_host_resolver.cc b/chromium/net/dns/mock_host_resolver.cc new file mode 100644 index 00000000000..b3d1489c9d7 --- /dev/null +++ b/chromium/net/dns/mock_host_resolver.cc @@ -0,0 +1,408 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/mock_host_resolver.h" + +#include <string> +#include <vector> + +#include "base/bind.h" +#include "base/memory/ref_counted.h" +#include "base/message_loop/message_loop.h" +#include "base/stl_util.h" +#include "base/strings/string_split.h" +#include "base/strings/string_util.h" +#include "base/threading/platform_thread.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" +#include "net/base/test_completion_callback.h" +#include "net/dns/host_cache.h" + +#if defined(OS_WIN) +#include "net/base/winsock_init.h" +#endif + +namespace net { + +namespace { + +// Cache size for the MockCachingHostResolver. +const unsigned kMaxCacheEntries = 100; +// TTL for the successful resolutions. Failures are not cached. +const unsigned kCacheEntryTTLSeconds = 60; + +} // namespace + +int ParseAddressList(const std::string& host_list, + const std::string& canonical_name, + AddressList* addrlist) { + *addrlist = AddressList(); + std::vector<std::string> addresses; + base::SplitString(host_list, ',', &addresses); + addrlist->set_canonical_name(canonical_name); + for (size_t index = 0; index < addresses.size(); ++index) { + IPAddressNumber ip_number; + if (!ParseIPLiteralToNumber(addresses[index], &ip_number)) { + LOG(WARNING) << "Not a supported IP literal: " << addresses[index]; + return ERR_UNEXPECTED; + } + addrlist->push_back(IPEndPoint(ip_number, -1)); + } + return OK; +} + +struct MockHostResolverBase::Request { + Request(const RequestInfo& req_info, + AddressList* addr, + const CompletionCallback& cb) + : info(req_info), addresses(addr), callback(cb) {} + RequestInfo info; + AddressList* addresses; + CompletionCallback callback; +}; + +MockHostResolverBase::~MockHostResolverBase() { + STLDeleteValues(&requests_); +} + +int MockHostResolverBase::Resolve(const RequestInfo& info, + AddressList* addresses, + const CompletionCallback& callback, + RequestHandle* handle, + const BoundNetLog& net_log) { + DCHECK(CalledOnValidThread()); + num_resolve_++; + size_t id = next_request_id_++; + int rv = ResolveFromIPLiteralOrCache(info, addresses); + if (rv != ERR_DNS_CACHE_MISS) { + return rv; + } + if (synchronous_mode_) { + return ResolveProc(id, info, addresses); + } + // Store the request for asynchronous resolution + Request* req = new Request(info, addresses, callback); + requests_[id] = req; + if (handle) + *handle = reinterpret_cast<RequestHandle>(id); + + if (!ondemand_mode_) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id)); + } + + return ERR_IO_PENDING; +} + +int MockHostResolverBase::ResolveFromCache(const RequestInfo& info, + AddressList* addresses, + const BoundNetLog& net_log) { + num_resolve_from_cache_++; + DCHECK(CalledOnValidThread()); + next_request_id_++; + int rv = ResolveFromIPLiteralOrCache(info, addresses); + return rv; +} + +void MockHostResolverBase::CancelRequest(RequestHandle handle) { + DCHECK(CalledOnValidThread()); + size_t id = reinterpret_cast<size_t>(handle); + RequestMap::iterator it = requests_.find(id); + if (it != requests_.end()) { + scoped_ptr<Request> req(it->second); + requests_.erase(it); + } else { + NOTREACHED() << "CancelRequest must NOT be called after request is " + "complete or canceled."; + } +} + +HostCache* MockHostResolverBase::GetHostCache() { + return cache_.get(); +} + +void MockHostResolverBase::ResolveAllPending() { + DCHECK(CalledOnValidThread()); + DCHECK(ondemand_mode_); + for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), i->first)); + } +} + +// start id from 1 to distinguish from NULL RequestHandle +MockHostResolverBase::MockHostResolverBase(bool use_caching) + : synchronous_mode_(false), + ondemand_mode_(false), + next_request_id_(1), + num_resolve_(0), + num_resolve_from_cache_(0) { + rules_ = CreateCatchAllHostResolverProc(); + + if (use_caching) { + cache_.reset(new HostCache(kMaxCacheEntries)); + } +} + +int MockHostResolverBase::ResolveFromIPLiteralOrCache(const RequestInfo& info, + AddressList* addresses) { + IPAddressNumber ip; + if (ParseIPLiteralToNumber(info.hostname(), &ip)) { + *addresses = AddressList::CreateFromIPAddress(ip, info.port()); + if (info.host_resolver_flags() & HOST_RESOLVER_CANONNAME) + addresses->SetDefaultCanonicalName(); + return OK; + } + int rv = ERR_DNS_CACHE_MISS; + if (cache_.get() && info.allow_cached_response()) { + HostCache::Key key(info.hostname(), + info.address_family(), + info.host_resolver_flags()); + const HostCache::Entry* entry = cache_->Lookup(key, base::TimeTicks::Now()); + if (entry) { + rv = entry->error; + if (rv == OK) + *addresses = AddressList::CopyWithPort(entry->addrlist, info.port()); + } + } + return rv; +} + +int MockHostResolverBase::ResolveProc(size_t id, + const RequestInfo& info, + AddressList* addresses) { + AddressList addr; + int rv = rules_->Resolve(info.hostname(), + info.address_family(), + info.host_resolver_flags(), + &addr, + NULL); + if (cache_.get()) { + HostCache::Key key(info.hostname(), + info.address_family(), + info.host_resolver_flags()); + // Storing a failure with TTL 0 so that it overwrites previous value. + base::TimeDelta ttl; + if (rv == OK) + ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds); + cache_->Set(key, HostCache::Entry(rv, addr), base::TimeTicks::Now(), ttl); + } + if (rv == OK) + *addresses = AddressList::CopyWithPort(addr, info.port()); + return rv; +} + +void MockHostResolverBase::ResolveNow(size_t id) { + RequestMap::iterator it = requests_.find(id); + if (it == requests_.end()) + return; // was canceled + + scoped_ptr<Request> req(it->second); + requests_.erase(it); + int rv = ResolveProc(id, req->info, req->addresses); + if (!req->callback.is_null()) + req->callback.Run(rv); +} + +//----------------------------------------------------------------------------- + +RuleBasedHostResolverProc::RuleBasedHostResolverProc(HostResolverProc* previous) + : HostResolverProc(previous) { +} + +void RuleBasedHostResolverProc::AddRule(const std::string& host_pattern, + const std::string& replacement) { + AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED, + replacement); +} + +void RuleBasedHostResolverProc::AddRuleForAddressFamily( + const std::string& host_pattern, + AddressFamily address_family, + const std::string& replacement) { + DCHECK(!replacement.empty()); + HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | + HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; + Rule rule(Rule::kResolverTypeSystem, + host_pattern, + address_family, + flags, + replacement, + std::string(), + 0); + rules_.push_back(rule); +} + +void RuleBasedHostResolverProc::AddIPLiteralRule( + const std::string& host_pattern, + const std::string& ip_literal, + const std::string& canonical_name) { + // Literals are always resolved to themselves by HostResolverImpl, + // consequently we do not support remapping them. + IPAddressNumber ip_number; + DCHECK(!ParseIPLiteralToNumber(host_pattern, &ip_number)); + HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | + HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; + if (!canonical_name.empty()) + flags |= HOST_RESOLVER_CANONNAME; + Rule rule(Rule::kResolverTypeIPLiteral, host_pattern, + ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, canonical_name, + 0); + rules_.push_back(rule); +} + +void RuleBasedHostResolverProc::AddRuleWithLatency( + const std::string& host_pattern, + const std::string& replacement, + int latency_ms) { + DCHECK(!replacement.empty()); + HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | + HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; + Rule rule(Rule::kResolverTypeSystem, + host_pattern, + ADDRESS_FAMILY_UNSPECIFIED, + flags, + replacement, + std::string(), + latency_ms); + rules_.push_back(rule); +} + +void RuleBasedHostResolverProc::AllowDirectLookup( + const std::string& host_pattern) { + HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | + HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; + Rule rule(Rule::kResolverTypeSystem, + host_pattern, + ADDRESS_FAMILY_UNSPECIFIED, + flags, + std::string(), + std::string(), + 0); + rules_.push_back(rule); +} + +void RuleBasedHostResolverProc::AddSimulatedFailure( + const std::string& host_pattern) { + HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | + HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; + Rule rule(Rule::kResolverTypeFail, + host_pattern, + ADDRESS_FAMILY_UNSPECIFIED, + flags, + std::string(), + std::string(), + 0); + rules_.push_back(rule); +} + +void RuleBasedHostResolverProc::ClearRules() { + rules_.clear(); +} + +int RuleBasedHostResolverProc::Resolve(const std::string& host, + AddressFamily address_family, + HostResolverFlags host_resolver_flags, + AddressList* addrlist, + int* os_error) { + RuleList::iterator r; + for (r = rules_.begin(); r != rules_.end(); ++r) { + bool matches_address_family = + r->address_family == ADDRESS_FAMILY_UNSPECIFIED || + r->address_family == address_family; + // Flags match if all of the bitflags in host_resolver_flags are enabled + // in the rule's host_resolver_flags. However, the rule may have additional + // flags specified, in which case the flags should still be considered a + // match. + bool matches_flags = (r->host_resolver_flags & host_resolver_flags) == + host_resolver_flags; + if (matches_flags && matches_address_family && + MatchPattern(host, r->host_pattern)) { + if (r->latency_ms != 0) { + base::PlatformThread::Sleep( + base::TimeDelta::FromMilliseconds(r->latency_ms)); + } + + // Remap to a new host. + const std::string& effective_host = + r->replacement.empty() ? host : r->replacement; + + // Apply the resolving function to the remapped hostname. + switch (r->resolver_type) { + case Rule::kResolverTypeFail: + return ERR_NAME_NOT_RESOLVED; + case Rule::kResolverTypeSystem: +#if defined(OS_WIN) + net::EnsureWinsockInit(); +#endif + return SystemHostResolverCall(effective_host, + address_family, + host_resolver_flags, + addrlist, os_error); + case Rule::kResolverTypeIPLiteral: + return ParseAddressList(effective_host, + r->canonical_name, + addrlist); + default: + NOTREACHED(); + return ERR_UNEXPECTED; + } + } + } + return ResolveUsingPrevious(host, address_family, + host_resolver_flags, addrlist, os_error); +} + +RuleBasedHostResolverProc::~RuleBasedHostResolverProc() { +} + +RuleBasedHostResolverProc* CreateCatchAllHostResolverProc() { + RuleBasedHostResolverProc* catchall = new RuleBasedHostResolverProc(NULL); + catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost"); + + // Next add a rules-based layer the use controls. + return new RuleBasedHostResolverProc(catchall); +} + +//----------------------------------------------------------------------------- + +int HangingHostResolver::Resolve(const RequestInfo& info, + AddressList* addresses, + const CompletionCallback& callback, + RequestHandle* out_req, + const BoundNetLog& net_log) { + return ERR_IO_PENDING; +} + +int HangingHostResolver::ResolveFromCache(const RequestInfo& info, + AddressList* addresses, + const BoundNetLog& net_log) { + return ERR_DNS_CACHE_MISS; +} + +//----------------------------------------------------------------------------- + +ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() {} + +ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc( + HostResolverProc* proc) { + Init(proc); +} + +ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() { + HostResolverProc* old_proc = + HostResolverProc::SetDefault(previous_proc_.get()); + // The lifetimes of multiple instances must be nested. + CHECK_EQ(old_proc, current_proc_); +} + +void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) { + current_proc_ = proc; + previous_proc_ = HostResolverProc::SetDefault(current_proc_.get()); + current_proc_->SetLastProc(previous_proc_.get()); +} + +} // namespace net diff --git a/chromium/net/dns/mock_host_resolver.h b/chromium/net/dns/mock_host_resolver.h new file mode 100644 index 00000000000..282521cc3b9 --- /dev/null +++ b/chromium/net/dns/mock_host_resolver.h @@ -0,0 +1,284 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_MOCK_HOST_RESOLVER_H_ +#define NET_DNS_MOCK_HOST_RESOLVER_H_ + +#include <list> +#include <map> + +#include "base/memory/weak_ptr.h" +#include "base/synchronization/waitable_event.h" +#include "base/threading/non_thread_safe.h" +#include "net/dns/host_resolver.h" +#include "net/dns/host_resolver_proc.h" + +namespace net { + +class HostCache; +class RuleBasedHostResolverProc; + +// Fills |*addrlist| with a socket address for |host_list| which should be a +// comma-separated list of IPv4 or IPv6 literal(s) without enclosing brackets. +// If |canonical_name| is non-empty it is used as the DNS canonical name for +// the host. Returns OK on success, ERR_UNEXPECTED otherwise. +int ParseAddressList(const std::string& host_list, + const std::string& canonical_name, + AddressList* addrlist); + +// In most cases, it is important that unit tests avoid relying on making actual +// DNS queries since the resulting tests can be flaky, especially if the network +// is unreliable for some reason. To simplify writing tests that avoid making +// actual DNS queries, pass a MockHostResolver as the HostResolver dependency. +// The socket addresses returned can be configured using the +// RuleBasedHostResolverProc: +// +// host_resolver->rules()->AddRule("foo.com", "1.2.3.4"); +// host_resolver->rules()->AddRule("bar.com", "2.3.4.5"); +// +// The above rules define a static mapping from hostnames to IP address +// literals. The first parameter to AddRule specifies a host pattern to match +// against, and the second parameter indicates what value should be used to +// replace the given hostname. So, the following is also supported: +// +// host_mapper->AddRule("*.com", "127.0.0.1"); +// +// Replacement doesn't have to be string representing an IP address. It can +// re-map one hostname to another as well. +// +// By default, MockHostResolvers include a single rule that maps all hosts to +// 127.0.0.1. + +// Base class shared by MockHostResolver and MockCachingHostResolver. +class MockHostResolverBase : public HostResolver, + public base::SupportsWeakPtr<MockHostResolverBase>, + public base::NonThreadSafe { + public: + virtual ~MockHostResolverBase(); + + RuleBasedHostResolverProc* rules() { return rules_.get(); } + void set_rules(RuleBasedHostResolverProc* rules) { rules_ = rules; } + + // Controls whether resolutions complete synchronously or asynchronously. + void set_synchronous_mode(bool is_synchronous) { + synchronous_mode_ = is_synchronous; + } + + // Asynchronous requests are automatically resolved by default. + // If set_ondemand_mode() is set then Resolve() returns IO_PENDING and + // ResolveAllPending() must be explicitly invoked to resolve all requests + // that are pending. + void set_ondemand_mode(bool is_ondemand) { + ondemand_mode_ = is_ondemand; + } + + // HostResolver methods: + virtual int Resolve(const RequestInfo& info, + AddressList* addresses, + const CompletionCallback& callback, + RequestHandle* out_req, + const BoundNetLog& net_log) OVERRIDE; + virtual int ResolveFromCache(const RequestInfo& info, + AddressList* addresses, + const BoundNetLog& net_log) OVERRIDE; + virtual void CancelRequest(RequestHandle req) OVERRIDE; + virtual HostCache* GetHostCache() OVERRIDE; + + // Resolves all pending requests. It is only valid to invoke this if + // set_ondemand_mode was set before. The requests are resolved asynchronously, + // after this call returns. + void ResolveAllPending(); + + // Returns true if there are pending requests that can be resolved by invoking + // ResolveAllPending(). + bool has_pending_requests() const { return !requests_.empty(); } + + // The number of times that Resolve() has been called. + size_t num_resolve() const { + return num_resolve_; + } + + // The number of times that ResolveFromCache() has been called. + size_t num_resolve_from_cache() const { + return num_resolve_from_cache_; + } + + protected: + explicit MockHostResolverBase(bool use_caching); + + private: + struct Request; + typedef std::map<size_t, Request*> RequestMap; + + // Resolve as IP or from |cache_| return cached error or + // DNS_CACHE_MISS if failed. + int ResolveFromIPLiteralOrCache(const RequestInfo& info, + AddressList* addresses); + // Resolve via |proc_|. + int ResolveProc(size_t id, const RequestInfo& info, AddressList* addresses); + // Resolve request stored in |requests_|. Pass rv to callback. + void ResolveNow(size_t id); + + bool synchronous_mode_; + bool ondemand_mode_; + scoped_refptr<RuleBasedHostResolverProc> rules_; + scoped_ptr<HostCache> cache_; + RequestMap requests_; + size_t next_request_id_; + + size_t num_resolve_; + size_t num_resolve_from_cache_; + + DISALLOW_COPY_AND_ASSIGN(MockHostResolverBase); +}; + +class MockHostResolver : public MockHostResolverBase { + public: + MockHostResolver() : MockHostResolverBase(false /*use_caching*/) {} + virtual ~MockHostResolver() {} +}; + +// Same as MockHostResolver, except internally it uses a host-cache. +// +// Note that tests are advised to use MockHostResolver instead, since it is +// more predictable. (MockHostResolver also can be put into synchronous +// operation mode in case that is what you needed from the caching version). +class MockCachingHostResolver : public MockHostResolverBase { + public: + MockCachingHostResolver() : MockHostResolverBase(true /*use_caching*/) {} + virtual ~MockCachingHostResolver() {} +}; + +// RuleBasedHostResolverProc applies a set of rules to map a host string to +// a replacement host string. It then uses the system host resolver to return +// a socket address. Generally the replacement should be an IPv4 literal so +// there is no network dependency. +class RuleBasedHostResolverProc : public HostResolverProc { + public: + explicit RuleBasedHostResolverProc(HostResolverProc* previous); + + // Any hostname matching the given pattern will be replaced with the given + // replacement value. Usually, replacement should be an IP address literal. + void AddRule(const std::string& host_pattern, + const std::string& replacement); + + // Same as AddRule(), but further restricts to |address_family|. + void AddRuleForAddressFamily(const std::string& host_pattern, + AddressFamily address_family, + const std::string& replacement); + + // Same as AddRule(), but the replacement is expected to be an IPv4 or IPv6 + // literal. This can be used in place of AddRule() to bypass the system's + // host resolver (the address list will be constructed manually). + // If |canonical_name| is non-empty, it is copied to the resulting AddressList + // but does not impact DNS resolution. + // |ip_literal| can be a single IP address like "192.168.1.1" or a comma + // separated list of IP addresses, like "::1,192:168.1.2". + void AddIPLiteralRule(const std::string& host_pattern, + const std::string& ip_literal, + const std::string& canonical_name); + + void AddRuleWithLatency(const std::string& host_pattern, + const std::string& replacement, + int latency_ms); + + // Make sure that |host| will not be re-mapped or even processed by underlying + // host resolver procedures. It can also be a pattern. + void AllowDirectLookup(const std::string& host); + + // Simulate a lookup failure for |host| (it also can be a pattern). + void AddSimulatedFailure(const std::string& host); + + // Deletes all the rules that have been added. + void ClearRules(); + + // HostResolverProc methods: + virtual int Resolve(const std::string& host, + AddressFamily address_family, + HostResolverFlags host_resolver_flags, + AddressList* addrlist, + int* os_error) OVERRIDE; + + private: + struct Rule { + enum ResolverType { + kResolverTypeFail, + kResolverTypeSystem, + kResolverTypeIPLiteral, + }; + + ResolverType resolver_type; + std::string host_pattern; + AddressFamily address_family; + HostResolverFlags host_resolver_flags; + std::string replacement; + std::string canonical_name; + int latency_ms; // In milliseconds. + + Rule(ResolverType resolver_type, + const std::string& host_pattern, + AddressFamily address_family, + HostResolverFlags host_resolver_flags, + const std::string& replacement, + const std::string& canonical_name, + int latency_ms) + : resolver_type(resolver_type), + host_pattern(host_pattern), + address_family(address_family), + host_resolver_flags(host_resolver_flags), + replacement(replacement), + canonical_name(canonical_name), + latency_ms(latency_ms) {} + }; + + typedef std::list<Rule> RuleList; + + virtual ~RuleBasedHostResolverProc(); + + RuleList rules_; +}; + +// Create rules that map all requests to localhost. +RuleBasedHostResolverProc* CreateCatchAllHostResolverProc(); + +// HangingHostResolver never completes its |Resolve| request. +class HangingHostResolver : public HostResolver { + public: + virtual int Resolve(const RequestInfo& info, + AddressList* addresses, + const CompletionCallback& callback, + RequestHandle* out_req, + const BoundNetLog& net_log) OVERRIDE; + virtual int ResolveFromCache(const RequestInfo& info, + AddressList* addresses, + const BoundNetLog& net_log) OVERRIDE; + virtual void CancelRequest(RequestHandle req) OVERRIDE {} +}; + +// This class sets the default HostResolverProc for a particular scope. The +// chain of resolver procs starting at |proc| is placed in front of any existing +// default resolver proc(s). This means that if multiple +// ScopedDefaultHostResolverProcs are declared, then resolving will start with +// the procs given to the last-allocated one, then fall back to the procs given +// to the previously-allocated one, and so forth. +// +// NOTE: Only use this as a catch-all safety net. Individual tests should use +// MockHostResolver. +class ScopedDefaultHostResolverProc { + public: + ScopedDefaultHostResolverProc(); + explicit ScopedDefaultHostResolverProc(HostResolverProc* proc); + + ~ScopedDefaultHostResolverProc(); + + void Init(HostResolverProc* proc); + + private: + scoped_refptr<HostResolverProc> current_proc_; + scoped_refptr<HostResolverProc> previous_proc_; +}; + +} // namespace net + +#endif // NET_DNS_MOCK_HOST_RESOLVER_H_ diff --git a/chromium/net/dns/mock_mdns_socket_factory.cc b/chromium/net/dns/mock_mdns_socket_factory.cc new file mode 100644 index 00000000000..8c08c150cd1 --- /dev/null +++ b/chromium/net/dns/mock_mdns_socket_factory.cc @@ -0,0 +1,115 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <algorithm> + +#include "net/base/net_errors.h" +#include "net/dns/mock_mdns_socket_factory.h" + +using testing::_; +using testing::Invoke; + +namespace net { + +MockMDnsDatagramServerSocket::MockMDnsDatagramServerSocket() { +} + +MockMDnsDatagramServerSocket::~MockMDnsDatagramServerSocket() { +} + +int MockMDnsDatagramServerSocket::SendTo(IOBuffer* buf, int buf_len, + const IPEndPoint& address, + const CompletionCallback& callback) { + return SendToInternal(std::string(buf->data(), buf_len), address.ToString(), + callback); +} + +int MockMDnsDatagramServerSocket::Listen(const IPEndPoint& address) { + return ListenInternal(address.ToString()); +} + +int MockMDnsDatagramServerSocket::JoinGroup( + const IPAddressNumber& group_address) const { + return JoinGroupInternal(IPAddressToString(group_address)); +} + +int MockMDnsDatagramServerSocket::LeaveGroup( + const IPAddressNumber& group_address) const { + return LeaveGroupInternal(IPAddressToString(group_address)); +} + +void MockMDnsDatagramServerSocket::SetResponsePacket( + std::string response_packet) { + response_packet_ = response_packet; +} + +int MockMDnsDatagramServerSocket::HandleRecvNow( + IOBuffer* buffer, int size, IPEndPoint* address, + const CompletionCallback& callback) { + int size_returned = + std::min(response_packet_.size(), static_cast<size_t>(size)); + memcpy(buffer->data(), response_packet_.data(), size_returned); + return size_returned; +} + +int MockMDnsDatagramServerSocket::HandleRecvLater( + IOBuffer* buffer, int size, IPEndPoint* address, + const CompletionCallback& callback) { + int rv = HandleRecvNow(buffer, size, address, callback); + base::MessageLoop::current()->PostTask(FROM_HERE, base::Bind(callback, rv)); + return ERR_IO_PENDING; +} + +MockMDnsSocketFactory::MockMDnsSocketFactory() { +} + +MockMDnsSocketFactory::~MockMDnsSocketFactory() { +} + +scoped_ptr<DatagramServerSocket> MockMDnsSocketFactory::CreateSocket() { + scoped_ptr<MockMDnsDatagramServerSocket> new_socket( + new testing:: NiceMock<MockMDnsDatagramServerSocket>); + + ON_CALL(*new_socket, SendToInternal(_, _, _)) + .WillByDefault(Invoke( + this, + &MockMDnsSocketFactory::SendToInternal)); + + ON_CALL(*new_socket, RecvFrom(_, _, _, _)) + .WillByDefault(Invoke( + this, + &MockMDnsSocketFactory::RecvFromInternal)); + + return new_socket.PassAs<DatagramServerSocket>(); +} + +void MockMDnsSocketFactory::SimulateReceive(const uint8* packet, int size) { + DCHECK(recv_buffer_size_ >= size); + DCHECK(recv_buffer_.get()); + DCHECK(!recv_callback_.is_null()); + + memcpy(recv_buffer_->data(), packet, size); + CompletionCallback recv_callback = recv_callback_; + recv_callback_.Reset(); + recv_callback.Run(size); +} + +int MockMDnsSocketFactory::RecvFromInternal( + IOBuffer* buffer, int size, + IPEndPoint* address, + const CompletionCallback& callback) { + recv_buffer_ = buffer; + recv_buffer_size_ = size; + recv_callback_ = callback; + return ERR_IO_PENDING; +} + +int MockMDnsSocketFactory::SendToInternal( + const std::string& packet, const std::string& address, + const CompletionCallback& callback) { + OnSendTo(packet); + return packet.size(); +} + +} // namespace net diff --git a/chromium/net/dns/mock_mdns_socket_factory.h b/chromium/net/dns/mock_mdns_socket_factory.h new file mode 100644 index 00000000000..f60b08c591b --- /dev/null +++ b/chromium/net/dns/mock_mdns_socket_factory.h @@ -0,0 +1,101 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_MOCK_MDNS_SOCKET_FACTORY_H_ +#define NET_DNS_MOCK_MDNS_SOCKET_FACTORY_H_ + +#include <string> + +#include "net/dns/mdns_client_impl.h" +#include "testing/gmock/include/gmock/gmock.h" + +namespace net { + +class MockMDnsDatagramServerSocket : public DatagramServerSocket { + public: + MockMDnsDatagramServerSocket(); + ~MockMDnsDatagramServerSocket(); + + // DatagramServerSocket implementation: + int Listen(const IPEndPoint& address); + + MOCK_METHOD1(ListenInternal, int(const std::string& address)); + + MOCK_METHOD4(RecvFrom, int(IOBuffer* buffer, int size, + IPEndPoint* address, + const CompletionCallback& callback)); + + int SendTo(IOBuffer* buf, int buf_len, const IPEndPoint& address, + const CompletionCallback& callback); + + MOCK_METHOD3(SendToInternal, int(const std::string& packet, + const std::string address, + const CompletionCallback& callback)); + + MOCK_METHOD1(SetReceiveBufferSize, bool(int32 size)); + MOCK_METHOD1(SetSendBufferSize, bool(int32 size)); + + MOCK_METHOD0(Close, void()); + + MOCK_CONST_METHOD1(GetPeerAddress, int(IPEndPoint* address)); + MOCK_CONST_METHOD1(GetLocalAddress, int(IPEndPoint* address)); + MOCK_CONST_METHOD0(NetLog, const BoundNetLog&()); + + MOCK_METHOD0(AllowAddressReuse, void()); + MOCK_METHOD0(AllowBroadcast, void()); + + int JoinGroup(const IPAddressNumber& group_address) const; + + MOCK_CONST_METHOD1(JoinGroupInternal, int(const std::string& group)); + + int LeaveGroup(const IPAddressNumber& group_address) const; + + MOCK_CONST_METHOD1(LeaveGroupInternal, int(const std::string& group)); + + MOCK_METHOD1(SetMulticastTimeToLive, int(int ttl)); + + MOCK_METHOD1(SetMulticastLoopbackMode, int(bool loopback)); + + void SetResponsePacket(std::string response_packet); + + int HandleRecvNow(IOBuffer* buffer, int size, IPEndPoint* address, + const CompletionCallback& callback); + + int HandleRecvLater(IOBuffer* buffer, int size, IPEndPoint* address, + const CompletionCallback& callback); + + private: + std::string response_packet_; +}; + +class MockMDnsSocketFactory : public MDnsConnection::SocketFactory { + public: + MockMDnsSocketFactory(); + + virtual ~MockMDnsSocketFactory(); + + virtual scoped_ptr<DatagramServerSocket> CreateSocket() OVERRIDE; + + void SimulateReceive(const uint8* packet, int size); + + MOCK_METHOD1(OnSendTo, void(const std::string&)); + + private: + int SendToInternal(const std::string& packet, const std::string& address, + const CompletionCallback& callback); + + // The latest receive callback is always saved, since the MDnsConnection + // does not care which socket a packet is received on. + int RecvFromInternal(IOBuffer* buffer, int size, + IPEndPoint* address, + const CompletionCallback& callback); + + scoped_refptr<IOBuffer> recv_buffer_; + int recv_buffer_size_; + CompletionCallback recv_callback_; +}; + +} // namespace net + +#endif // NET_DNS_MOCK_MDNS_SOCKET_FACTORY_H_ diff --git a/chromium/net/dns/notify_watcher_mac.cc b/chromium/net/dns/notify_watcher_mac.cc new file mode 100644 index 00000000000..286f18bb7bb --- /dev/null +++ b/chromium/net/dns/notify_watcher_mac.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/notify_watcher_mac.h" + +#include <notify.h> + +#include "base/logging.h" +#include "base/posix/eintr_wrapper.h" + +namespace net { + +NotifyWatcherMac::NotifyWatcherMac() : notify_fd_(-1), notify_token_(-1) {} + +NotifyWatcherMac::~NotifyWatcherMac() { + Cancel(); +} + +bool NotifyWatcherMac::Watch(const char* key, const CallbackType& callback) { + DCHECK(key); + DCHECK(!callback.is_null()); + Cancel(); + uint32_t status = notify_register_file_descriptor( + key, ¬ify_fd_, 0, ¬ify_token_); + if (status != NOTIFY_STATUS_OK) + return false; + DCHECK_GE(notify_fd_, 0); + if (!base::MessageLoopForIO::current()->WatchFileDescriptor( + notify_fd_, + true, + base::MessageLoopForIO::WATCH_READ, + &watcher_, + this)) { + Cancel(); + return false; + } + callback_ = callback; + return true; +} + +void NotifyWatcherMac::Cancel() { + if (notify_fd_ >= 0) { + notify_cancel(notify_token_); // Also closes |notify_fd_|. + notify_fd_ = -1; + callback_.Reset(); + watcher_.StopWatchingFileDescriptor(); + } +} + +void NotifyWatcherMac::OnFileCanReadWithoutBlocking(int fd) { + int token; + int status = HANDLE_EINTR(read(notify_fd_, &token, sizeof(token))); + if (status != sizeof(token)) { + Cancel(); + callback_.Run(false); + return; + } + // Ignoring |token| value to avoid possible endianness mismatch: + // http://openradar.appspot.com/8821081 + callback_.Run(true); +} + +} // namespace net diff --git a/chromium/net/dns/notify_watcher_mac.h b/chromium/net/dns/notify_watcher_mac.h new file mode 100644 index 00000000000..01375d56637 --- /dev/null +++ b/chromium/net/dns/notify_watcher_mac.h @@ -0,0 +1,47 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_NOTIFY_WATCHER_MAC_H_ +#define NET_DNS_NOTIFY_WATCHER_MAC_H_ + +#include "base/callback.h" +#include "base/message_loop/message_loop.h" + +namespace net { + +// Watches for notifications from Libnotify and delivers them to a Callback. +// After failure the watch is cancelled and will have to be restarted. +class NotifyWatcherMac : public base::MessageLoopForIO::Watcher { + public: + // Called on received notification with true on success and false on error. + typedef base::Callback<void(bool succeeded)> CallbackType; + + NotifyWatcherMac(); + + // When deleted, automatically cancels. + virtual ~NotifyWatcherMac(); + + // Registers for notifications for |key|. Returns true if succeeds. If so, + // will deliver asynchronous notifications and errors to |callback|. + bool Watch(const char* key, const CallbackType& callback); + + // Cancels the watch. + void Cancel(); + + private: + // MessageLoopForIO::Watcher: + virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE; + virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE {} + + int notify_fd_; + int notify_token_; + CallbackType callback_; + base::MessageLoopForIO::FileDescriptorWatcher watcher_; + + DISALLOW_COPY_AND_ASSIGN(NotifyWatcherMac); +}; + +} // namespace net + +#endif // NET_DNS_NOTIFY_WATCHER_MAC_H_ diff --git a/chromium/net/dns/record_parsed.cc b/chromium/net/dns/record_parsed.cc new file mode 100644 index 00000000000..bee6c7aac63 --- /dev/null +++ b/chromium/net/dns/record_parsed.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/record_parsed.h" + +#include "base/logging.h" +#include "net/dns/dns_response.h" +#include "net/dns/record_rdata.h" + +namespace net { + +RecordParsed::RecordParsed(const std::string& name, uint16 type, uint16 klass, + uint32 ttl, scoped_ptr<const RecordRdata> rdata, + base::Time time_created) + : name_(name), type_(type), klass_(klass), ttl_(ttl), rdata_(rdata.Pass()), + time_created_(time_created) { +} + +RecordParsed::~RecordParsed() { +} + +// static +scoped_ptr<const RecordParsed> RecordParsed::CreateFrom( + DnsRecordParser* parser, + base::Time time_created) { + DnsResourceRecord record; + scoped_ptr<const RecordRdata> rdata; + + if (!parser->ReadRecord(&record)) + return scoped_ptr<const RecordParsed>(); + + switch (record.type) { + case ARecordRdata::kType: + rdata = ARecordRdata::Create(record.rdata, *parser); + break; + case AAAARecordRdata::kType: + rdata = AAAARecordRdata::Create(record.rdata, *parser); + break; + case CnameRecordRdata::kType: + rdata = CnameRecordRdata::Create(record.rdata, *parser); + break; + case PtrRecordRdata::kType: + rdata = PtrRecordRdata::Create(record.rdata, *parser); + break; + case SrvRecordRdata::kType: + rdata = SrvRecordRdata::Create(record.rdata, *parser); + break; + case TxtRecordRdata::kType: + rdata = TxtRecordRdata::Create(record.rdata, *parser); + break; + case NsecRecordRdata::kType: + rdata = NsecRecordRdata::Create(record.rdata, *parser); + break; + default: + LOG(WARNING) << "Unknown RData type for recieved record: " << record.type; + return scoped_ptr<const RecordParsed>(); + } + + if (!rdata.get()) + return scoped_ptr<const RecordParsed>(); + + return scoped_ptr<const RecordParsed>(new RecordParsed(record.name, + record.type, + record.klass, + record.ttl, + rdata.Pass(), + time_created)); +} + +bool RecordParsed::IsEqual(const RecordParsed* other, bool is_mdns) const { + DCHECK(other); + uint16 klass = klass_; + uint16 other_klass = other->klass_; + + if (is_mdns) { + klass &= dns_protocol::kMDnsClassMask; + other_klass &= dns_protocol::kMDnsClassMask; + } + + return name_ == other->name_ && + klass == other_klass && + type_ == other->type_ && + rdata_->IsEqual(other->rdata_.get()); +} +} diff --git a/chromium/net/dns/record_parsed.h b/chromium/net/dns/record_parsed.h new file mode 100644 index 00000000000..016c4910bfb --- /dev/null +++ b/chromium/net/dns/record_parsed.h @@ -0,0 +1,64 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_RECORD_PARSED_H_ +#define NET_DNS_RECORD_PARSED_H_ + +#include <string> + +#include "base/memory/scoped_ptr.h" +#include "base/time/time.h" +#include "net/base/net_export.h" + +namespace net { + +class DnsRecordParser; +class RecordRdata; + +// Parsed record. This is a form of DnsResourceRecord where the rdata section +// has been parsed into a data structure. +class NET_EXPORT_PRIVATE RecordParsed { + public: + virtual ~RecordParsed(); + + // All records are inherently immutable. Return a const pointer. + static scoped_ptr<const RecordParsed> CreateFrom(DnsRecordParser* parser, + base::Time time_created); + + const std::string& name() const { return name_; } + uint16 type() const { return type_; } + uint16 klass() const { return klass_; } + uint32 ttl() const { return ttl_; } + + base::Time time_created() const { return time_created_; } + + template <class T> const T* rdata() const { + if (T::kType != type_) + return NULL; + return static_cast<const T*>(rdata_.get()); + } + + // Check if two records have the same data. Ignores time_created and ttl. + // If |is_mdns| is true, ignore the top bit of the class + // (the cache flush bit). + bool IsEqual(const RecordParsed* other, bool is_mdns) const; + + private: + RecordParsed(const std::string& name, uint16 type, uint16 klass, + uint32 ttl, scoped_ptr<const RecordRdata> rdata, + base::Time time_created); + + std::string name_; // in dotted form + const uint16 type_; + const uint16 klass_; + const uint32 ttl_; + + const scoped_ptr<const RecordRdata> rdata_; + + const base::Time time_created_; +}; + +} // namespace net + +#endif // NET_DNS_RECORD_PARSED_H_ diff --git a/chromium/net/dns/record_parsed_unittest.cc b/chromium/net/dns/record_parsed_unittest.cc new file mode 100644 index 00000000000..2864dcbe761 --- /dev/null +++ b/chromium/net/dns/record_parsed_unittest.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/record_parsed.h" + +#include "net/dns/dns_protocol.h" +#include "net/dns/dns_response.h" +#include "net/dns/dns_test_util.h" +#include "net/dns/record_rdata.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +static const uint8 kT1ResponseWithCacheFlushBit[] = { + 0x0a, 'c', 'o', 'd', 'e', 'r', 'e', 'v', 'i', 'e', 'w', + 0x08, 'c', 'h', 'r', 'o', 'm', 'i', 'u', 'm', + 0x03, 'o', 'r', 'g', + 0x00, + 0x00, 0x05, // TYPE is CNAME. + 0x80, 0x01, // CLASS is IN with cache flush bit set. + 0x00, 0x01, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. + 0x24, 0x74, + 0x00, 0x12, // RDLENGTH is 18 bytes. + // ghs.l.google.com in DNS format. + 0x03, 'g', 'h', 's', + 0x01, 'l', + 0x06, 'g', 'o', 'o', 'g', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00 +}; + +TEST(RecordParsedTest, ParseSingleRecord) { + DnsRecordParser parser(kT1ResponseDatagram, sizeof(kT1ResponseDatagram), + sizeof(dns_protocol::Header)); + scoped_ptr<const RecordParsed> record; + const CnameRecordRdata* rdata; + + parser.SkipQuestion(); + record = RecordParsed::CreateFrom(&parser, base::Time()); + EXPECT_TRUE(record != NULL); + + ASSERT_EQ("codereview.chromium.org", record->name()); + ASSERT_EQ(dns_protocol::kTypeCNAME, record->type()); + ASSERT_EQ(dns_protocol::kClassIN, record->klass()); + + rdata = record->rdata<CnameRecordRdata>(); + ASSERT_TRUE(rdata != NULL); + ASSERT_EQ(kT1CanonName, rdata->cname()); + + ASSERT_FALSE(record->rdata<SrvRecordRdata>()); + ASSERT_TRUE(record->IsEqual(record.get(), true)); +} + +TEST(RecordParsedTest, CacheFlushBitCompare) { + DnsRecordParser parser1(kT1ResponseDatagram, sizeof(kT1ResponseDatagram), + sizeof(dns_protocol::Header)); + parser1.SkipQuestion(); + scoped_ptr<const RecordParsed> record1 = + RecordParsed::CreateFrom(&parser1, base::Time()); + + DnsRecordParser parser2(kT1ResponseWithCacheFlushBit, + sizeof(kT1ResponseWithCacheFlushBit), + 0); + + scoped_ptr<const RecordParsed> record2 = + RecordParsed::CreateFrom(&parser2, base::Time()); + + EXPECT_FALSE(record1->IsEqual(record2.get(), false)); + EXPECT_TRUE(record1->IsEqual(record2.get(), true)); + EXPECT_FALSE(record2->IsEqual(record1.get(), false)); + EXPECT_TRUE(record2->IsEqual(record1.get(), true)); +} + +} //namespace net diff --git a/chromium/net/dns/record_rdata.cc b/chromium/net/dns/record_rdata.cc new file mode 100644 index 00000000000..4ebc643a377 --- /dev/null +++ b/chromium/net/dns/record_rdata.cc @@ -0,0 +1,287 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/record_rdata.h" + +#include "net/base/big_endian.h" +#include "net/base/dns_util.h" +#include "net/dns/dns_protocol.h" +#include "net/dns/dns_response.h" + +namespace net { + +static const size_t kSrvRecordMinimumSize = 6; + +RecordRdata::RecordRdata() { +} + +SrvRecordRdata::SrvRecordRdata() : priority_(0), weight_(0), port_(0) { +} + +SrvRecordRdata::~SrvRecordRdata() {} + +// static +scoped_ptr<SrvRecordRdata> SrvRecordRdata::Create( + const base::StringPiece& data, + const DnsRecordParser& parser) { + if (data.size() < kSrvRecordMinimumSize) return scoped_ptr<SrvRecordRdata>(); + + scoped_ptr<SrvRecordRdata> rdata(new SrvRecordRdata); + + BigEndianReader reader(data.data(), data.size()); + // 2 bytes for priority, 2 bytes for weight, 2 bytes for port. + reader.ReadU16(&rdata->priority_); + reader.ReadU16(&rdata->weight_); + reader.ReadU16(&rdata->port_); + + if (!parser.ReadName(data.substr(kSrvRecordMinimumSize).begin(), + &rdata->target_)) + return scoped_ptr<SrvRecordRdata>(); + + return rdata.Pass(); +} + +uint16 SrvRecordRdata::Type() const { + return SrvRecordRdata::kType; +} + +bool SrvRecordRdata::IsEqual(const RecordRdata* other) const { + if (other->Type() != Type()) return false; + const SrvRecordRdata* srv_other = static_cast<const SrvRecordRdata*>(other); + return weight_ == srv_other->weight_ && + port_ == srv_other->port_ && + priority_ == srv_other->priority_ && + target_ == srv_other->target_; +} + +ARecordRdata::ARecordRdata() { +} + +ARecordRdata::~ARecordRdata() { +} + +// static +scoped_ptr<ARecordRdata> ARecordRdata::Create( + const base::StringPiece& data, + const DnsRecordParser& parser) { + if (data.size() != kIPv4AddressSize) + return scoped_ptr<ARecordRdata>(); + + scoped_ptr<ARecordRdata> rdata(new ARecordRdata); + + rdata->address_.resize(kIPv4AddressSize); + for (unsigned i = 0; i < kIPv4AddressSize; ++i) { + rdata->address_[i] = data[i]; + } + + return rdata.Pass(); +} + +uint16 ARecordRdata::Type() const { + return ARecordRdata::kType; +} + +bool ARecordRdata::IsEqual(const RecordRdata* other) const { + if (other->Type() != Type()) return false; + const ARecordRdata* a_other = static_cast<const ARecordRdata*>(other); + return address_ == a_other->address_; +} + +AAAARecordRdata::AAAARecordRdata() { +} + +AAAARecordRdata::~AAAARecordRdata() { +} + +// static +scoped_ptr<AAAARecordRdata> AAAARecordRdata::Create( + const base::StringPiece& data, + const DnsRecordParser& parser) { + if (data.size() != kIPv6AddressSize) + return scoped_ptr<AAAARecordRdata>(); + + scoped_ptr<AAAARecordRdata> rdata(new AAAARecordRdata); + + rdata->address_.resize(kIPv6AddressSize); + for (unsigned i = 0; i < kIPv6AddressSize; ++i) { + rdata->address_[i] = data[i]; + } + + return rdata.Pass(); +} + +uint16 AAAARecordRdata::Type() const { + return AAAARecordRdata::kType; +} + +bool AAAARecordRdata::IsEqual(const RecordRdata* other) const { + if (other->Type() != Type()) return false; + const AAAARecordRdata* a_other = static_cast<const AAAARecordRdata*>(other); + return address_ == a_other->address_; +} + +CnameRecordRdata::CnameRecordRdata() { +} + +CnameRecordRdata::~CnameRecordRdata() { +} + +// static +scoped_ptr<CnameRecordRdata> CnameRecordRdata::Create( + const base::StringPiece& data, + const DnsRecordParser& parser) { + scoped_ptr<CnameRecordRdata> rdata(new CnameRecordRdata); + + if (!parser.ReadName(data.begin(), &rdata->cname_)) + return scoped_ptr<CnameRecordRdata>(); + + return rdata.Pass(); +} + +uint16 CnameRecordRdata::Type() const { + return CnameRecordRdata::kType; +} + +bool CnameRecordRdata::IsEqual(const RecordRdata* other) const { + if (other->Type() != Type()) return false; + const CnameRecordRdata* cname_other = + static_cast<const CnameRecordRdata*>(other); + return cname_ == cname_other->cname_; +} + +PtrRecordRdata::PtrRecordRdata() { +} + +PtrRecordRdata::~PtrRecordRdata() { +} + +// static +scoped_ptr<PtrRecordRdata> PtrRecordRdata::Create( + const base::StringPiece& data, + const DnsRecordParser& parser) { + scoped_ptr<PtrRecordRdata> rdata(new PtrRecordRdata); + + if (!parser.ReadName(data.begin(), &rdata->ptrdomain_)) + return scoped_ptr<PtrRecordRdata>(); + + return rdata.Pass(); +} + +uint16 PtrRecordRdata::Type() const { + return PtrRecordRdata::kType; +} + +bool PtrRecordRdata::IsEqual(const RecordRdata* other) const { + if (other->Type() != Type()) return false; + const PtrRecordRdata* ptr_other = static_cast<const PtrRecordRdata*>(other); + return ptrdomain_ == ptr_other->ptrdomain_; +} + +TxtRecordRdata::TxtRecordRdata() { +} + +TxtRecordRdata::~TxtRecordRdata() { +} + +// static +scoped_ptr<TxtRecordRdata> TxtRecordRdata::Create( + const base::StringPiece& data, + const DnsRecordParser& parser) { + scoped_ptr<TxtRecordRdata> rdata(new TxtRecordRdata); + + for (size_t i = 0; i < data.size(); ) { + uint8 length = data[i]; + + if (i + length >= data.size()) + return scoped_ptr<TxtRecordRdata>(); + + rdata->texts_.push_back(data.substr(i + 1, length).as_string()); + + // Move to the next string. + i += length + 1; + } + + return rdata.Pass(); +} + +uint16 TxtRecordRdata::Type() const { + return TxtRecordRdata::kType; +} + +bool TxtRecordRdata::IsEqual(const RecordRdata* other) const { + if (other->Type() != Type()) return false; + const TxtRecordRdata* txt_other = static_cast<const TxtRecordRdata*>(other); + return texts_ == txt_other->texts_; +} + +NsecRecordRdata::NsecRecordRdata() { +} + +NsecRecordRdata::~NsecRecordRdata() { +} + +// static +scoped_ptr<NsecRecordRdata> NsecRecordRdata::Create( + const base::StringPiece& data, + const DnsRecordParser& parser) { + scoped_ptr<NsecRecordRdata> rdata(new NsecRecordRdata); + + // Read the "next domain". This part for the NSEC record format is + // ignored for mDNS, since it has no semantic meaning. + unsigned next_domain_length = parser.ReadName(data.data(), NULL); + + // If we did not succeed in getting the next domain or the data length + // is too short for reading the bitmap header, return. + if (next_domain_length == 0 || data.length() < next_domain_length + 2) + return scoped_ptr<NsecRecordRdata>(); + + struct BitmapHeader { + uint8 block_number; // The block number should be zero. + uint8 length; // Bitmap length in bytes. Between 1 and 32. + }; + + const BitmapHeader* header = reinterpret_cast<const BitmapHeader*>( + data.data() + next_domain_length); + + // The block number must be zero in mDns-specific NSEC records. The bitmap + // length must be between 1 and 32. + if (header->block_number != 0 || header->length == 0 || header->length > 32) + return scoped_ptr<NsecRecordRdata>(); + + base::StringPiece bitmap_data = data.substr(next_domain_length + 2); + + // Since we may only have one block, the data length must be exactly equal to + // the domain length plus bitmap size. + if (bitmap_data.length() != header->length) + return scoped_ptr<NsecRecordRdata>(); + + rdata->bitmap_.insert(rdata->bitmap_.begin(), + bitmap_data.begin(), + bitmap_data.end()); + + return rdata.Pass(); +} + +uint16 NsecRecordRdata::Type() const { + return NsecRecordRdata::kType; +} + +bool NsecRecordRdata::IsEqual(const RecordRdata* other) const { + if (other->Type() != Type()) + return false; + const NsecRecordRdata* nsec_other = + static_cast<const NsecRecordRdata*>(other); + return bitmap_ == nsec_other->bitmap_; +} + +bool NsecRecordRdata::GetBit(unsigned i) const { + unsigned byte_num = i/8; + if (bitmap_.size() < byte_num + 1) + return false; + + unsigned bit_num = 7 - i % 8; + return (bitmap_[byte_num] & (1 << bit_num)) != 0; +} + +} // namespace net diff --git a/chromium/net/dns/record_rdata.h b/chromium/net/dns/record_rdata.h new file mode 100644 index 00000000000..f83a48650e1 --- /dev/null +++ b/chromium/net/dns/record_rdata.h @@ -0,0 +1,217 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_RECORD_RDATA_H_ +#define NET_DNS_RECORD_RDATA_H_ + +#include <string> +#include <vector> + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/memory/scoped_ptr.h" +#include "base/strings/string_piece.h" +#include "net/base/big_endian.h" +#include "net/base/net_export.h" +#include "net/base/net_util.h" +#include "net/dns/dns_protocol.h" + +namespace net { + +class DnsRecordParser; + +// Parsed represenation of the extra data in a record. Does not include standard +// DNS record data such as TTL, Name, Type and Class. +class NET_EXPORT_PRIVATE RecordRdata { + public: + virtual ~RecordRdata() {} + + virtual bool IsEqual(const RecordRdata* other) const = 0; + virtual uint16 Type() const = 0; + + protected: + RecordRdata(); + + DISALLOW_COPY_AND_ASSIGN(RecordRdata); +}; + +// SRV record format (http://www.ietf.org/rfc/rfc2782.txt): +// 2 bytes network-order unsigned priority +// 2 bytes network-order unsigned weight +// 2 bytes network-order unsigned port +// target: domain name (on-the-wire representation) +class NET_EXPORT_PRIVATE SrvRecordRdata : public RecordRdata { + public: + static const uint16 kType = dns_protocol::kTypeSRV; + + virtual ~SrvRecordRdata(); + static scoped_ptr<SrvRecordRdata> Create(const base::StringPiece& data, + const DnsRecordParser& parser); + + virtual bool IsEqual(const RecordRdata* other) const OVERRIDE; + virtual uint16 Type() const OVERRIDE; + + uint16 priority() const { return priority_; } + uint16 weight() const { return weight_; } + uint16 port() const { return port_; } + + const std::string& target() const { return target_; } + + private: + SrvRecordRdata(); + + uint16 priority_; + uint16 weight_; + uint16 port_; + + std::string target_; + + DISALLOW_COPY_AND_ASSIGN(SrvRecordRdata); +}; + +// A Record format (http://www.ietf.org/rfc/rfc1035.txt): +// 4 bytes for IP address. +class NET_EXPORT_PRIVATE ARecordRdata : public RecordRdata { + public: + static const uint16 kType = dns_protocol::kTypeA; + + virtual ~ARecordRdata(); + static scoped_ptr<ARecordRdata> Create(const base::StringPiece& data, + const DnsRecordParser& parser); + virtual bool IsEqual(const RecordRdata* other) const OVERRIDE; + virtual uint16 Type() const OVERRIDE; + + const IPAddressNumber& address() const { return address_; } + + private: + ARecordRdata(); + + IPAddressNumber address_; + + DISALLOW_COPY_AND_ASSIGN(ARecordRdata); +}; + +// AAAA Record format (http://www.ietf.org/rfc/rfc1035.txt): +// 16 bytes for IP address. +class NET_EXPORT_PRIVATE AAAARecordRdata : public RecordRdata { + public: + static const uint16 kType = dns_protocol::kTypeAAAA; + + virtual ~AAAARecordRdata(); + static scoped_ptr<AAAARecordRdata> Create(const base::StringPiece& data, + const DnsRecordParser& parser); + virtual bool IsEqual(const RecordRdata* other) const OVERRIDE; + virtual uint16 Type() const OVERRIDE; + + const IPAddressNumber& address() const { return address_; } + + private: + AAAARecordRdata(); + + IPAddressNumber address_; + + DISALLOW_COPY_AND_ASSIGN(AAAARecordRdata); +}; + +// CNAME record format (http://www.ietf.org/rfc/rfc1035.txt): +// cname: On the wire representation of domain name. +class NET_EXPORT_PRIVATE CnameRecordRdata : public RecordRdata { + public: + static const uint16 kType = dns_protocol::kTypeCNAME; + + virtual ~CnameRecordRdata(); + static scoped_ptr<CnameRecordRdata> Create(const base::StringPiece& data, + const DnsRecordParser& parser); + virtual bool IsEqual(const RecordRdata* other) const OVERRIDE; + virtual uint16 Type() const OVERRIDE; + + std::string cname() const { return cname_; } + + private: + CnameRecordRdata(); + + std::string cname_; + + DISALLOW_COPY_AND_ASSIGN(CnameRecordRdata); +}; + +// PTR record format (http://www.ietf.org/rfc/rfc1035.txt): +// domain: On the wire representation of domain name. +class NET_EXPORT_PRIVATE PtrRecordRdata : public RecordRdata { + public: + static const uint16 kType = dns_protocol::kTypePTR; + + virtual ~PtrRecordRdata(); + static scoped_ptr<PtrRecordRdata> Create(const base::StringPiece& data, + const DnsRecordParser& parser); + virtual bool IsEqual(const RecordRdata* other) const OVERRIDE; + virtual uint16 Type() const OVERRIDE; + + std::string ptrdomain() const { return ptrdomain_; } + + private: + PtrRecordRdata(); + + std::string ptrdomain_; + + DISALLOW_COPY_AND_ASSIGN(PtrRecordRdata); +}; + +// TXT record format (http://www.ietf.org/rfc/rfc1035.txt): +// texts: One or more <character-string>s. +// a <character-string> is a length octet followed by as many characters. +class NET_EXPORT_PRIVATE TxtRecordRdata : public RecordRdata { + public: + static const uint16 kType = dns_protocol::kTypeTXT; + + virtual ~TxtRecordRdata(); + static scoped_ptr<TxtRecordRdata> Create(const base::StringPiece& data, + const DnsRecordParser& parser); + virtual bool IsEqual(const RecordRdata* other) const OVERRIDE; + virtual uint16 Type() const OVERRIDE; + + const std::vector<std::string>& texts() const { return texts_; } + + private: + TxtRecordRdata(); + + std::vector<std::string> texts_; + + DISALLOW_COPY_AND_ASSIGN(TxtRecordRdata); +}; + +// Only the subset of the NSEC record format required by mDNS is supported. +// Nsec record format is described in http://www.ietf.org/rfc/rfc3845.txt and +// the limited version required for mDNS described in +// http://www.rfc-editor.org/rfc/rfc6762.txt Section 6.1. +class NET_EXPORT_PRIVATE NsecRecordRdata : public RecordRdata { + public: + static const uint16 kType = dns_protocol::kTypeNSEC; + + virtual ~NsecRecordRdata(); + static scoped_ptr<NsecRecordRdata> Create(const base::StringPiece& data, + const DnsRecordParser& parser); + virtual bool IsEqual(const RecordRdata* other) const OVERRIDE; + virtual uint16 Type() const OVERRIDE; + + // Length of the bitmap in bits. + unsigned bitmap_length() const { return bitmap_.size() * 8; } + + // Returns bit i-th bit in the bitmap, where bits withing a byte are organized + // most to least significant. If it is set, a record with rrtype i exists for + // the domain name of this nsec record. + bool GetBit(unsigned i) const; + + private: + NsecRecordRdata(); + + std::vector<uint8> bitmap_; + + DISALLOW_COPY_AND_ASSIGN(NsecRecordRdata); +}; + + +} // namespace net + +#endif // NET_DNS_RECORD_RDATA_H_ diff --git a/chromium/net/dns/record_rdata_unittest.cc b/chromium/net/dns/record_rdata_unittest.cc new file mode 100644 index 00000000000..90bac446e2e --- /dev/null +++ b/chromium/net/dns/record_rdata_unittest.cc @@ -0,0 +1,222 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "base/memory/scoped_ptr.h" +#include "net/base/net_util.h" +#include "net/dns/dns_response.h" +#include "net/dns/record_rdata.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +base::StringPiece MakeStringPiece(const uint8* data, unsigned size) { + const char* data_cc = reinterpret_cast<const char*>(data); + return base::StringPiece(data_cc, size); +} + +TEST(RecordRdataTest, ParseSrvRecord) { + scoped_ptr<SrvRecordRdata> record1_obj; + scoped_ptr<SrvRecordRdata> record2_obj; + + // These are just the rdata portions of the DNS records, rather than complete + // records, but it works well enough for this test. + + const uint8 record[] = { + 0x00, 0x01, + 0x00, 0x02, + 0x00, 0x50, + 0x03, 'w', 'w', 'w', + 0x06, 'g', 'o', 'o', 'g', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00, + 0x01, 0x01, + 0x01, 0x02, + 0x01, 0x03, + 0x04, 'w', 'w', 'w', '2', + 0xc0, 0x0a, // Pointer to "google.com" + }; + + DnsRecordParser parser(record, sizeof(record), 0); + const unsigned first_record_len = 22; + base::StringPiece record1_strpiece = MakeStringPiece( + record, first_record_len); + base::StringPiece record2_strpiece = MakeStringPiece( + record + first_record_len, sizeof(record) - first_record_len); + + record1_obj = SrvRecordRdata::Create(record1_strpiece, parser); + ASSERT_TRUE(record1_obj != NULL); + ASSERT_EQ(1, record1_obj->priority()); + ASSERT_EQ(2, record1_obj->weight()); + ASSERT_EQ(80, record1_obj->port()); + + ASSERT_EQ("www.google.com", record1_obj->target()); + + record2_obj = SrvRecordRdata::Create(record2_strpiece, parser); + ASSERT_TRUE(record2_obj != NULL); + ASSERT_EQ(257, record2_obj->priority()); + ASSERT_EQ(258, record2_obj->weight()); + ASSERT_EQ(259, record2_obj->port()); + + ASSERT_EQ("www2.google.com", record2_obj->target()); + + ASSERT_TRUE(record1_obj->IsEqual(record1_obj.get())); + ASSERT_FALSE(record1_obj->IsEqual(record2_obj.get())); +} + +TEST(RecordRdataTest, ParseARecord) { + scoped_ptr<ARecordRdata> record_obj; + + // These are just the rdata portions of the DNS records, rather than complete + // records, but it works well enough for this test. + + const uint8 record[] = { + 0x7F, 0x00, 0x00, 0x01 // 127.0.0.1 + }; + + DnsRecordParser parser(record, sizeof(record), 0); + base::StringPiece record_strpiece = MakeStringPiece(record, sizeof(record)); + + record_obj = ARecordRdata::Create(record_strpiece, parser); + ASSERT_TRUE(record_obj != NULL); + + ASSERT_EQ("127.0.0.1", IPAddressToString(record_obj->address())); + + ASSERT_TRUE(record_obj->IsEqual(record_obj.get())); +} + +TEST(RecordRdataTest, ParseAAAARecord) { + scoped_ptr<AAAARecordRdata> record_obj; + + // These are just the rdata portions of the DNS records, rather than complete + // records, but it works well enough for this test. + + const uint8 record[] = { + 0x12, 0x34, 0x56, 0x78, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x09 // 1234:5678::9A + }; + + DnsRecordParser parser(record, sizeof(record), 0); + base::StringPiece record_strpiece = MakeStringPiece(record, sizeof(record)); + + record_obj = AAAARecordRdata::Create(record_strpiece, parser); + ASSERT_TRUE(record_obj != NULL); + + ASSERT_EQ("1234:5678::9", + IPAddressToString(record_obj->address())); + + ASSERT_TRUE(record_obj->IsEqual(record_obj.get())); +} + +TEST(RecordRdataTest, ParseCnameRecord) { + scoped_ptr<CnameRecordRdata> record_obj; + + // These are just the rdata portions of the DNS records, rather than complete + // records, but it works well enough for this test. + + const uint8 record[] = { + 0x03, 'w', 'w', 'w', + 0x06, 'g', 'o', 'o', 'g', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00 + }; + + DnsRecordParser parser(record, sizeof(record), 0); + base::StringPiece record_strpiece = MakeStringPiece(record, sizeof(record)); + + record_obj = CnameRecordRdata::Create(record_strpiece, parser); + ASSERT_TRUE(record_obj != NULL); + + ASSERT_EQ("www.google.com", record_obj->cname()); + + ASSERT_TRUE(record_obj->IsEqual(record_obj.get())); +} + +TEST(RecordRdataTest, ParsePtrRecord) { + scoped_ptr<PtrRecordRdata> record_obj; + + // These are just the rdata portions of the DNS records, rather than complete + // records, but it works well enough for this test. + + const uint8 record[] = { + 0x03, 'w', 'w', 'w', + 0x06, 'g', 'o', 'o', 'g', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00 + }; + + DnsRecordParser parser(record, sizeof(record), 0); + base::StringPiece record_strpiece = MakeStringPiece(record, sizeof(record)); + + record_obj = PtrRecordRdata::Create(record_strpiece, parser); + ASSERT_TRUE(record_obj != NULL); + + ASSERT_EQ("www.google.com", record_obj->ptrdomain()); + + ASSERT_TRUE(record_obj->IsEqual(record_obj.get())); +} + +TEST(RecordRdataTest, ParseTxtRecord) { + scoped_ptr<TxtRecordRdata> record_obj; + + // These are just the rdata portions of the DNS records, rather than complete + // records, but it works well enough for this test. + + const uint8 record[] = { + 0x03, 'w', 'w', 'w', + 0x06, 'g', 'o', 'o', 'g', 'l', 'e', + 0x03, 'c', 'o', 'm' + }; + + DnsRecordParser parser(record, sizeof(record), 0); + base::StringPiece record_strpiece = MakeStringPiece(record, sizeof(record)); + + record_obj = TxtRecordRdata::Create(record_strpiece, parser); + ASSERT_TRUE(record_obj != NULL); + + std::vector<std::string> expected; + expected.push_back("www"); + expected.push_back("google"); + expected.push_back("com"); + + ASSERT_EQ(expected, record_obj->texts()); + + ASSERT_TRUE(record_obj->IsEqual(record_obj.get())); +} + +TEST(RecordRdataTest, ParseNsecRecord) { + scoped_ptr<NsecRecordRdata> record_obj; + + // These are just the rdata portions of the DNS records, rather than complete + // records, but it works well enough for this test. + + const uint8 record[] = { + 0x03, 'w', 'w', 'w', + 0x06, 'g', 'o', 'o', 'g', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00, + 0x00, 0x02, 0x40, 0x01 + }; + + DnsRecordParser parser(record, sizeof(record), 0); + base::StringPiece record_strpiece = MakeStringPiece(record, sizeof(record)); + + record_obj = NsecRecordRdata::Create(record_strpiece, parser); + ASSERT_TRUE(record_obj != NULL); + + ASSERT_EQ(16u, record_obj->bitmap_length()); + + EXPECT_FALSE(record_obj->GetBit(0)); + EXPECT_TRUE(record_obj->GetBit(1)); + for (int i = 2; i < 15; i++) { + EXPECT_FALSE(record_obj->GetBit(i)); + } + EXPECT_TRUE(record_obj->GetBit(15)); + + ASSERT_TRUE(record_obj->IsEqual(record_obj.get())); +} + + +} // namespace net diff --git a/chromium/net/dns/serial_worker.cc b/chromium/net/dns/serial_worker.cc new file mode 100644 index 00000000000..394721c1a65 --- /dev/null +++ b/chromium/net/dns/serial_worker.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/serial_worker.h" + +#include "base/bind.h" +#include "base/location.h" +#include "base/message_loop/message_loop_proxy.h" +#include "base/threading/worker_pool.h" + +namespace net { + +namespace { + // Delay between calls to WorkerPool::PostTask + const int kWorkerPoolRetryDelayMs = 100; +} + +SerialWorker::SerialWorker() + : message_loop_(base::MessageLoopProxy::current()), + state_(IDLE) {} + +SerialWorker::~SerialWorker() {} + +void SerialWorker::WorkNow() { + DCHECK(message_loop_->BelongsToCurrentThread()); + switch (state_) { + case IDLE: + if (!base::WorkerPool::PostTask(FROM_HERE, base::Bind( + &SerialWorker::DoWorkJob, this), false)) { +#if defined(OS_POSIX) + // See worker_pool_posix.cc. + NOTREACHED() << "WorkerPool::PostTask is not expected to fail on posix"; +#else + LOG(WARNING) << "Failed to WorkerPool::PostTask, will retry later"; + message_loop_->PostDelayedTask( + FROM_HERE, + base::Bind(&SerialWorker::RetryWork, this), + base::TimeDelta::FromMilliseconds(kWorkerPoolRetryDelayMs)); + state_ = WAITING; + return; +#endif + } + state_ = WORKING; + return; + case WORKING: + // Remember to re-read after |DoRead| finishes. + state_ = PENDING; + return; + case CANCELLED: + case PENDING: + case WAITING: + return; + default: + NOTREACHED() << "Unexpected state " << state_; + } +} + +void SerialWorker::Cancel() { + DCHECK(message_loop_->BelongsToCurrentThread()); + state_ = CANCELLED; +} + +void SerialWorker::DoWorkJob() { + this->DoWork(); + // If this fails, the loop is gone, so there is no point retrying. + message_loop_->PostTask(FROM_HERE, base::Bind( + &SerialWorker::OnWorkJobFinished, this)); +} + +void SerialWorker::OnWorkJobFinished() { + DCHECK(message_loop_->BelongsToCurrentThread()); + switch (state_) { + case CANCELLED: + return; + case WORKING: + state_ = IDLE; + this->OnWorkFinished(); + return; + case PENDING: + state_ = IDLE; + WorkNow(); + return; + default: + NOTREACHED() << "Unexpected state " << state_; + } +} + +void SerialWorker::RetryWork() { + DCHECK(message_loop_->BelongsToCurrentThread()); + switch (state_) { + case CANCELLED: + return; + case WAITING: + state_ = IDLE; + WorkNow(); + return; + default: + NOTREACHED() << "Unexpected state " << state_; + } +} + +} // namespace net + diff --git a/chromium/net/dns/serial_worker.h b/chromium/net/dns/serial_worker.h new file mode 100644 index 00000000000..59a4d3f189c --- /dev/null +++ b/chromium/net/dns/serial_worker.h @@ -0,0 +1,96 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_SERIAL_WORKER_H_ +#define NET_DNS_SERIAL_WORKER_H_ + +#include <string> + +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "net/base/net_export.h" + +// Forward declaration +namespace base { +class MessageLoopProxy; +} + +namespace net { + +// SerialWorker executes a job on WorkerPool serially -- **once at a time**. +// On |WorkNow|, a call to |DoWork| is scheduled on the WorkerPool. Once it +// completes, |OnWorkFinished| is called on the origin thread. +// If |WorkNow| is called (1 or more times) while |DoWork| is already under way, +// |DoWork| will be called once: after current |DoWork| completes, before a +// call to |OnWorkFinished|. +// +// This behavior is designed for updating a result after some trigger, for +// example reading a file once FilePathWatcher indicates it changed. +// +// Derived classes should store results of work done in |DoWork| in dedicated +// fields and read them in |OnWorkFinished| which is executed on the origin +// thread. This avoids the need to template this class. +// +// This implementation avoids locking by using the |state_| member to ensure +// that |DoWork| and |OnWorkFinished| cannot execute in parallel. +// +// TODO(szym): update to WorkerPool::PostTaskAndReply once available. +class NET_EXPORT_PRIVATE SerialWorker + : NON_EXPORTED_BASE(public base::RefCountedThreadSafe<SerialWorker>) { + public: + SerialWorker(); + + // Unless already scheduled, post |DoWork| to WorkerPool. + // Made virtual to allow mocking. + virtual void WorkNow(); + + // Stop scheduling jobs. + void Cancel(); + + bool IsCancelled() const { return state_ == CANCELLED; } + + protected: + friend class base::RefCountedThreadSafe<SerialWorker>; + // protected to allow sub-classing, but prevent deleting + virtual ~SerialWorker(); + + // Executed on WorkerPool, at most once at a time. + virtual void DoWork() = 0; + + // Executed on origin thread after |DoRead| completes. + virtual void OnWorkFinished() = 0; + + base::MessageLoopProxy* loop() { return message_loop_.get(); } + + private: + enum State { + CANCELLED = -1, + IDLE = 0, + WORKING, // |DoWorkJob| posted on WorkerPool, until |OnWorkJobFinished| + PENDING, // |WorkNow| while WORKING, must re-do work + WAITING, // WorkerPool is busy, |RetryWork| is posted + }; + + // Called on the worker thread, executes |DoWork| and notifies the origin + // thread. + void DoWorkJob(); + + // Called on the the origin thread after |DoWork| completes. + void OnWorkJobFinished(); + + // Posted to message loop in case WorkerPool is busy. (state == WAITING) + void RetryWork(); + + // Message loop for the thread of origin. + scoped_refptr<base::MessageLoopProxy> message_loop_; + + State state_; + + DISALLOW_COPY_AND_ASSIGN(SerialWorker); +}; + +} // namespace net + +#endif // NET_DNS_SERIAL_WORKER_H_ + diff --git a/chromium/net/dns/serial_worker_unittest.cc b/chromium/net/dns/serial_worker_unittest.cc new file mode 100644 index 00000000000..442526f29d2 --- /dev/null +++ b/chromium/net/dns/serial_worker_unittest.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/serial_worker.h" + +#include "base/bind.h" +#include "base/message_loop/message_loop.h" +#include "base/synchronization/lock.h" +#include "base/synchronization/waitable_event.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +class SerialWorkerTest : public testing::Test { + public: + // The class under test + class TestSerialWorker : public SerialWorker { + public: + explicit TestSerialWorker(SerialWorkerTest* t) + : test_(t) {} + virtual void DoWork() OVERRIDE { + ASSERT_TRUE(test_); + test_->OnWork(); + } + virtual void OnWorkFinished() OVERRIDE { + ASSERT_TRUE(test_); + test_->OnWorkFinished(); + } + private: + virtual ~TestSerialWorker() {} + SerialWorkerTest* test_; + }; + + // Mocks + + void OnWork() { + { // Check that OnWork is executed serially. + base::AutoLock lock(work_lock_); + EXPECT_FALSE(work_running_) << "DoRead is not called serially!"; + work_running_ = true; + } + BreakNow("OnWork"); + work_allowed_.Wait(); + // Calling from WorkerPool, but protected by work_allowed_/work_called_. + output_value_ = input_value_; + + { // This lock might be destroyed after work_called_ is signalled. + base::AutoLock lock(work_lock_); + work_running_ = false; + } + work_called_.Signal(); + } + + void OnWorkFinished() { + EXPECT_TRUE(message_loop_ == base::MessageLoop::current()); + EXPECT_EQ(output_value_, input_value_); + BreakNow("OnWorkFinished"); + } + + protected: + void BreakCallback(std::string breakpoint) { + breakpoint_ = breakpoint; + base::MessageLoop::current()->QuitNow(); + } + + void BreakNow(std::string b) { + message_loop_->PostTask(FROM_HERE, + base::Bind(&SerialWorkerTest::BreakCallback, + base::Unretained(this), b)); + } + + void RunUntilBreak(std::string b) { + message_loop_->Run(); + ASSERT_EQ(breakpoint_, b); + } + + SerialWorkerTest() + : input_value_(0), + output_value_(-1), + work_allowed_(false, false), + work_called_(false, false), + work_running_(false) { + } + + // Helpers for tests. + + // Lets OnWork run and waits for it to complete. Can only return if OnWork is + // executed on a concurrent thread. + void WaitForWork() { + RunUntilBreak("OnWork"); + work_allowed_.Signal(); + work_called_.Wait(); + } + + // test::Test methods + virtual void SetUp() OVERRIDE { + message_loop_ = base::MessageLoop::current(); + worker_ = new TestSerialWorker(this); + } + + virtual void TearDown() OVERRIDE { + // Cancel the worker to catch if it makes a late DoWork call. + worker_->Cancel(); + // Check if OnWork is stalled. + EXPECT_FALSE(work_running_) << "OnWork should be done by TearDown"; + // Release it for cleanliness. + if (work_running_) { + WaitForWork(); + } + } + + // Input value read on WorkerPool. + int input_value_; + // Output value written on WorkerPool. + int output_value_; + + // read is called on WorkerPool so we need to synchronize with it. + base::WaitableEvent work_allowed_; + base::WaitableEvent work_called_; + + // Protected by read_lock_. Used to verify that read calls are serialized. + bool work_running_; + base::Lock work_lock_; + + // Loop for this thread. + base::MessageLoop* message_loop_; + + // WatcherDelegate under test. + scoped_refptr<TestSerialWorker> worker_; + + std::string breakpoint_; +}; + +TEST_F(SerialWorkerTest, ExecuteAndSerializeReads) { + for (int i = 0; i < 3; ++i) { + ++input_value_; + worker_->WorkNow(); + WaitForWork(); + RunUntilBreak("OnWorkFinished"); + + EXPECT_TRUE(message_loop_->IsIdleForTesting()); + } + + // Schedule two calls. OnWork checks if it is called serially. + ++input_value_; + worker_->WorkNow(); + // read is blocked, so this will have to induce re-work + worker_->WorkNow(); + WaitForWork(); + WaitForWork(); + RunUntilBreak("OnWorkFinished"); + + // No more tasks should remain. + EXPECT_TRUE(message_loop_->IsIdleForTesting()); +} + +} // namespace + +} // namespace net + diff --git a/chromium/net/dns/single_request_host_resolver.cc b/chromium/net/dns/single_request_host_resolver.cc new file mode 100644 index 00000000000..31ef4c56b6e --- /dev/null +++ b/chromium/net/dns/single_request_host_resolver.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/single_request_host_resolver.h" + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/compiler_specific.h" +#include "base/logging.h" +#include "net/base/net_errors.h" + +namespace net { + +SingleRequestHostResolver::SingleRequestHostResolver(HostResolver* resolver) + : resolver_(resolver), + cur_request_(NULL), + callback_( + base::Bind(&SingleRequestHostResolver::OnResolveCompletion, + base::Unretained(this))) { + DCHECK(resolver_ != NULL); +} + +SingleRequestHostResolver::~SingleRequestHostResolver() { + Cancel(); +} + +int SingleRequestHostResolver::Resolve( + const HostResolver::RequestInfo& info, AddressList* addresses, + const CompletionCallback& callback, const BoundNetLog& net_log) { + DCHECK(addresses); + DCHECK_EQ(false, callback.is_null()); + DCHECK(cur_request_callback_.is_null()) << "resolver already in use"; + + HostResolver::RequestHandle request = NULL; + + // We need to be notified of completion before |callback| is called, so that + // we can clear out |cur_request_*|. + CompletionCallback transient_callback = + callback.is_null() ? CompletionCallback() : callback_; + + int rv = resolver_->Resolve( + info, addresses, transient_callback, &request, net_log); + + if (rv == ERR_IO_PENDING) { + DCHECK_EQ(false, callback.is_null()); + // Cleared in OnResolveCompletion(). + cur_request_ = request; + cur_request_callback_ = callback; + } + + return rv; +} + +void SingleRequestHostResolver::Cancel() { + if (!cur_request_callback_.is_null()) { + resolver_->CancelRequest(cur_request_); + cur_request_ = NULL; + cur_request_callback_.Reset(); + } +} + +void SingleRequestHostResolver::OnResolveCompletion(int result) { + DCHECK(cur_request_); + DCHECK_EQ(false, cur_request_callback_.is_null()); + + CompletionCallback callback = cur_request_callback_; + + // Clear the outstanding request information. + cur_request_ = NULL; + cur_request_callback_.Reset(); + + // Call the user's original callback. + callback.Run(result); +} + +} // namespace net diff --git a/chromium/net/dns/single_request_host_resolver.h b/chromium/net/dns/single_request_host_resolver.h new file mode 100644 index 00000000000..52d01328911 --- /dev/null +++ b/chromium/net/dns/single_request_host_resolver.h @@ -0,0 +1,56 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DNS_SINGLE_REQUEST_HOST_RESOLVER_H_ +#define NET_DNS_SINGLE_REQUEST_HOST_RESOLVER_H_ + +#include "net/dns/host_resolver.h" + +namespace net { + +// This class represents the task of resolving a hostname (or IP address +// literal) to an AddressList object. It wraps HostResolver to resolve only a +// single hostname at a time and cancels this request when going out of scope. +class NET_EXPORT SingleRequestHostResolver { + public: + // |resolver| must remain valid for the lifetime of |this|. + explicit SingleRequestHostResolver(HostResolver* resolver); + + // If a completion callback is pending when the resolver is destroyed, the + // host resolution is cancelled, and the completion callback will not be + // called. + ~SingleRequestHostResolver(); + + // Resolves the given hostname (or IP address literal), filling out the + // |addresses| object upon success. See HostResolver::Resolve() for details. + int Resolve(const HostResolver::RequestInfo& info, + AddressList* addresses, + const CompletionCallback& callback, + const BoundNetLog& net_log); + + // Cancels the in-progress request, if any. This prevents the callback + // from being invoked. Resolve() can be called again after cancelling. + void Cancel(); + + private: + // Callback for when the request to |resolver_| completes, so we dispatch + // to the user's callback. + void OnResolveCompletion(int result); + + // The actual host resolver that will handle the request. + HostResolver* const resolver_; + + // The current request (if any). + HostResolver::RequestHandle cur_request_; + CompletionCallback cur_request_callback_; + + // Completion callback for when request to |resolver_| completes. + CompletionCallback callback_; + + DISALLOW_COPY_AND_ASSIGN(SingleRequestHostResolver); +}; + +} // namespace net + +#endif // NET_DNS_SINGLE_REQUEST_HOST_RESOLVER_H_ diff --git a/chromium/net/dns/single_request_host_resolver_unittest.cc b/chromium/net/dns/single_request_host_resolver_unittest.cc new file mode 100644 index 00000000000..1b0198f4fde --- /dev/null +++ b/chromium/net/dns/single_request_host_resolver_unittest.cc @@ -0,0 +1,124 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/dns/single_request_host_resolver.h" + +#include "net/base/address_list.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/base/test_completion_callback.h" +#include "net/dns/mock_host_resolver.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +// Helper class used by SingleRequestHostResolverTest.Cancel test. +// It checks that only one request is outstanding at a time, and that +// it is cancelled before the class is destroyed. +class HangingHostResolver : public HostResolver { + public: + HangingHostResolver() : outstanding_request_(NULL) {} + + virtual ~HangingHostResolver() { + EXPECT_TRUE(!has_outstanding_request()); + } + + bool has_outstanding_request() const { + return outstanding_request_ != NULL; + } + + virtual int Resolve(const RequestInfo& info, + AddressList* addresses, + const CompletionCallback& callback, + RequestHandle* out_req, + const BoundNetLog& net_log) OVERRIDE { + EXPECT_FALSE(has_outstanding_request()); + outstanding_request_ = reinterpret_cast<RequestHandle>(0x1234); + *out_req = outstanding_request_; + + // Never complete this request! Caller is expected to cancel it + // before destroying the resolver. + return ERR_IO_PENDING; + } + + virtual int ResolveFromCache(const RequestInfo& info, + AddressList* addresses, + const BoundNetLog& net_log) OVERRIDE { + NOTIMPLEMENTED(); + return ERR_UNEXPECTED; + } + + virtual void CancelRequest(RequestHandle req) OVERRIDE { + EXPECT_TRUE(has_outstanding_request()); + EXPECT_EQ(req, outstanding_request_); + outstanding_request_ = NULL; + } + + private: + RequestHandle outstanding_request_; + + DISALLOW_COPY_AND_ASSIGN(HangingHostResolver); +}; + +// Test that a regular end-to-end lookup returns the expected result. +TEST(SingleRequestHostResolverTest, NormalResolve) { + // Create a host resolver dependency that returns address "199.188.1.166" + // for resolutions of "watsup". + MockHostResolver resolver; + resolver.rules()->AddIPLiteralRule("watsup", "199.188.1.166", std::string()); + + SingleRequestHostResolver single_request_resolver(&resolver); + + // Resolve "watsup:90" using our SingleRequestHostResolver. + AddressList addrlist; + TestCompletionCallback callback; + HostResolver::RequestInfo request(HostPortPair("watsup", 90)); + int rv = single_request_resolver.Resolve( + request, &addrlist, callback.callback(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(OK, callback.WaitForResult()); + + // Verify that the result is what we specified in the MockHostResolver. + ASSERT_FALSE(addrlist.empty()); + EXPECT_EQ("199.188.1.166", addrlist.front().ToStringWithoutPort()); +} + +// Test that the Cancel() method cancels any outstanding request. +TEST(SingleRequestHostResolverTest, Cancel) { + HangingHostResolver resolver; + + { + SingleRequestHostResolver single_request_resolver(&resolver); + + // Resolve "watsup:90" using our SingleRequestHostResolver. + AddressList addrlist; + TestCompletionCallback callback; + HostResolver::RequestInfo request(HostPortPair("watsup", 90)); + int rv = single_request_resolver.Resolve( + request, &addrlist, callback.callback(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_TRUE(resolver.has_outstanding_request()); + } + + // Now that the SingleRequestHostResolver has been destroyed, the + // in-progress request should have been aborted. + EXPECT_FALSE(resolver.has_outstanding_request()); +} + +// Test that the Cancel() method is a no-op when there is no outstanding +// request. +TEST(SingleRequestHostResolverTest, CancelWhileNoPendingRequest) { + HangingHostResolver resolver; + SingleRequestHostResolver single_request_resolver(&resolver); + single_request_resolver.Cancel(); + + // To pass, HangingHostResolver should not have received a cancellation + // request (since there is nothing to cancel). If it does, it will crash. +} + +} // namespace + +} // namespace net |