diff options
author | Eli Collins <elic@assurancetechnologies.com> | 2016-06-11 11:22:23 -0400 |
---|---|---|
committer | Eli Collins <elic@assurancetechnologies.com> | 2016-06-11 11:22:23 -0400 |
commit | 14234b6605506befd32b30304a84d6acac4c9737 (patch) | |
tree | 3d0f0fc51a9670ff756ddd4e5ee7f3e76ff5e3bb /passlib/apache.py | |
parent | 57006ea0feadf44e680935256825981fde6c856d (diff) | |
download | passlib-14234b6605506befd32b30304a84d6acac4c9737.tar.gz |
passlib.apache: drastically rewrote parsing & rendering code
* now preserves whitespace & comments to match htpasswd.
* now warns about & skips duplicate entries for same user.
* added semi-private .set_hash() helpers for writing hash directly.
* some internal cleanups to Htdigest's realm handling
Diffstat (limited to 'passlib/apache.py')
-rw-r--r-- | passlib/apache.py | 217 |
1 files changed, 169 insertions, 48 deletions
diff --git a/passlib/apache.py b/passlib/apache.py index 1aea738..60f5cb5 100644 --- a/passlib/apache.py +++ b/passlib/apache.py @@ -14,7 +14,7 @@ from passlib.context import CryptContext from passlib.exc import ExpectedStringError from passlib.hash import htdigest from passlib.utils import render_bytes, to_bytes, deprecated_method, is_ascii_codec -from passlib.utils.compat import join_bytes, unicode, BytesIO, iteritems, PY3, OrderedDict +from passlib.utils.compat import join_bytes, unicode, BytesIO, PY3, OrderedDict # local __all__ = [ 'HtpasswdFile', @@ -31,6 +31,10 @@ _BCOLON = b":" # byte values that aren't allowed in fields. _INVALID_FIELD_CHARS = b":\n\r\t\x00" +#: _CommonFile._source token types +_SKIPPED = "skipped" +_RECORD = "record" + #============================================================================= # common helpers #============================================================================= @@ -54,10 +58,14 @@ class _CommonFile(object): # if true, automatically save to local file after changes are made. autosave = False - # ordered dict mapping key -> value for all records in database. + # dict mapping key -> value for all records in database. # (e.g. user => hash for Htpasswd) _records = None + #: list of tokens for recreating original file contents when saving. if present, + #: will be sequence of (_SKIPPED, b"whitespace/comments") and (_RECORD, <record key>) tuples. + _source = None + #=================================================================== # alt constuctors #=================================================================== @@ -128,7 +136,8 @@ class _CommonFile(object): if path and not new: self.load() else: - self._records = OrderedDict() + self._records = {} + self._source = [] def __repr__(self): tail = '' @@ -141,13 +150,16 @@ class _CommonFile(object): return "<%s 0x%0x%s>" % (self.__class__.__name__, id(self), tail) # NOTE: ``path`` is a property so that ``_mtime`` is wiped when it's set. - def _get_path(self): + + @property + def path(self): return self._path - def _set_path(self, value): + + @path.setter + def path(self, value): if value != self._path: self._mtime = 0 self._path = value - path = property(_get_path, _set_path) @property def mtime(self): @@ -210,25 +222,64 @@ class _CommonFile(object): def _load_lines(self, lines): """load from sequence of lists""" - # XXX: found reference that "#" comment lines may be supported by - # htpasswd, should verify this, and figure out how to handle them. - # if true, this would also affect what can be stored in user field. - # XXX: if multiple entries for a key, should we use the first one - # or the last one? going w/ first entry for now. - # XXX: how should this behave if parsing fails? currently - # it will contain everything that was loaded up to error. - # could clear / restore old state instead. parse = self._parse_record - records = self._records = OrderedDict() + records = {} + source = [] + skipped = b'' for idx, line in enumerate(lines): + # NOTE: per htpasswd source (https://github.com/apache/httpd/blob/trunk/support/htpasswd.c), + # lines with only whitespace, or with "#" as first non-whitespace char, + # are left alone / ignored. + tmp = line.lstrip() + if not tmp or tmp.startswith("#"): + skipped += line + continue + + # parse valid line key, value = parse(line, idx+1) - if key not in records: - records[key] = value + + # NOTE: if multiple entries for a key, we use the first one, + # which seems to match htpasswd source + if key in records: + log.warning("username occurs multiple times in source file: %r" % key) + skipped += line + continue + + # flush buffer of skipped whitespace lines + if skipped: + source.append((_SKIPPED, skipped)) + skipped = b'' + + # store new user line + records[key] = value + source.append((_RECORD, key)) + + # don't bother preserving trailing whitespace, but do preserve trailing comments + if skipped.rstrip(): + source.append((_SKIPPED, skipped)) + + # NOTE: not replacing ._records until parsing succeeds, so loading is atomic. + self._records = records + self._source = source def _parse_record(self, record, lineno): # pragma: no cover - abstract method """parse line of file into (key, value) pair""" raise NotImplementedError("should be implemented in subclass") + def _set_record(self, key, value): + """ + helper for setting record which takes care of inserting source line if needed; + + :returns: + bool if key already present + """ + records = self._records + existing = (key in records) + records[key] = value + if not existing: + self._source.append((_RECORD, key)) + return existing + #=================================================================== # saving #=================================================================== @@ -255,9 +306,40 @@ class _CommonFile(object): """Export current state as a string of bytes""" return join_bytes(self._iter_lines()) + # def clean(self): + # """ + # discard any comments or whitespace that were being preserved from the source file, + # and re-sort keys in alphabetical order + # """ + # self._source = [(_RECORD, key) for key in sorted(self._records)] + # self._autosave() + def _iter_lines(self): """iterator yielding lines of database""" - return (self._render_record(key,value) for key,value in iteritems(self._records)) + # NOTE: this relies on <records> being an OrderedDict so that it outputs + # records in a deterministic order. + records = self._records + if __debug__: + pending = set(records) + for action, content in self._source: + if action == _SKIPPED: + # 'content' is whitespace/comments to write + yield content + else: + assert action == _RECORD + # 'content' is record key + if content not in records: + # record was deleted + # NOTE: doing it lazily like this so deleting & re-adding user + # preserves their original location in the file. + continue + yield self._render_record(content, records[content]) + if __debug__: + pending.remove(content) + if __debug__: + # sanity check that we actually wrote all the records + # (otherwise _source & _records are somehow out of sync) + assert not pending, "failed to write all records: missing=%r" % (pending,) def _render_record(self, key, value): # pragma: no cover - abstract method """given key/value pair, encode as line of file""" @@ -525,7 +607,7 @@ class HtpasswdFile(_CommonFile): #=================================================================== # NOTE: _records map stores <user> for the key, and <hash> for the value, - # both in bytes which use self.encoding + # both in bytes which use self.encoding #=================================================================== # init & serialization @@ -561,7 +643,9 @@ class HtpasswdFile(_CommonFile): #=================================================================== def users(self): - """Return list of all users in database""" + """ + Return list of all users in database + """ return [self._decode_field(user) for user in self._records] ##def has_user(self, user): @@ -588,14 +672,8 @@ class HtpasswdFile(_CommonFile): to prevent ambiguity with the dictionary method. The old alias is deprecated, and will be removed in Passlib 1.8. """ - user = self._encode_user(user) hash = self.context.hash(password) - if PY3: - hash = hash.encode(self.encoding) - existing = (user in self._records) - self._records[user] = hash - self._autosave() - return existing + return self.set_hash(user, hash) @deprecated_method(deprecated="1.6", removed="1.8", replacement="set_password") @@ -616,6 +694,24 @@ class HtpasswdFile(_CommonFile): except KeyError: return None + def set_hash(self, user, hash): + """ + semi-private helper which allows writing a hash directly; + adds user if needed. + + .. warning:: + does not (currently) do any validation of the hash string + + .. versionadded:: 1.7 + """ + # assert self.context.identify(hash), "unrecognized hash format" + if PY3 and isinstance(hash, str): + hash = hash.encode(self.encoding) + user = self._encode_user(user) + existing = self._set_record(user, hash) + self._autosave() + return existing + @deprecated_method(deprecated="1.6", removed="1.8", replacement="get_hash") def find(self, user): @@ -638,7 +734,9 @@ class HtpasswdFile(_CommonFile): return True def check_password(self, user, password): - """Verify password for specified user. + """ + Verify password for specified user. + If algorithm marked as deprecated by CryptContext, will automatically be re-hashed. :returns: * ``None`` if user not found. @@ -661,6 +759,7 @@ class HtpasswdFile(_CommonFile): ok, new_hash = self.context.verify_and_update(password, hash) if ok and new_hash is not None: # rehash user's password if old hash was deprecated + assert user in self._records # otherwise would have to use ._set_record() self._records[user] = new_hash self._autosave() return ok @@ -843,15 +942,21 @@ class HtdigestFile(_CommonFile): user, realm = key return render_bytes("%s:%s:%s\n", user, realm, hash) - def _encode_realm(self, realm): - # override default _encode_realm to fill in default realm field + def _require_realm(self, realm): if realm is None: realm = self.default_realm if realm is None: raise TypeError("you must specify a realm explicitly, " - "or set the default_realm attribute") + "or set the default_realm attribute") + return realm + + def _encode_realm(self, realm): + realm = self._require_realm(realm) return self._encode_field(realm, "realm") + def _encode_key(self, user, realm): + return self._encode_user(user), self._encode_realm(realm) + #=================================================================== # public methods #=================================================================== @@ -873,9 +978,7 @@ class HtdigestFile(_CommonFile): ##def has_user(self, user, realm=None): ## "check if user+realm combination exists" - ## user = self._encode_user(user) - ## realm = self._encode_realm(realm) - ## return (user,realm) in self._records + ## return self._encode_key(user,realm) in self._records ##def rename_realm(self, old, new): ## """rename all accounts in realm""" @@ -884,7 +987,7 @@ class HtdigestFile(_CommonFile): ## keys = [key for key in self._records if key[1] == old] ## for key in keys: ## hash = self._records.pop(key) - ## self._records[key[0],new] = hash + ## self._set_record((key[0], new), hash) ## self._autosave() ## return len(keys) @@ -894,7 +997,7 @@ class HtdigestFile(_CommonFile): ## new = self._encode_user(new) ## realm = self._encode_realm(realm) ## hash = self._records.pop((old,realm)) - ## self._records[new,realm] = hash + ## self._set_record((new, realm), hash) ## self._autosave() def set_password(self, user, realm=None, password=_UNSET): @@ -912,16 +1015,9 @@ class HtdigestFile(_CommonFile): if password is _UNSET: # called w/ two args - (user, password), use default realm realm, password = None, realm - user = self._encode_user(user) - realm = self._encode_realm(realm) - key = (user, realm) - existing = (key in self._records) + realm = self._require_realm(realm) hash = htdigest.hash(password, user, realm, encoding=self.encoding) - if PY3: - hash = hash.encode(self.encoding) - self._records[key] = hash - self._autosave() - return existing + return self.set_hash(user, realm, hash) @deprecated_method(deprecated="1.6", removed="1.8", replacement="set_password") @@ -929,7 +1025,6 @@ class HtdigestFile(_CommonFile): """set password for user""" return self.set_password(user, realm, password) - # XXX: rename to something more explicit, like get_hash()? def get_hash(self, user, realm=None): """Return :class:`~passlib.hash.htdigest` hash stored for user. @@ -941,7 +1036,7 @@ class HtdigestFile(_CommonFile): for clarity. The old name is deprecated, and will be removed in Passlib 1.8. """ - key = (self._encode_user(user), self._encode_realm(realm)) + key = self._encode_key(user, realm) hash = self._records.get(key) if hash is None: return None @@ -949,6 +1044,32 @@ class HtdigestFile(_CommonFile): hash = hash.decode(self.encoding) return hash + def set_hash(self, user, realm=None, hash=_UNSET): + """ + semi-private helper which allows writing a hash directly; + adds user & realm if needed. + + If ``self.default_realm`` has been set, this may be called + with the syntax ``set_hash(user, hash)``, + otherwise it must be called with all three arguments: + ``set_hash(user, realm, hash)``. + + .. warning:: + does not (currently) do any validation of the hash string + + .. versionadded:: 1.7 + """ + if hash is _UNSET: + # called w/ two args - (user, hash), use default realm + realm, hash = None, realm + # assert htdigest.identify(hash), "unrecognized hash format" + if PY3 and isinstance(hash, str): + hash = hash.encode(self.encoding) + key = self._encode_key(user, realm) + existing = self._set_record(key, hash) + self._autosave() + return existing + @deprecated_method(deprecated="1.6", removed="1.8", replacement="get_hash") def find(self, user, realm): @@ -965,7 +1086,7 @@ class HtdigestFile(_CommonFile): * ``True`` if user deleted, * ``False`` if user not found in realm. """ - key = (self._encode_user(user), self._encode_realm(realm)) + key = self._encode_key(user, realm) try: del self._records[key] except KeyError: |