diff options
author | James M Snell <jasnell@gmail.com> | 2021-03-25 13:27:03 -0700 |
---|---|---|
committer | James M Snell <jasnell@gmail.com> | 2021-04-02 06:16:42 -0700 |
commit | e79471deb7b5c91f061bf395bbef9816e26b4bf9 (patch) | |
tree | 4fd83036cf9e667f16324bb733955aee1aed34b1 /src | |
parent | 0b90d352948393e5f92ceb2ef749c643b1002a68 (diff) | |
download | node-new-e79471deb7b5c91f061bf395bbef9816e26b4bf9.tar.gz |
net: allow net.BlockList to use net.SocketAddress objects
Signed-off-by: James M Snell <jasnell@gmail.com>
PR-URL: https://github.com/nodejs/node/pull/37917
Reviewed-By: Matteo Collina <matteo.collina@gmail.com>
Diffstat (limited to 'src')
-rw-r--r-- | src/node_sockaddr.cc | 140 | ||||
-rw-r--r-- | src/node_sockaddr.h | 42 |
2 files changed, 75 insertions, 107 deletions
diff --git a/src/node_sockaddr.cc b/src/node_sockaddr.cc index e745589b12..b15b6ab471 100644 --- a/src/node_sockaddr.cc +++ b/src/node_sockaddr.cc @@ -392,18 +392,18 @@ SocketAddressBlockList::SocketAddressBlockList( : parent_(parent) {} void SocketAddressBlockList::AddSocketAddress( - const SocketAddress& address) { + const std::shared_ptr<SocketAddress>& address) { Mutex::ScopedLock lock(mutex_); std::unique_ptr<Rule> rule = std::make_unique<SocketAddressRule>(address); rules_.emplace_front(std::move(rule)); - address_rules_[address] = rules_.begin(); + address_rules_[*address.get()] = rules_.begin(); } void SocketAddressBlockList::RemoveSocketAddress( - const SocketAddress& address) { + const std::shared_ptr<SocketAddress>& address) { Mutex::ScopedLock lock(mutex_); - auto it = address_rules_.find(address); + auto it = address_rules_.find(*address.get()); if (it != std::end(address_rules_)) { rules_.erase(it->second); address_rules_.erase(it); @@ -411,8 +411,8 @@ void SocketAddressBlockList::RemoveSocketAddress( } void SocketAddressBlockList::AddSocketAddressRange( - const SocketAddress& start, - const SocketAddress& end) { + const std::shared_ptr<SocketAddress>& start, + const std::shared_ptr<SocketAddress>& end) { Mutex::ScopedLock lock(mutex_); std::unique_ptr<Rule> rule = std::make_unique<SocketAddressRangeRule>(start, end); @@ -420,7 +420,7 @@ void SocketAddressBlockList::AddSocketAddressRange( } void SocketAddressBlockList::AddSocketAddressMask( - const SocketAddress& network, + const std::shared_ptr<SocketAddress>& network, int prefix) { Mutex::ScopedLock lock(mutex_); std::unique_ptr<Rule> rule = @@ -428,7 +428,8 @@ void SocketAddressBlockList::AddSocketAddressMask( rules_.emplace_front(std::move(rule)); } -bool SocketAddressBlockList::Apply(const SocketAddress& address) { +bool SocketAddressBlockList::Apply( + const std::shared_ptr<SocketAddress>& address) { Mutex::ScopedLock lock(mutex_); for (const auto& rule : rules_) { if (rule->Apply(address)) @@ -438,59 +439,60 @@ bool SocketAddressBlockList::Apply(const SocketAddress& address) { } SocketAddressBlockList::SocketAddressRule::SocketAddressRule( - const SocketAddress& address_) + const std::shared_ptr<SocketAddress>& address_) : address(address_) {} SocketAddressBlockList::SocketAddressRangeRule::SocketAddressRangeRule( - const SocketAddress& start_, - const SocketAddress& end_) + const std::shared_ptr<SocketAddress>& start_, + const std::shared_ptr<SocketAddress>& end_) : start(start_), end(end_) {} SocketAddressBlockList::SocketAddressMaskRule::SocketAddressMaskRule( - const SocketAddress& network_, + const std::shared_ptr<SocketAddress>& network_, int prefix_) : network(network_), prefix(prefix_) {} bool SocketAddressBlockList::SocketAddressRule::Apply( - const SocketAddress& address) { - return this->address.is_match(address); + const std::shared_ptr<SocketAddress>& address) { + return this->address->is_match(*address.get()); } std::string SocketAddressBlockList::SocketAddressRule::ToString() { std::string ret = "Address: "; - ret += address.family() == AF_INET ? "IPv4" : "IPv6"; + ret += address->family() == AF_INET ? "IPv4" : "IPv6"; ret += " "; - ret += address.address(); + ret += address->address(); return ret; } bool SocketAddressBlockList::SocketAddressRangeRule::Apply( - const SocketAddress& address) { - return address >= start && address <= end; + const std::shared_ptr<SocketAddress>& address) { + return *address.get() >= *start.get() && + *address.get() <= *end.get(); } std::string SocketAddressBlockList::SocketAddressRangeRule::ToString() { std::string ret = "Range: "; - ret += start.family() == AF_INET ? "IPv4" : "IPv6"; + ret += start->family() == AF_INET ? "IPv4" : "IPv6"; ret += " "; - ret += start.address(); + ret += start->address(); ret += "-"; - ret += end.address(); + ret += end->address(); return ret; } bool SocketAddressBlockList::SocketAddressMaskRule::Apply( - const SocketAddress& address) { - return address.is_in_network(network, prefix); + const std::shared_ptr<SocketAddress>& address) { + return address->is_in_network(*network.get(), prefix); } std::string SocketAddressBlockList::SocketAddressMaskRule::ToString() { std::string ret = "Subnet: "; - ret += network.family() == AF_INET ? "IPv4" : "IPv6"; + ret += network->family() == AF_INET ? "IPv4" : "IPv6"; ret += " "; - ret += network.address(); + ret += network->address(); ret += "/" + std::to_string(prefix); return ret; } @@ -591,20 +593,11 @@ void SocketAddressBlockListWrap::AddAddress( SocketAddressBlockListWrap* wrap; ASSIGN_OR_RETURN_UNWRAP(&wrap, args.Holder()); - CHECK(args[0]->IsString()); - CHECK(args[1]->IsInt32()); - - sockaddr_storage address; - Utf8Value value(args.GetIsolate(), args[0]); - int32_t family; - if (!args[1]->Int32Value(env->context()).To(&family)) - return; - - if (!SocketAddress::ToSockAddr(family, *value, 0, &address)) - return; + CHECK(SocketAddressBase::HasInstance(env, args[0])); + SocketAddressBase* addr; + ASSIGN_OR_RETURN_UNWRAP(&addr, args[0]); - wrap->blocklist_->AddSocketAddress( - SocketAddress(reinterpret_cast<const sockaddr*>(&address))); + wrap->blocklist_->AddSocketAddress(addr->address()); args.GetReturnValue().Set(true); } @@ -615,30 +608,21 @@ void SocketAddressBlockListWrap::AddRange( SocketAddressBlockListWrap* wrap; ASSIGN_OR_RETURN_UNWRAP(&wrap, args.Holder()); - CHECK(args[0]->IsString()); - CHECK(args[1]->IsString()); - CHECK(args[2]->IsInt32()); - - sockaddr_storage address[2]; - Utf8Value start(args.GetIsolate(), args[0]); - Utf8Value end(args.GetIsolate(), args[1]); - int32_t family; - if (!args[2]->Int32Value(env->context()).To(&family)) - return; - - if (!SocketAddress::ToSockAddr(family, *start, 0, &address[0]) || - !SocketAddress::ToSockAddr(family, *end, 0, &address[1])) { - return; - } + CHECK(SocketAddressBase::HasInstance(env, args[0])); + CHECK(SocketAddressBase::HasInstance(env, args[1])); - SocketAddress start_addr(reinterpret_cast<const sockaddr*>(&address[0])); - SocketAddress end_addr(reinterpret_cast<const sockaddr*>(&address[1])); + SocketAddressBase* start_addr; + SocketAddressBase* end_addr; + ASSIGN_OR_RETURN_UNWRAP(&start_addr, args[0]); + ASSIGN_OR_RETURN_UNWRAP(&end_addr, args[1]); // Starting address must come before the end address - if (start_addr > end_addr) + if (*start_addr->address().get() > *end_addr->address().get()) return args.GetReturnValue().Set(false); - wrap->blocklist_->AddSocketAddressRange(start_addr, end_addr); + wrap->blocklist_->AddSocketAddressRange( + start_addr->address(), + end_addr->address()); args.GetReturnValue().Set(true); } @@ -649,29 +633,22 @@ void SocketAddressBlockListWrap::AddSubnet( SocketAddressBlockListWrap* wrap; ASSIGN_OR_RETURN_UNWRAP(&wrap, args.Holder()); - CHECK(args[0]->IsString()); + CHECK(SocketAddressBase::HasInstance(env, args[0])); CHECK(args[1]->IsInt32()); - CHECK(args[2]->IsInt32()); - sockaddr_storage address; - Utf8Value network(args.GetIsolate(), args[0]); - int32_t family; + SocketAddressBase* addr; + ASSIGN_OR_RETURN_UNWRAP(&addr, args[0]); + int32_t prefix; - if (!args[1]->Int32Value(env->context()).To(&family) || - !args[2]->Int32Value(env->context()).To(&prefix)) { + if (!args[1]->Int32Value(env->context()).To(&prefix)) { return; } - if (!SocketAddress::ToSockAddr(family, *network, 0, &address)) - return; - - CHECK_IMPLIES(family == AF_INET, prefix <= 32); - CHECK_IMPLIES(family == AF_INET6, prefix <= 128); + CHECK_IMPLIES(addr->address()->family() == AF_INET, prefix <= 32); + CHECK_IMPLIES(addr->address()->family() == AF_INET6, prefix <= 128); CHECK_GE(prefix, 0); - wrap->blocklist_->AddSocketAddressMask( - SocketAddress(reinterpret_cast<const sockaddr*>(&address)), - prefix); + wrap->blocklist_->AddSocketAddressMask(addr->address(), prefix); args.GetReturnValue().Set(true); } @@ -682,21 +659,11 @@ void SocketAddressBlockListWrap::Check( SocketAddressBlockListWrap* wrap; ASSIGN_OR_RETURN_UNWRAP(&wrap, args.Holder()); - CHECK(args[0]->IsString()); - CHECK(args[1]->IsInt32()); + CHECK(SocketAddressBase::HasInstance(env, args[0])); + SocketAddressBase* addr; + ASSIGN_OR_RETURN_UNWRAP(&addr, args[0]); - sockaddr_storage address; - Utf8Value value(args.GetIsolate(), args[0]); - int32_t family; - if (!args[1]->Int32Value(env->context()).To(&family)) - return; - - if (!SocketAddress::ToSockAddr(family, *value, 0, &address)) - return; - - args.GetReturnValue().Set( - wrap->blocklist_->Apply( - SocketAddress(reinterpret_cast<const sockaddr*>(&address)))); + args.GetReturnValue().Set(wrap->blocklist_->Apply(addr->address())); } void SocketAddressBlockListWrap::GetRules( @@ -869,7 +836,6 @@ void SocketAddressBase::Detail(const FunctionCallbackInfo<Value>& args) { } void SocketAddressBase::GetFlowLabel(const FunctionCallbackInfo<Value>& args) { - Environment* env = Environment::GetCurrent(args); SocketAddressBase* base; ASSIGN_OR_RETURN_UNWRAP(&base, args.Holder()); args.GetReturnValue().Set(base->address_->flow_label()); diff --git a/src/node_sockaddr.h b/src/node_sockaddr.h index 62dcab6bad..704fe0c511 100644 --- a/src/node_sockaddr.h +++ b/src/node_sockaddr.h @@ -168,6 +168,10 @@ class SocketAddressBase : public BaseObject { v8::Local<v8::Object> wrap, std::shared_ptr<SocketAddress> address); + inline const std::shared_ptr<SocketAddress>& address() const { + return address_; + } + void MemoryInfo(MemoryTracker* tracker) const override; SET_MEMORY_INFO_NAME(SocketAddressBase); SET_SELF_SIZE(SocketAddressBase); @@ -246,38 +250,36 @@ class SocketAddressBlockList : public MemoryRetainer { std::shared_ptr<SocketAddressBlockList> parent = {}); ~SocketAddressBlockList() = default; - void AddSocketAddress( - const SocketAddress& address); + void AddSocketAddress(const std::shared_ptr<SocketAddress>& address); - void RemoveSocketAddress( - const SocketAddress& address); + void RemoveSocketAddress(const std::shared_ptr<SocketAddress>& address); void AddSocketAddressRange( - const SocketAddress& start, - const SocketAddress& end); + const std::shared_ptr<SocketAddress>& start, + const std::shared_ptr<SocketAddress>& end); void AddSocketAddressMask( - const SocketAddress& address, + const std::shared_ptr<SocketAddress>& address, int prefix); - bool Apply(const SocketAddress& address); + bool Apply(const std::shared_ptr<SocketAddress>& address); size_t size() const { return rules_.size(); } v8::MaybeLocal<v8::Array> ListRules(Environment* env); struct Rule : public MemoryRetainer { - virtual bool Apply(const SocketAddress& address) = 0; + virtual bool Apply(const std::shared_ptr<SocketAddress>& address) = 0; inline v8::MaybeLocal<v8::Value> ToV8String(Environment* env); virtual std::string ToString() = 0; }; struct SocketAddressRule final : Rule { - SocketAddress address; + std::shared_ptr<SocketAddress> address; - explicit SocketAddressRule(const SocketAddress& address); + explicit SocketAddressRule(const std::shared_ptr<SocketAddress>& address); - bool Apply(const SocketAddress& address) override; + bool Apply(const std::shared_ptr<SocketAddress>& address) override; std::string ToString() override; void MemoryInfo(node::MemoryTracker* tracker) const override; @@ -286,14 +288,14 @@ class SocketAddressBlockList : public MemoryRetainer { }; struct SocketAddressRangeRule final : Rule { - SocketAddress start; - SocketAddress end; + std::shared_ptr<SocketAddress> start; + std::shared_ptr<SocketAddress> end; SocketAddressRangeRule( - const SocketAddress& start, - const SocketAddress& end); + const std::shared_ptr<SocketAddress>& start, + const std::shared_ptr<SocketAddress>& end); - bool Apply(const SocketAddress& address) override; + bool Apply(const std::shared_ptr<SocketAddress>& address) override; std::string ToString() override; void MemoryInfo(node::MemoryTracker* tracker) const override; @@ -302,14 +304,14 @@ class SocketAddressBlockList : public MemoryRetainer { }; struct SocketAddressMaskRule final : Rule { - SocketAddress network; + std::shared_ptr<SocketAddress> network; int prefix; SocketAddressMaskRule( - const SocketAddress& address, + const std::shared_ptr<SocketAddress>& address, int prefix); - bool Apply(const SocketAddress& address) override; + bool Apply(const std::shared_ptr<SocketAddress>& address) override; std::string ToString() override; void MemoryInfo(node::MemoryTracker* tracker) const override; |