summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook@pioneer.com>2022-06-15 21:19:18 -0500
committerJordan Cook <jordan.cook.git@proton.me>2023-03-01 17:37:59 -0600
commita89bf63914cadc42f56d6d4f8b1c67660c4ff604 (patch)
treebec0566ba2d047f6b6da4e2ed2036e6e3594cf87
parentb4a8ed037dccc5ee4ac9b763551642bcd7c0e05f (diff)
downloadrequests-cache-a89bf63914cadc42f56d6d4f8b1c67660c4ff604.tar.gz
Add basic dict-like methods
-rw-r--r--requests_cache/backends/db.py124
-rw-r--r--requests_cache/serializers/cattrs.py8
2 files changed, 93 insertions, 39 deletions
diff --git a/requests_cache/backends/db.py b/requests_cache/backends/db.py
index a6d7ef4..67d6db2 100644
--- a/requests_cache/backends/db.py
+++ b/requests_cache/backends/db.py
@@ -1,18 +1,35 @@
-"""Generic relational database backend that works with any dialect supported by SQLAlchemy"""
+"""Generic relational database backend that works with any dialect supported by SQLAlchemy.
+
+**Example (PostgreSQL):**
+
+::
+
+ >>> from requests_cache import CachedSession, DBCache
+ >>> from sqlalchemy import create_engine
+ >>>
+ >>> engine = create_engine('postgresql://user@localhost:5432/postgres')
+ >>> session = CachedSession(backend=DBCache(engine))
+
+"""
import json
+from datetime import timedelta
+from typing import Optional, Type, TypeAlias
-from sqlalchemy import Column, DateTime, Integer, LargeBinary, String
+from requests.cookies import RequestsCookieJar, cookiejar_from_dict
+from requests.structures import CaseInsensitiveDict
+from sqlalchemy import Column, DateTime, Float, Integer, LargeBinary, String, delete, func, select
from sqlalchemy.engine import Engine
+from sqlalchemy.orm import Session as SA_Session
from sqlalchemy.orm import declarative_base
from ..models import CachedRequest, CachedResponse
from . import BaseCache, BaseStorage
-Base = declarative_base()
+Base: TypeAlias = declarative_base() # type: ignore
# TODO: Benchmark read & write with ORM model vs. creating a CachedResponse from raw SQL query
-class SQLResponse(Base):
+class DBResponse(Base):
"""Database model based on :py:class:`.CachedResponse`. Instead of full serialization, this maps
request attributes to database columns. The corresponding table is generated based on this model.
"""
@@ -23,7 +40,7 @@ class SQLResponse(Base):
cookies = Column(String)
content = Column(LargeBinary)
created_at = Column(DateTime, nullable=False, index=True)
- elapsed = Column(Integer)
+ elapsed = Column(Float)
expires = Column(DateTime, index=True)
encoding = Column(String)
headers = Column(String)
@@ -36,22 +53,23 @@ class SQLResponse(Base):
status_code = Column(Integer, nullable=False)
url = Column(String, nullable=False, index=True)
+ # TODO: Maybe these conversion methods should be moved to a different module as a "serializer"?
@classmethod
def from_cached_response(cls, response: CachedResponse):
"""Convert from db model into CachedResponse (to emulate the original response)"""
return cls(
key=response.cache_key,
- cookies=json.dumps(response.cookies),
+ cookies=_save_cookies(response.cookies),
content=response.content,
created_at=response.created_at,
- elapsed=response.elapsed,
+ elapsed=response.elapsed.total_seconds(),
expires=response.expires,
encoding=response.encoding,
- headers=json.dumps(response.headers),
+ headers=_save_headers(response.headers),
reason=response.reason,
request_body=response.request.body,
- request_cookies=json.dumps(response.request.cookies),
- request_headers=json.dumps(response.request.headers),
+ request_cookies=_save_cookies(response.request.cookies),
+ request_headers=_save_headers(response.request.headers),
request_method=response.request.method,
request_url=response.request.url,
status_code=response.status_code,
@@ -59,22 +77,22 @@ class SQLResponse(Base):
)
def to_cached_response(self) -> CachedResponse:
- """Convert from CachedResponse to db model (so SA can handle dialect-specific types, etc.)"""
+ """Convert from CachedResponse to db model (so SA can handle dialect-specific behavior)"""
request = CachedRequest(
body=self.request_body,
- cookies=json.loads(self.request_cookies) if self.request_cookies else None,
- headers=json.loads(self.request_headers) if self.request_headers else None,
+ cookies=_load_cookies(self.request_cookies),
+ headers=_load_headers(self.request_headers),
method=self.request_method,
url=self.request_url,
)
obj = CachedResponse(
- cookies=json.loads(self.cookies) if self.cookies else None,
+ cookies=_load_cookies(self.cookies),
content=self.content,
created_at=self.created_at,
- elapsed=self.elapsed,
+ elapsed=timedelta(seconds=float(self.elapsed)),
expires=self.expires,
encoding=self.encoding,
- headers=json.loads(self.headers) if self.headers else None,
+ headers=_load_headers(self.headers),
reason=self.reason,
request=request,
status_code=self.status_code,
@@ -84,39 +102,77 @@ class SQLResponse(Base):
return obj
-class SQLRedirect(Base):
+class DBRedirect(Base):
__tablename__ = 'redirect'
- redirect_key = Column(String, primary_key=True)
+ key = Column(String, primary_key=True)
response_key = Column(String, index=True)
-class DbCache(BaseCache):
+class DBCache(BaseCache):
def __init__(self, engine: Engine, **kwargs):
super().__init__(**kwargs)
- self.redirects = DbStorage(engine, model=SQLResponse, **kwargs)
- self.responses = DbStorage(engine, model=SQLRedirect, **kwargs)
+ # Create tables if they don't exist
+ Base.metadata.create_all(engine)
+ session = SA_Session(engine, future=True)
+
+ self.responses = DBStorage(session, model=DBResponse, **kwargs)
+ # TODO: Separate class for handling redirects
+ self.redirects = BaseStorage() # DBStorage(session, model=DBRedirect, **kwargs)
-class DbStorage(BaseStorage):
- def __init__(self, engine: Engine, model, **kwargs):
+class DBStorage(BaseStorage):
+ def __init__(self, session: SA_Session, model: Type[Base], **kwargs):
super().__init__(no_serializer=True, **kwargs)
- self.engine = engine
+ self.session = session
self.model = model
- def __getitem__(self, key):
- pass
+ def __getitem__(self, key: str) -> CachedResponse:
+ with self.session:
+ stmt = select(DBResponse).where(DBResponse.key == key)
+ result = self.session.execute(stmt).fetchone()
+ if not result:
+ raise KeyError
+ return result[0].to_cached_response()
- def __setitem__(self, key, value):
- pass
+ def __setitem__(self, key: str, value: CachedResponse):
+ value.cache_key = key
+ self.session.merge(DBResponse.from_cached_response(value))
+ self.session.commit()
- def __delitem__(self, key):
- pass
+ def __delitem__(self, key: str):
+ stmt = delete(DBResponse).where(DBResponse.key == key)
+ if not self.session.execute(stmt).rowcount:
+ raise KeyError
def __iter__(self):
- pass
+ for row in self.session.execute(select(DBResponse)).all():
+ yield row
- def __len__(self):
- pass
+ def __len__(self) -> int:
+ stmt = select([func.count()]).select_from(DBResponse)
+ return self.session.execute(stmt).scalar()
def clear(self):
- pass
+ self.session.execute(delete(DBResponse))
+
+
+def _load_cookies(cookies_str: str) -> RequestsCookieJar:
+ try:
+ return cookiejar_from_dict(json.loads(cookies_str))
+ except (TypeError, ValueError):
+ return RequestsCookieJar()
+
+
+def _save_cookies(cookies: RequestsCookieJar) -> Optional[str]:
+ return json.dumps(cookies.get_dict()) if cookies else None
+
+
+def _load_headers(headers_str: str) -> CaseInsensitiveDict:
+ try:
+ return json.loads(headers_str)
+ except (TypeError, ValueError):
+ return CaseInsensitiveDict()
+
+
+def _save_headers(headers: CaseInsensitiveDict) -> Optional[str]:
+ return json.dumps(dict(headers)) if headers else None
diff --git a/requests_cache/serializers/cattrs.py b/requests_cache/serializers/cattrs.py
index bf45c2d..935d064 100644
--- a/requests_cache/serializers/cattrs.py
+++ b/requests_cache/serializers/cattrs.py
@@ -100,10 +100,10 @@ def init_converter(
converter.register_unstructure_hook(
timedelta, lambda obj: obj.total_seconds() if obj else None
)
- converter.register_structure_hook(timedelta, _to_timedelta)
+ converter.register_structure_hook(timedelta, lambda obj, cls: timedelta(seconds=float(obj)))
# Convert dict-like objects to and from plain dicts
- converter.register_unstructure_hook(RequestsCookieJar, lambda obj: dict(obj.items()))
+ converter.register_unstructure_hook(RequestsCookieJar, lambda obj: obj.get_dict())
converter.register_structure_hook(RequestsCookieJar, lambda obj, cls: cookiejar_from_dict(obj))
converter.register_unstructure_hook(CaseInsensitiveDict, dict)
converter.register_structure_hook(
@@ -173,8 +173,6 @@ def _to_datetime(obj, cls) -> datetime:
def _to_timedelta(obj, cls) -> timedelta:
- if isinstance(obj, (int, float)):
- obj = timedelta(seconds=obj)
- elif isinstance(obj, Decimal):
+ if not isinstance(obj, timedelta):
obj = timedelta(seconds=float(obj))
return obj