diff options
author | Fredrik Thulin <fredrik@thulin.net> | 2013-06-11 15:04:26 +0200 |
---|---|---|
committer | Fredrik Thulin <fredrik@thulin.net> | 2013-06-11 15:04:26 +0200 |
commit | c16d9792c6fa1965822f3bfdc31309948ae1f0be (patch) | |
tree | e76fe7aeb76350f8e5d60a32d3f8fe606b628e4b /src/saml2/mongo_store.py | |
parent | 93a9bf3a6a68cd15d4e40d45612a03a738ad8c8d (diff) | |
download | pysaml2-c16d9792c6fa1965822f3bfdc31309948ae1f0be.tar.gz |
_mdb_get_database: don't require tuple for URI
Diffstat (limited to 'src/saml2/mongo_store.py')
-rw-r--r-- | src/saml2/mongo_store.py | 46 |
1 files changed, 25 insertions, 21 deletions
diff --git a/src/saml2/mongo_store.py b/src/saml2/mongo_store.py index 315f4a47..607678ee 100644 --- a/src/saml2/mongo_store.py +++ b/src/saml2/mongo_store.py @@ -4,6 +4,7 @@ import logging from pymongo import MongoClient from pymongo.mongo_replica_set_client import MongoReplicaSetClient import pymongo.uri_parser +import pymongo.errors from saml2.eptid import Eptid from saml2.mdstore import MetaData from saml2.s_utils import PolicyError @@ -251,39 +252,42 @@ class MDB(object): -def _mdb_get_database(database, **kwargs): +def _mdb_get_database(uri, **kwargs): """ Helper-function to connect to MongoDB and return a database object. + The `uri' argument should be either a full MongoDB connection URI string, + or just a database name in which case a connection to the default mongo + instance at mongodb://localhost:27017 will be made. + Performs explicit authentication if a username is provided in a connection string URI, since PyMongo does not always seem to do that as promised. - If database is *not* a connection to mongodb://localhost:27017 will be - made. - :params database: name as string or (uri, name) :returns: pymongo database object """ - _conn = None - if isinstance(database, tuple): - uri, collection, = database - if uri.startswith("mongodb://"): - if "replicaSet=" in uri: - connection_factory = MongoReplicaSetClient - else: - connection_factory = MongoClient + connection_factory = MongoClient + _parsed_uri = {} + db_name = None + try: + _parsed_uri = pymongo.uri_parser.parse_uri(uri) + except pymongo.errors.InvalidURI: + # assume URI to be just the database name + db_name = uri + pass + else: + if "replicaset" in _parsed_uri["options"]: + connection_factory = MongoReplicaSetClient + db_name = _parsed_uri.get("database", "pysaml2") - if not 'tz_aware' in kwargs: - # default, but not forced - kwargs['tz_aware'] = True + if not "tz_aware" in kwargs: + # default, but not forced + kwargs["tz_aware"] = True - _conn = connection_factory(uri, **kwargs) - if not _conn: - _conn = MongoClient() + _conn = connection_factory(uri, **kwargs) + _db = _conn[db_name] - _parsed_uri = pymongo.uri_parser.parse_uri(uri) - _db = _conn[_parsed_uri.get("database")] - if _parsed_uri.get("username", None): + if "username" in _parsed_uri: _db.authenticate( _parsed_uri.get("username", None), _parsed_uri.get("password", None) |