summaryrefslogtreecommitdiff
path: root/src/saml2/mongo_store.py
diff options
context:
space:
mode:
authorFredrik Thulin <fredrik@thulin.net>2013-06-11 15:04:26 +0200
committerFredrik Thulin <fredrik@thulin.net>2013-06-11 15:04:26 +0200
commitc16d9792c6fa1965822f3bfdc31309948ae1f0be (patch)
treee76fe7aeb76350f8e5d60a32d3f8fe606b628e4b /src/saml2/mongo_store.py
parent93a9bf3a6a68cd15d4e40d45612a03a738ad8c8d (diff)
downloadpysaml2-c16d9792c6fa1965822f3bfdc31309948ae1f0be.tar.gz
_mdb_get_database: don't require tuple for URI
Diffstat (limited to 'src/saml2/mongo_store.py')
-rw-r--r--src/saml2/mongo_store.py46
1 files changed, 25 insertions, 21 deletions
diff --git a/src/saml2/mongo_store.py b/src/saml2/mongo_store.py
index 315f4a47..607678ee 100644
--- a/src/saml2/mongo_store.py
+++ b/src/saml2/mongo_store.py
@@ -4,6 +4,7 @@ import logging
from pymongo import MongoClient
from pymongo.mongo_replica_set_client import MongoReplicaSetClient
import pymongo.uri_parser
+import pymongo.errors
from saml2.eptid import Eptid
from saml2.mdstore import MetaData
from saml2.s_utils import PolicyError
@@ -251,39 +252,42 @@ class MDB(object):
-def _mdb_get_database(database, **kwargs):
+def _mdb_get_database(uri, **kwargs):
"""
Helper-function to connect to MongoDB and return a database object.
+ The `uri' argument should be either a full MongoDB connection URI string,
+ or just a database name in which case a connection to the default mongo
+ instance at mongodb://localhost:27017 will be made.
+
Performs explicit authentication if a username is provided in a connection
string URI, since PyMongo does not always seem to do that as promised.
- If database is *not* a connection to mongodb://localhost:27017 will be
- made.
-
:params database: name as string or (uri, name)
:returns: pymongo database object
"""
- _conn = None
- if isinstance(database, tuple):
- uri, collection, = database
- if uri.startswith("mongodb://"):
- if "replicaSet=" in uri:
- connection_factory = MongoReplicaSetClient
- else:
- connection_factory = MongoClient
+ connection_factory = MongoClient
+ _parsed_uri = {}
+ db_name = None
+ try:
+ _parsed_uri = pymongo.uri_parser.parse_uri(uri)
+ except pymongo.errors.InvalidURI:
+ # assume URI to be just the database name
+ db_name = uri
+ pass
+ else:
+ if "replicaset" in _parsed_uri["options"]:
+ connection_factory = MongoReplicaSetClient
+ db_name = _parsed_uri.get("database", "pysaml2")
- if not 'tz_aware' in kwargs:
- # default, but not forced
- kwargs['tz_aware'] = True
+ if not "tz_aware" in kwargs:
+ # default, but not forced
+ kwargs["tz_aware"] = True
- _conn = connection_factory(uri, **kwargs)
- if not _conn:
- _conn = MongoClient()
+ _conn = connection_factory(uri, **kwargs)
+ _db = _conn[db_name]
- _parsed_uri = pymongo.uri_parser.parse_uri(uri)
- _db = _conn[_parsed_uri.get("database")]
- if _parsed_uri.get("username", None):
+ if "username" in _parsed_uri:
_db.authenticate(
_parsed_uri.get("username", None),
_parsed_uri.get("password", None)