// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/dns/mock_host_resolver.h" #include #include #include "base/bind.h" #include "base/callback_helpers.h" #include "base/location.h" #include "base/logging.h" #include "base/memory/ptr_util.h" #include "base/memory/ref_counted.h" #include "base/no_destructor.h" #include "base/single_thread_task_runner.h" #include "base/stl_util.h" #include "base/strings/pattern.h" #include "base/strings/string_split.h" #include "base/strings/string_util.h" #include "base/threading/platform_thread.h" #include "base/threading/thread_task_runner_handle.h" #include "base/time/default_tick_clock.h" #include "base/time/tick_clock.h" #include "net/base/host_port_pair.h" #include "net/base/ip_address.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" #include "net/dns/host_cache.h" #include "net/dns/public/resolve_error_info.h" #include "net/url_request/url_request_context.h" #if defined(OS_WIN) #include "net/base/winsock_init.h" #endif namespace net { namespace { // Cache size for the MockCachingHostResolver. const unsigned kMaxCacheEntries = 100; // TTL for the successful resolutions. Failures are not cached. const unsigned kCacheEntryTTLSeconds = 60; } // namespace int ParseAddressList(const std::string& host_list, const std::string& canonical_name, AddressList* addrlist) { *addrlist = AddressList(); addrlist->set_canonical_name(canonical_name); for (const base::StringPiece& address : base::SplitStringPiece( host_list, ",", base::TRIM_WHITESPACE, base::SPLIT_WANT_ALL)) { IPAddress ip_address; if (!ip_address.AssignFromIPLiteral(address)) { LOG(WARNING) << "Not a supported IP literal: " << address.as_string(); return ERR_UNEXPECTED; } addrlist->push_back(IPEndPoint(ip_address, 0)); } return OK; } class MockHostResolverBase::RequestImpl : public HostResolver::ResolveHostRequest { public: RequestImpl(const HostPortPair& request_host, const NetworkIsolationKey& network_isolation_key, const base::Optional& optional_parameters, base::WeakPtr resolver) : request_host_(request_host), network_isolation_key_(network_isolation_key), parameters_(optional_parameters ? optional_parameters.value() : ResolveHostParameters()), priority_(parameters_.initial_priority), host_resolver_flags_(ParametersToHostResolverFlags(parameters_)), resolve_error_info_(ResolveErrorInfo(ERR_IO_PENDING)), id_(0), resolver_(resolver), complete_(false) {} ~RequestImpl() override { if (id_ > 0) { if (resolver_) resolver_->DetachRequest(id_); id_ = 0; resolver_ = nullptr; } } void DetachFromResolver() { id_ = 0; resolver_ = nullptr; } int Start(CompletionOnceCallback callback) override { DCHECK(callback); // Start() may only be called once per request. DCHECK_EQ(0u, id_); DCHECK(!complete_); DCHECK(!callback_); // Parent HostResolver must still be alive to call Start(). DCHECK(resolver_); int rv = resolver_->Resolve(this); DCHECK(!complete_); if (rv == ERR_IO_PENDING) { DCHECK_GT(id_, 0u); callback_ = std::move(callback); } else { DCHECK_EQ(0u, id_); complete_ = true; } return rv; } const base::Optional& GetAddressResults() const override { DCHECK(complete_); return address_results_; } const base::Optional>& GetTextResults() const override { DCHECK(complete_); static const base::NoDestructor>> nullopt_result; return *nullopt_result; } const base::Optional>& GetHostnameResults() const override { DCHECK(complete_); static const base::NoDestructor>> nullopt_result; return *nullopt_result; } net::ResolveErrorInfo GetResolveErrorInfo() const override { DCHECK(complete_); return resolve_error_info_; } const base::Optional& GetStaleInfo() const override { DCHECK(complete_); return staleness_; } void ChangeRequestPriority(RequestPriority priority) override { priority_ = priority; } void SetError(int error) { // Should only be called before request is marked completed. DCHECK(!complete_); resolve_error_info_ = ResolveErrorInfo(error); } void set_address_results( const AddressList& address_results, base::Optional staleness) { // Should only be called at most once and before request is marked // completed. DCHECK(!complete_); DCHECK(!address_results_); DCHECK(!parameters_.is_speculative); address_results_ = address_results; staleness_ = std::move(staleness); } void OnAsyncCompleted(size_t id, int error) { DCHECK_EQ(id_, id); id_ = 0; // Check that error information has been set and that the top-level error // code is valid. DCHECK(resolve_error_info_.error != ERR_IO_PENDING); DCHECK(error == OK || error == ERR_NAME_NOT_RESOLVED); DCHECK(!complete_); complete_ = true; DCHECK(callback_); std::move(callback_).Run(error); } const HostPortPair& request_host() const { return request_host_; } const NetworkIsolationKey& network_isolation_key() const { return network_isolation_key_; } const ResolveHostParameters& parameters() const { return parameters_; } int host_resolver_flags() const { return host_resolver_flags_; } size_t id() { return id_; } RequestPriority priority() const { return priority_; } void set_id(size_t id) { DCHECK_GT(id, 0u); DCHECK_EQ(0u, id_); id_ = id; } bool complete() { return complete_; } private: const HostPortPair request_host_; const NetworkIsolationKey network_isolation_key_; const ResolveHostParameters parameters_; RequestPriority priority_; int host_resolver_flags_; base::Optional address_results_; base::Optional staleness_; ResolveErrorInfo resolve_error_info_; // Used while stored with the resolver for async resolution. Otherwise 0. size_t id_; CompletionOnceCallback callback_; // Use a WeakPtr as the resolver may be destroyed while there are still // outstanding request objects. base::WeakPtr resolver_; bool complete_; DISALLOW_COPY_AND_ASSIGN(RequestImpl); }; class MockHostResolverBase::ProbeRequestImpl : public HostResolver::ProbeRequest { public: explicit ProbeRequestImpl(base::WeakPtr resolver) : resolver_(std::move(resolver)) {} ProbeRequestImpl(const ProbeRequestImpl&) = delete; ProbeRequestImpl& operator=(const ProbeRequestImpl&) = delete; ~ProbeRequestImpl() override { if (resolver_ && resolver_->doh_probe_request_ == this) resolver_->doh_probe_request_ = nullptr; } int Start() override { DCHECK(resolver_); DCHECK(!resolver_->doh_probe_request_); resolver_->doh_probe_request_ = this; return ERR_IO_PENDING; } private: base::WeakPtr resolver_; }; class MockHostResolverBase::MdnsListenerImpl : public HostResolver::MdnsListener { public: MdnsListenerImpl(const HostPortPair& host, DnsQueryType query_type, base::WeakPtr resolver) : host_(host), query_type_(query_type), delegate_(nullptr), resolver_(resolver) { DCHECK_NE(DnsQueryType::UNSPECIFIED, query_type_); DCHECK(resolver_); } ~MdnsListenerImpl() override { if (resolver_) resolver_->RemoveCancelledListener(this); } int Start(Delegate* delegate) override { DCHECK(delegate); DCHECK(!delegate_); DCHECK(resolver_); delegate_ = delegate; resolver_->AddListener(this); return OK; } void TriggerAddressResult(Delegate::UpdateType update_type, IPEndPoint address) { delegate_->OnAddressResult(update_type, query_type_, std::move(address)); } void TriggerTextResult(Delegate::UpdateType update_type, std::vector text_records) { delegate_->OnTextResult(update_type, query_type_, std::move(text_records)); } void TriggerHostnameResult(Delegate::UpdateType update_type, HostPortPair host) { delegate_->OnHostnameResult(update_type, query_type_, std::move(host)); } void TriggerUnhandledResult(Delegate::UpdateType update_type) { delegate_->OnUnhandledResult(update_type, query_type_); } const HostPortPair& host() const { return host_; } DnsQueryType query_type() const { return query_type_; } private: const HostPortPair host_; const DnsQueryType query_type_; Delegate* delegate_; // Use a WeakPtr as the resolver may be destroyed while there are still // outstanding listener objects. base::WeakPtr resolver_; }; MockHostResolverBase::~MockHostResolverBase() { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); // Sanity check that pending requests are always cleaned up, by waiting for // completion, manually cancelling, or calling OnShutdown(). DCHECK(requests_.empty()); } void MockHostResolverBase::OnShutdown() { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); // Cancel all pending requests. for (auto& request : requests_) { request.second->DetachFromResolver(); } requests_.clear(); // Prevent future requests by clearing resolution rules and the cache. rules_map_.clear(); cache_ = nullptr; doh_probe_request_ = nullptr; } std::unique_ptr MockHostResolverBase::CreateRequest( const HostPortPair& host, const NetworkIsolationKey& network_isolation_key, const NetLogWithSource& source_net_log, const base::Optional& optional_parameters) { return std::make_unique(host, network_isolation_key, optional_parameters, AsWeakPtr()); } std::unique_ptr MockHostResolverBase::CreateDohProbeRequest() { return std::make_unique(AsWeakPtr()); } std::unique_ptr MockHostResolverBase::CreateMdnsListener(const HostPortPair& host, DnsQueryType query_type) { return std::make_unique(host, query_type, AsWeakPtr()); } HostCache* MockHostResolverBase::GetHostCache() { return cache_.get(); } int MockHostResolverBase::LoadIntoCache( const HostPortPair& host, const NetworkIsolationKey& network_isolation_key, const base::Optional& optional_parameters) { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); DCHECK(cache_); ResolveHostParameters parameters = optional_parameters.value_or(ResolveHostParameters()); AddressList addresses; base::Optional stale_info; int rv = ResolveFromIPLiteralOrCache( host, network_isolation_key, parameters.dns_query_type, ParametersToHostResolverFlags(parameters), parameters.source, parameters.cache_usage, &addresses, &stale_info); if (rv != ERR_DNS_CACHE_MISS) { // Request already in cache (or IP literal). No need to load it. return rv; } // Just like the real resolver, refuse to do anything with invalid // hostnames. if (!IsValidDNSDomain(host.host())) return ERR_NAME_NOT_RESOLVED; return ResolveProc(host, network_isolation_key, DnsQueryTypeToAddressFamily(parameters.dns_query_type), ParametersToHostResolverFlags(parameters), parameters.source, &addresses); } void MockHostResolverBase::ResolveAllPending() { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); DCHECK(ondemand_mode_); for (auto i = requests_.begin(); i != requests_.end(); ++i) { base::ThreadTaskRunnerHandle::Get()->PostTask( FROM_HERE, base::BindOnce(&MockHostResolverBase::ResolveNow, AsWeakPtr(), i->first)); } } size_t MockHostResolverBase::last_id() { if (requests_.empty()) return 0; return requests_.rbegin()->first; } void MockHostResolverBase::ResolveNow(size_t id) { auto it = requests_.find(id); if (it == requests_.end()) return; // was canceled RequestImpl* req = it->second; requests_.erase(it); AddressList addresses; int error = ResolveProc( req->request_host(), req->network_isolation_key(), DnsQueryTypeToAddressFamily(req->parameters().dns_query_type), req->host_resolver_flags(), req->parameters().source, &addresses); req->SetError(error); if (error == OK && !req->parameters().is_speculative) req->set_address_results(addresses, base::nullopt); req->OnAsyncCompleted(id, SquashErrorCode(error)); } void MockHostResolverBase::DetachRequest(size_t id) { auto it = requests_.find(id); CHECK(it != requests_.end()); requests_.erase(it); } const std::string& MockHostResolverBase::request_host(size_t id) { DCHECK(request(id)); return request(id)->request_host().host(); } RequestPriority MockHostResolverBase::request_priority(size_t id) { DCHECK(request(id)); return request(id)->priority(); } const NetworkIsolationKey& MockHostResolverBase::request_network_isolation_key( size_t id) { DCHECK(request(id)); return request(id)->network_isolation_key(); } void MockHostResolverBase::ResolveOnlyRequestNow() { DCHECK_EQ(1u, requests_.size()); ResolveNow(requests_.begin()->first); } void MockHostResolverBase::TriggerMdnsListeners( const HostPortPair& host, DnsQueryType query_type, MdnsListener::Delegate::UpdateType update_type, const IPEndPoint& address_result) { for (auto* listener : listeners_) { if (listener->host() == host && listener->query_type() == query_type) listener->TriggerAddressResult(update_type, address_result); } } void MockHostResolverBase::TriggerMdnsListeners( const HostPortPair& host, DnsQueryType query_type, MdnsListener::Delegate::UpdateType update_type, const std::vector& text_result) { for (auto* listener : listeners_) { if (listener->host() == host && listener->query_type() == query_type) listener->TriggerTextResult(update_type, text_result); } } void MockHostResolverBase::TriggerMdnsListeners( const HostPortPair& host, DnsQueryType query_type, MdnsListener::Delegate::UpdateType update_type, const HostPortPair& host_result) { for (auto* listener : listeners_) { if (listener->host() == host && listener->query_type() == query_type) listener->TriggerHostnameResult(update_type, host_result); } } void MockHostResolverBase::TriggerMdnsListeners( const HostPortPair& host, DnsQueryType query_type, MdnsListener::Delegate::UpdateType update_type) { for (auto* listener : listeners_) { if (listener->host() == host && listener->query_type() == query_type) listener->TriggerUnhandledResult(update_type); } } MockHostResolverBase::RequestImpl* MockHostResolverBase::request(size_t id) { RequestMap::iterator request = requests_.find(id); DCHECK(request != requests_.end()); return (*request).second; } // start id from 1 to distinguish from NULL RequestHandle MockHostResolverBase::MockHostResolverBase(bool use_caching, int cache_invalidation_num) : last_request_priority_(DEFAULT_PRIORITY), last_secure_dns_mode_override_(base::nullopt), synchronous_mode_(false), ondemand_mode_(false), initial_cache_invalidation_num_(cache_invalidation_num), next_request_id_(1), num_resolve_(0), num_resolve_from_cache_(0), num_non_local_resolves_(0), tick_clock_(base::DefaultTickClock::GetInstance()) { rules_map_[HostResolverSource::ANY] = CreateCatchAllHostResolverProc(); rules_map_[HostResolverSource::SYSTEM] = CreateCatchAllHostResolverProc(); rules_map_[HostResolverSource::DNS] = CreateCatchAllHostResolverProc(); rules_map_[HostResolverSource::MULTICAST_DNS] = CreateCatchAllHostResolverProc(); if (use_caching) cache_.reset(new HostCache(kMaxCacheEntries)); else DCHECK_GE(0, cache_invalidation_num); } int MockHostResolverBase::Resolve(RequestImpl* request) { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); last_request_priority_ = request->parameters().initial_priority; last_request_network_isolation_key_ = request->network_isolation_key(); last_secure_dns_mode_override_ = request->parameters().secure_dns_mode_override; num_resolve_++; AddressList addresses; base::Optional stale_info; int rv = ResolveFromIPLiteralOrCache( request->request_host(), request->network_isolation_key(), request->parameters().dns_query_type, request->host_resolver_flags(), request->parameters().source, request->parameters().cache_usage, &addresses, &stale_info); request->SetError(rv); if (rv == OK && !request->parameters().is_speculative) request->set_address_results(addresses, std::move(stale_info)); if (rv != ERR_DNS_CACHE_MISS || request->parameters().source == HostResolverSource::LOCAL_ONLY) { return SquashErrorCode(rv); } // Just like the real resolver, refuse to do anything with invalid // hostnames. if (!IsValidDNSDomain(request->request_host().host())) { request->SetError(ERR_NAME_NOT_RESOLVED); return ERR_NAME_NOT_RESOLVED; } if (synchronous_mode_) { int rv = ResolveProc( request->request_host(), request->network_isolation_key(), DnsQueryTypeToAddressFamily(request->parameters().dns_query_type), request->host_resolver_flags(), request->parameters().source, &addresses); request->SetError(rv); if (rv == OK && !request->parameters().is_speculative) request->set_address_results(addresses, base::nullopt); return SquashErrorCode(rv); } // Store the request for asynchronous resolution size_t id = next_request_id_++; request->set_id(id); requests_[id] = request; if (!ondemand_mode_) { base::ThreadTaskRunnerHandle::Get()->PostTask( FROM_HERE, base::BindOnce(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id)); } return ERR_IO_PENDING; } int MockHostResolverBase::ResolveFromIPLiteralOrCache( const HostPortPair& host, const NetworkIsolationKey& network_isolation_key, DnsQueryType dns_query_type, HostResolverFlags flags, HostResolverSource source, HostResolver::ResolveHostParameters::CacheUsage cache_usage, AddressList* addresses, base::Optional* out_stale_info) { DCHECK(addresses); DCHECK(out_stale_info); *out_stale_info = base::nullopt; IPAddress ip_address; if (ip_address.AssignFromIPLiteral(host.host())) { // This matches the behavior HostResolverImpl. if (dns_query_type != DnsQueryType::UNSPECIFIED && dns_query_type != AddressFamilyToDnsQueryType(GetAddressFamily(ip_address))) { return ERR_NAME_NOT_RESOLVED; } *addresses = AddressList::CreateFromIPAddress(ip_address, host.port()); if (flags & HOST_RESOLVER_CANONNAME) addresses->SetDefaultCanonicalName(); return OK; } int rv = ERR_DNS_CACHE_MISS; bool cache_allowed = cache_usage == HostResolver::ResolveHostParameters::CacheUsage::ALLOWED || cache_usage == HostResolver::ResolveHostParameters::CacheUsage::STALE_ALLOWED; if (cache_.get() && cache_allowed) { // Local-only requests search the cache for non-local-only results. HostResolverSource effective_source = source == HostResolverSource::LOCAL_ONLY ? HostResolverSource::ANY : source; HostCache::Key key(host.host(), dns_query_type, flags, effective_source, network_isolation_key); const std::pair* cache_result; HostCache::EntryStaleness stale_info = HostCache::kNotStale; if (cache_usage == HostResolver::ResolveHostParameters::CacheUsage::STALE_ALLOWED) { cache_result = cache_->LookupStale(key, tick_clock_->NowTicks(), &stale_info, true /* ignore_secure */); } else { cache_result = cache_->Lookup(key, tick_clock_->NowTicks(), true /* ignore_secure */); } if (cache_result) { rv = cache_result->second.error(); if (rv == OK) { *addresses = AddressList::CopyWithPort( cache_result->second.addresses().value(), host.port()); *out_stale_info = std::move(stale_info); } auto cache_invalidation_iterator = cache_invalidation_nums_.find(key); if (cache_invalidation_iterator != cache_invalidation_nums_.end()) { DCHECK_LE(1, cache_invalidation_iterator->second); cache_invalidation_iterator->second--; if (cache_invalidation_iterator->second == 0) { HostCache::Entry new_entry(cache_result->second); cache_->Set(key, new_entry, tick_clock_->NowTicks(), base::TimeDelta()); cache_invalidation_nums_.erase(cache_invalidation_iterator); } } } } return rv; } int MockHostResolverBase::ResolveProc( const HostPortPair& host, const NetworkIsolationKey& network_isolation_key, AddressFamily requested_address_family, HostResolverFlags flags, HostResolverSource source, AddressList* addresses) { DCHECK(rules_map_.find(source) != rules_map_.end()); ++num_non_local_resolves_; AddressList addr; int rv = rules_map_[source]->Resolve(host.host(), requested_address_family, flags, &addr, nullptr); if (cache_.get()) { HostCache::Key key(host.host(), AddressFamilyToDnsQueryType(requested_address_family), flags, source, network_isolation_key); // Storing a failure with TTL 0 so that it overwrites previous value. base::TimeDelta ttl; if (rv == OK) { ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds); if (initial_cache_invalidation_num_ > 0) cache_invalidation_nums_[key] = initial_cache_invalidation_num_; } cache_->Set(key, HostCache::Entry(rv, addr, HostCache::Entry::SOURCE_UNKNOWN), tick_clock_->NowTicks(), ttl); } if (rv == OK) *addresses = AddressList::CopyWithPort(addr, host.port()); return rv; } void MockHostResolverBase::AddListener(MdnsListenerImpl* listener) { listeners_.insert(listener); } void MockHostResolverBase::RemoveCancelledListener(MdnsListenerImpl* listener) { listeners_.erase(listener); } MockHostResolverFactory::MockHostResolverFactory( scoped_refptr rules, bool use_caching, int cache_invalidation_num) : rules_(std::move(rules)), use_caching_(use_caching), cache_invalidation_num_(cache_invalidation_num) {} MockHostResolverFactory::~MockHostResolverFactory() = default; std::unique_ptr MockHostResolverFactory::CreateResolver( HostResolverManager* manager, base::StringPiece host_mapping_rules, bool enable_caching) { DCHECK(host_mapping_rules.empty()); // Explicit new to access private constructor. auto resolver = base::WrapUnique(new MockHostResolverBase( enable_caching && use_caching_, cache_invalidation_num_)); if (rules_) resolver->set_rules(rules_.get()); return resolver; } std::unique_ptr MockHostResolverFactory::CreateStandaloneResolver( NetLog* net_log, const HostResolver::ManagerOptions& options, base::StringPiece host_mapping_rules, bool enable_caching) { return CreateResolver(nullptr, host_mapping_rules, enable_caching); } //----------------------------------------------------------------------------- RuleBasedHostResolverProc::Rule::Rule(ResolverType resolver_type, const std::string& host_pattern, AddressFamily address_family, HostResolverFlags host_resolver_flags, const std::string& replacement, const std::string& canonical_name, int latency_ms) : resolver_type(resolver_type), host_pattern(host_pattern), address_family(address_family), host_resolver_flags(host_resolver_flags), replacement(replacement), canonical_name(canonical_name), latency_ms(latency_ms) {} RuleBasedHostResolverProc::Rule::Rule(const Rule& other) = default; RuleBasedHostResolverProc::RuleBasedHostResolverProc(HostResolverProc* previous) : HostResolverProc(previous), modifications_allowed_(true) {} void RuleBasedHostResolverProc::AddRule(const std::string& host_pattern, const std::string& replacement) { AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED, replacement); } void RuleBasedHostResolverProc::AddRuleForAddressFamily( const std::string& host_pattern, AddressFamily address_family, const std::string& replacement) { DCHECK(!replacement.empty()); HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; Rule rule(Rule::kResolverTypeSystem, host_pattern, address_family, flags, replacement, std::string(), 0); AddRuleInternal(rule); } void RuleBasedHostResolverProc::AddRuleWithFlags( const std::string& host_pattern, const std::string& replacement, HostResolverFlags flags, const std::string& canonical_name) { DCHECK(!replacement.empty()); Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED, flags, replacement, canonical_name, 0); AddRuleInternal(rule); } void RuleBasedHostResolverProc::AddIPLiteralRule( const std::string& host_pattern, const std::string& ip_literal, const std::string& canonical_name) { // Literals are always resolved to themselves by HostResolverImpl, // consequently we do not support remapping them. IPAddress ip_address; DCHECK(!ip_address.AssignFromIPLiteral(host_pattern)); HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; if (!canonical_name.empty()) flags |= HOST_RESOLVER_CANONNAME; Rule rule(Rule::kResolverTypeIPLiteral, host_pattern, ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, canonical_name, 0); AddRuleInternal(rule); } void RuleBasedHostResolverProc::AddRuleWithLatency( const std::string& host_pattern, const std::string& replacement, int latency_ms) { DCHECK(!replacement.empty()); HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED, flags, replacement, std::string(), latency_ms); AddRuleInternal(rule); } void RuleBasedHostResolverProc::AllowDirectLookup( const std::string& host_pattern) { HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; Rule rule(Rule::kResolverTypeSystem, host_pattern, ADDRESS_FAMILY_UNSPECIFIED, flags, std::string(), std::string(), 0); AddRuleInternal(rule); } void RuleBasedHostResolverProc::AddSimulatedFailure( const std::string& host_pattern) { HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; Rule rule(Rule::kResolverTypeFail, host_pattern, ADDRESS_FAMILY_UNSPECIFIED, flags, std::string(), std::string(), 0); AddRuleInternal(rule); } void RuleBasedHostResolverProc::AddSimulatedTimeoutFailure( const std::string& host_pattern) { HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; Rule rule(Rule::kResolverTypeFailTimeout, host_pattern, ADDRESS_FAMILY_UNSPECIFIED, flags, std::string(), std::string(), 0); AddRuleInternal(rule); } void RuleBasedHostResolverProc::ClearRules() { CHECK(modifications_allowed_); base::AutoLock lock(rule_lock_); rules_.clear(); } void RuleBasedHostResolverProc::DisableModifications() { modifications_allowed_ = false; } RuleBasedHostResolverProc::RuleList RuleBasedHostResolverProc::GetRules() { RuleList rv; { base::AutoLock lock(rule_lock_); rv = rules_; } return rv; } int RuleBasedHostResolverProc::Resolve(const std::string& host, AddressFamily address_family, HostResolverFlags host_resolver_flags, AddressList* addrlist, int* os_error) { base::AutoLock lock(rule_lock_); RuleList::iterator r; for (r = rules_.begin(); r != rules_.end(); ++r) { bool matches_address_family = r->address_family == ADDRESS_FAMILY_UNSPECIFIED || r->address_family == address_family; // Ignore HOST_RESOLVER_SYSTEM_ONLY, since it should have no impact on // whether a rule matches. HostResolverFlags flags = host_resolver_flags & (~HOST_RESOLVER_SYSTEM_ONLY & ~HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6); // Flags match if all of the bitflags in host_resolver_flags are enabled // in the rule's host_resolver_flags. However, the rule may have additional // flags specified, in which case the flags should still be considered a // match. bool matches_flags = (r->host_resolver_flags & flags) == flags; if (matches_flags && matches_address_family && base::MatchPattern(host, r->host_pattern)) { if (r->latency_ms != 0) { base::PlatformThread::Sleep( base::TimeDelta::FromMilliseconds(r->latency_ms)); } // Remap to a new host. const std::string& effective_host = r->replacement.empty() ? host : r->replacement; // Apply the resolving function to the remapped hostname. switch (r->resolver_type) { case Rule::kResolverTypeFail: return ERR_NAME_NOT_RESOLVED; case Rule::kResolverTypeFailTimeout: return ERR_DNS_TIMED_OUT; case Rule::kResolverTypeSystem: #if defined(OS_WIN) EnsureWinsockInit(); #endif return SystemHostResolverCall(effective_host, address_family, host_resolver_flags, addrlist, os_error); case Rule::kResolverTypeIPLiteral: { AddressList raw_addr_list; int result = ParseAddressList( effective_host, !r->canonical_name.empty() ? r->canonical_name : host, &raw_addr_list); // Filter out addresses with the wrong family. *addrlist = AddressList(); for (const auto& address : raw_addr_list) { if (address_family == ADDRESS_FAMILY_UNSPECIFIED || address_family == address.GetFamily()) { addrlist->push_back(address); } } addrlist->set_canonical_name(raw_addr_list.canonical_name()); if (result == OK && addrlist->empty()) return ERR_NAME_NOT_RESOLVED; return result; } default: NOTREACHED(); return ERR_UNEXPECTED; } } } return ResolveUsingPrevious(host, address_family, host_resolver_flags, addrlist, os_error); } RuleBasedHostResolverProc::~RuleBasedHostResolverProc() = default; void RuleBasedHostResolverProc::AddRuleInternal(const Rule& rule) { Rule fixed_rule = rule; // SystemResolverProc expects valid DNS addresses. // So for kResolverTypeSystem rules: // * If the replacement is an IP address, switch to an IP literal rule. // * If it's a non-empty invalid domain name, switch to a fail rule (Empty // domain names mean use a direct lookup). if (fixed_rule.resolver_type == Rule::kResolverTypeSystem) { IPAddress ip_address; bool valid_address = ip_address.AssignFromIPLiteral(fixed_rule.replacement); if (valid_address) { fixed_rule.resolver_type = Rule::kResolverTypeIPLiteral; } else if (!fixed_rule.replacement.empty() && !IsValidDNSDomain(fixed_rule.replacement)) { // TODO(mmenke): Can this be replaced with a DCHECK instead? fixed_rule.resolver_type = Rule::kResolverTypeFail; } } CHECK(modifications_allowed_); base::AutoLock lock(rule_lock_); rules_.push_back(fixed_rule); } RuleBasedHostResolverProc* CreateCatchAllHostResolverProc() { RuleBasedHostResolverProc* catchall = new RuleBasedHostResolverProc(nullptr); // Note that IPv6 lookups fail. catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost"); // Next add a rules-based layer the use controls. return new RuleBasedHostResolverProc(catchall); } //----------------------------------------------------------------------------- // Implementation of ResolveHostRequest that tracks cancellations when the // request is destroyed after being started. class HangingHostResolver::RequestImpl : public HostResolver::ResolveHostRequest, public HostResolver::ProbeRequest { public: explicit RequestImpl(base::WeakPtr resolver) : resolver_(resolver) {} ~RequestImpl() override { if (is_running_ && resolver_) resolver_->num_cancellations_++; } int Start(CompletionOnceCallback callback) override { return Start(); } int Start() override { DCHECK(resolver_); is_running_ = true; return ERR_IO_PENDING; } const base::Optional& GetAddressResults() const override { IMMEDIATE_CRASH(); } const base::Optional>& GetTextResults() const override { IMMEDIATE_CRASH(); } const base::Optional>& GetHostnameResults() const override { IMMEDIATE_CRASH(); } net::ResolveErrorInfo GetResolveErrorInfo() const override { IMMEDIATE_CRASH(); } const base::Optional& GetStaleInfo() const override { IMMEDIATE_CRASH(); } void ChangeRequestPriority(RequestPriority priority) override {} private: // Use a WeakPtr as the resolver may be destroyed while there are still // outstanding request objects. base::WeakPtr resolver_; bool is_running_ = false; DISALLOW_COPY_AND_ASSIGN(RequestImpl); }; HangingHostResolver::HangingHostResolver() = default; HangingHostResolver::~HangingHostResolver() = default; void HangingHostResolver::OnShutdown() { shutting_down_ = true; } std::unique_ptr HangingHostResolver::CreateRequest( const HostPortPair& host, const NetworkIsolationKey& network_isolation_key, const NetLogWithSource& source_net_log, const base::Optional& optional_parameters) { last_host_ = host; last_network_isolation_key_ = network_isolation_key; if (shutting_down_) return CreateFailingRequest(ERR_CONTEXT_SHUT_DOWN); if (optional_parameters && optional_parameters.value().source == HostResolverSource::LOCAL_ONLY) { return CreateFailingRequest(ERR_DNS_CACHE_MISS); } return std::make_unique(weak_ptr_factory_.GetWeakPtr()); } std::unique_ptr HangingHostResolver::CreateDohProbeRequest() { if (shutting_down_) return CreateFailingProbeRequest(ERR_CONTEXT_SHUT_DOWN); return std::make_unique(weak_ptr_factory_.GetWeakPtr()); } //----------------------------------------------------------------------------- ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() = default; ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc( HostResolverProc* proc) { Init(proc); } ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() { HostResolverProc* old_proc = HostResolverProc::SetDefault(previous_proc_.get()); // The lifetimes of multiple instances must be nested. CHECK_EQ(old_proc, current_proc_.get()); } void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) { current_proc_ = proc; previous_proc_ = HostResolverProc::SetDefault(current_proc_.get()); current_proc_->SetLastProc(previous_proc_.get()); } } // namespace net