summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/horizontal_shard.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/horizontal_shard.py')
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py24
1 files changed, 14 insertions, 10 deletions
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
index d20fbd484..8902ae606 100644
--- a/lib/sqlalchemy/ext/horizontal_shard.py
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -61,17 +61,21 @@ class ShardedQuery(Query):
# were done, this is where it would happen
return iter(partial)
- def get(self, ident, **kwargs):
- if self._shard_id is not None:
- return super(ShardedQuery, self).get(ident)
- else:
- ident = util.to_list(ident)
- for shard_id in self.id_chooser(self, ident):
- o = self.set_shard(shard_id).get(ident, **kwargs)
- if o is not None:
- return o
+ def _get_impl(self, ident, fallback_fn):
+ def _fallback(query, ident):
+ if self._shard_id is not None:
+ return fallback_fn(self, ident)
else:
- return None
+ ident = util.to_list(ident)
+ for shard_id in self.id_chooser(self, ident):
+ q = self.set_shard(shard_id)
+ o = fallback_fn(q, ident)
+ if o is not None:
+ return o
+ else:
+ return None
+
+ return super(ShardedQuery, self)._get_impl(ident, _fallback)
class ShardedSession(Session):