diff options
author | Bob Halley <halley@dnspython.org> | 2021-12-01 06:48:58 -0800 |
---|---|---|
committer | Bob Halley <halley@dnspython.org> | 2021-12-01 06:48:58 -0800 |
commit | 9a16076cb3b7d36efaabff030688693dd56f0ee6 (patch) | |
tree | 173052cae1ce98b0147d84943aaa26cd47e713df /dns/versioned.py | |
parent | c706e26e990856311e35f46cb58eaf333e80ed2f (diff) | |
download | dnspython-zone-refactor.tar.gz |
Refactor zone transactions to always use versioned CoW code.zone-refactor
Diffstat (limited to 'dns/versioned.py')
-rw-r--r-- | dns/versioned.py | 219 |
1 files changed, 19 insertions, 200 deletions
diff --git a/dns/versioned.py b/dns/versioned.py index 686a83b..42f2c81 100644 --- a/dns/versioned.py +++ b/dns/versioned.py @@ -11,12 +11,9 @@ except ImportError: # pragma: no cover import dns.exception import dns.immutable import dns.name -import dns.node import dns.rdataclass import dns.rdatatype -import dns.rdata import dns.rdtypes.ANY.SOA -import dns.transaction import dns.zone @@ -24,142 +21,13 @@ class UseTransaction(dns.exception.DNSException): """To alter a versioned zone, use a transaction.""" -class Version: - def __init__(self, zone, id): - self.zone = zone - self.id = id - self.nodes = {} - - def _validate_name(self, name): - if name.is_absolute(): - if not name.is_subdomain(self.zone.origin): - raise KeyError("name is not a subdomain of the zone origin") - if self.zone.relativize: - name = name.relativize(self.origin) - return name - - def get_node(self, name): - name = self._validate_name(name) - return self.nodes.get(name) - - def get_rdataset(self, name, rdtype, covers): - node = self.get_node(name) - if node is None: - return None - return node.get_rdataset(self.zone.rdclass, rdtype, covers) - - def items(self): - return self.nodes.items() # pylint: disable=dict-items-not-iterating - - -class WritableVersion(Version): - def __init__(self, zone, replacement=False): - # The zone._versions_lock must be held by our caller. - if len(zone._versions) > 0: - id = zone._versions[-1].id + 1 - else: - id = 1 - super().__init__(zone, id) - if not replacement: - # We copy the map, because that gives us a simple and thread-safe - # way of doing versions, and we have a garbage collector to help - # us. We only make new node objects if we actually change the - # node. - self.nodes.update(zone.nodes) - # We have to copy the zone origin as it may be None in the first - # version, and we don't want to mutate the zone until we commit. - self.origin = zone.origin - self.changed = set() - - def _maybe_cow(self, name): - name = self._validate_name(name) - node = self.nodes.get(name) - if node is None or node.id != self.id: - new_node = self.zone.node_factory() - new_node.id = self.id - if node is not None: - # moo! copy on write! - new_node.rdatasets.extend(node.rdatasets) - self.nodes[name] = new_node - self.changed.add(name) - return new_node - else: - return node - - def delete_node(self, name): - name = self._validate_name(name) - if name in self.nodes: - del self.nodes[name] - self.changed.add(name) - - def put_rdataset(self, name, rdataset): - node = self._maybe_cow(name) - node.replace_rdataset(rdataset) - - def delete_rdataset(self, name, rdtype, covers): - node = self._maybe_cow(name) - node.delete_rdataset(self.zone.rdclass, rdtype, covers) - if len(node) == 0: - del self.nodes[name] - - -@dns.immutable.immutable -class ImmutableVersion(Version): - def __init__(self, version): - # We tell super() that it's a replacement as we don't want it - # to copy the nodes, as we're about to do that with an - # immutable Dict. - super().__init__(version.zone, True) - # set the right id! - self.id = version.id - # Make changed nodes immutable - for name in version.changed: - node = version.nodes.get(name) - # it might not exist if we deleted it in the version - if node: - version.nodes[name] = ImmutableNode(node) - self.nodes = dns.immutable.Dict(version.nodes, True) - - -# A node with a version id. - -class Node(dns.node.Node): - __slots__ = ['id'] - - def __init__(self): - super().__init__() - # A proper id will get set by the Version - self.id = 0 - - -@dns.immutable.immutable -class ImmutableNode(Node): - __slots__ = ['id'] - - def __init__(self, node): - super().__init__() - self.id = node.id - self.rdatasets = tuple( - [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] - ) - - def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): - if create: - raise TypeError("immutable") - return super().find_rdataset(rdclass, rdtype, covers, False) - - def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): - if create: - raise TypeError("immutable") - return super().get_rdataset(rdclass, rdtype, covers, False) - - def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): - raise TypeError("immutable") - - def replace_rdataset(self, replacement): - raise TypeError("immutable") +# Backwards compatibility +Node = dns.zone.VersionedNode +ImmutableNode = dns.zone.ImmutableVersionedNode +Version = dns.zone.Version +WritableVersion = dns.zone.WritableVersion +ImmutableVersion = dns.zone.ImmutableVersion +Transaction = dns.zone.Transaction class Zone(dns.zone.Zone): @@ -198,7 +66,9 @@ class Zone(dns.zone.Zone): self._write_event = None self._write_waiters = collections.deque() self._readers = set() - self._commit_version_unlocked(None, WritableVersion(self), origin) + self._commit_version_unlocked(None, + WritableVersion(self, replacement=True), + origin) def reader(self, id=None, serial=None): # pylint: disable=arguments-differ if id is not None and serial is not None: @@ -247,7 +117,8 @@ class Zone(dns.zone.Zone): # give up the lock, so that we hold the lock as # short a time as possible. This is why we call # _setup_version() below. - self._write_txn = Transaction(self, replacement) + self._write_txn = Transaction(self, replacement, + make_immutable=True) # give up our exclusive right to make a Transaction self._write_event = None break @@ -367,6 +238,13 @@ class Zone(dns.zone.Zone): with self._version_lock: self._commit_version_unlocked(txn, version, origin) + def _get_next_version_id(self): + if len(self._versions) > 0: + id = self._versions[-1].id + 1 + else: + id = 1 + return id + def find_node(self, name, create=False): if create: raise UseTransaction @@ -394,62 +272,3 @@ class Zone(dns.zone.Zone): def replace_rdataset(self, name, replacement): raise UseTransaction - - -class Transaction(dns.transaction.Transaction): - - def __init__(self, zone, replacement, version=None): - read_only = version is not None - super().__init__(zone, replacement, read_only) - self.version = version - - @property - def zone(self): - return self.manager - - def _setup_version(self): - assert self.version is None - self.version = WritableVersion(self.zone, self.replacement) - - def _get_rdataset(self, name, rdtype, covers): - return self.version.get_rdataset(name, rdtype, covers) - - def _put_rdataset(self, name, rdataset): - assert not self.read_only - self.version.put_rdataset(name, rdataset) - - def _delete_name(self, name): - assert not self.read_only - self.version.delete_node(name) - - def _delete_rdataset(self, name, rdtype, covers): - assert not self.read_only - self.version.delete_rdataset(name, rdtype, covers) - - def _name_exists(self, name): - return self.version.get_node(name) is not None - - def _changed(self): - if self.read_only: - return False - else: - return len(self.version.changed) > 0 - - def _end_transaction(self, commit): - if self.read_only: - self.zone._end_read(self) - elif commit and len(self.version.changed) > 0: - self.zone._commit_version(self, ImmutableVersion(self.version), - self.version.origin) - else: - # rollback - self.zone._end_write(self) - - def _set_origin(self, origin): - if self.version.origin is None: - self.version.origin = origin - - def _iterate_rdatasets(self): - for (name, node) in self.version.items(): - for rdataset in node: - yield (name, rdataset) |