"""A base class for SQLAlchemy-based repositories."""
from typing import Any, Generic, Iterable, List, Optional, Sequence, Tuple
from kerno.kerno import Kerno
from kerno.typing import DictStr, Entity
[docs]class SpyRepo:
"""Nice test double, can be inspected at the end of a test."""
def __init__(self, **kw) -> None: # noqa
self.new: List[Any] = []
self.deleted: List[Any] = []
self.flushed = False
for key, val in kw.items():
setattr(self, key, val)
[docs] def add(self, entity: Entity) -> Entity: # noqa
self.new.append(entity)
return entity
[docs] def add_all(self, entities: Sequence[Entity]) -> None: # noqa
self.new.extend(entities)
[docs] def delete(self, entity) -> None: # noqa
self.deleted.append(entity)
[docs] def flush(self) -> None: # noqa
self.flushed = True
[docs]class BaseSQLAlchemyRepository:
"""Base class for a SQLAlchemy-based repository."""
SAS = "session factory" # name of the SQLAlchemy session utility
def __init__(self, kerno: Kerno, session_factory: Any = None):
"""Construct a SQLAlchemy repository instance to serve ONE request.
- ``kerno`` is the Kerno instance for the current application.
- ``session_factory`` is a function that returns a
SQLAlchemy session to be used in this request -- scoped or not.
If not provided as an argument, get it from the
kerno utility registry (under the "session factory" name.
"""
self.kerno = kerno
self.sas = self.new_sas(session_factory or kerno.utilities[self.SAS])
assert self.sas
[docs] def new_sas(self, session_factory):
"""Obtain a new SQLAlchemy session instance."""
assert session_factory is not None
is_scoped_session = hasattr(session_factory, "query")
# Because we don't want to depend on SQLAlchemy:
if callable(session_factory) and not is_scoped_session:
return session_factory()
else:
return session_factory
[docs] def add(self, entity: Entity) -> Entity:
"""Add an object to the SQLAlchemy session, then return it."""
self.sas.add(entity)
return entity
[docs] def add_all(self, entities: Sequence[Entity]) -> None:
"""Add model instances to the SQLAlchemy session."""
self.sas.add_all(entities)
[docs] def delete(self, entity: Entity) -> None:
"""Delete an ``entity`` from the database."""
self.sas.delete(entity)
[docs] def flush(self) -> None:
"""Obtain IDs on new objects and update state on the database.
Without committing the transaction.
"""
self.sas.flush()
def _get_or_add(self, cls: type, **filters):
"""Retrieve or add an entity with ``filters``; return the entity.
The entity will have a transient ``_is_new`` flag telling you whether
it already existed.
This is a helper for the implementation of repository methods and
should not be used elsewhere.
"""
entity = self.sas.query(cls).filter_by(**filters).first()
if entity is None:
entity = cls(**filters)
self.add(entity=entity)
is_new = True
else:
is_new = False
assert not hasattr(entity, "_is_new")
entity._is_new = is_new
return entity
def _update_or_add(self, cls: type, props: DictStr = {}, **filters):
"""Load and modify entity if it exists, else create one.
First obtain either an existing object or a new one, based on
``filters``. Then apply ``props`` and return the entity.
The entity will have a transient ``_is_new`` flag telling you whether
it already existed.
This is a helper for the implementation of repository methods and
should not be used elsewhere.
"""
assert "_is_new" not in props
entity = self._get_or_add(cls, **filters)
for key, val in props.items():
setattr(entity, key, val)
return entity
[docs] def update_association(
self,
cls: type,
field: str,
ids: Sequence[int],
filters: DictStr,
synchronize_session=None,
) -> List[Entity]:
"""Update a many-to-many relationship. Return only NEW associations."""
return update_association(
cls=cls,
field=field,
ids=ids,
filters=filters,
sas=self.sas,
synchronize_session=synchronize_session,
)
[docs] def get_or_create(self, cls: type, **filters) -> Tuple[Any, bool]:
"""Retrieve or add object; return a tuple ``(object, is_new)``.
``is_new`` is False if the object already existed in the database.
"""
instance = self.sas.query(cls).filter_by(**filters).first()
is_new = not instance
if is_new:
instance = cls(**filters)
self.sas.add(instance)
return instance, is_new
[docs] def create_or_update(
self, cls: type, values: DictStr = {}, **filters
) -> Tuple[Any, bool]:
"""Load and update entity if it exists, else create one.
First obtain either an existing object or a new one, based on ``filters``.
Then apply ``values`` and return a tuple ``(object, is_new)``.
"""
instance, is_new = self.get_or_create(cls, **filters)
for k, v in values.items():
setattr(instance, k, v)
return instance, is_new
[docs]def update_association(
cls: type,
field: str,
ids: Sequence[int],
filters: DictStr,
sas,
synchronize_session=None,
) -> List[Entity]:
"""Update a many-to-many relationship. Return only NEW associations.
When you have a many-to-many relationship, there is an association
table between 2 main tables. The problem of setting the data in
this case is a recurring one and it is solved here.
Existing associations might be deleted and some might be created.
Example usage::
user = session.query(User).get(1)
# Suppose there's a many-to-many relationship to Address,
# through an entity in the middle named UserAddress which contains
# just the columns user_id and address_id.
new_associations = update_association(
cls=UserAddress, # the association class
field='address_id' # name of the remote foreign key
ids=[5, 42, 89], # the IDs of the user's addresses
filters={"user": user}, # to load existing associations
sas=my_sqlalchemy_session,
)
for item in new_associations:
print(item)
This method returns a list of any new association instances
because you might want to finish the job by doing something
more with them (e. g. setting other attributes).
A new query is needed to retrieve the totality of the associations.
"""
# Fetch eventually existing association IDs
existing_ids = frozenset(
[o[0] for o in sas.query(getattr(cls, field)).filter_by(**filters)]
)
# Delete association rows that we no longer want
desired_ids = frozenset(ids)
to_remove = existing_ids - desired_ids
if to_remove:
q_remove = (
sas.query(cls)
.filter_by(**filters)
.filter(getattr(cls, field).in_(to_remove))
)
if synchronize_session is not None:
q_remove.delete(synchronize_session=synchronize_session)
else:
for entity in q_remove:
sas.delete(entity)
# Create desired associations that do not yet exist
to_create = desired_ids - existing_ids
new_associations = []
for id in to_create:
association = cls(**filters)
setattr(association, field, id)
new_associations.append(association)
sas.add_all(new_associations)
return new_associations
[docs]class Query(Iterable, Generic[Entity]):
"""Typing stub for a returned SQLAlchemy query.
This is purposely very incomplete. It is intended to be used as return
value in repository methods, such that user code can use, but not change,
the returned query, which is what we like to do in this architecture.
If you want a more complete implementation, try
https://github.com/dropbox/sqlalchemy-stubs
Their query stub is at
https://github.com/dropbox/sqlalchemy-stubs/blob/master/sqlalchemy-stubs/orm/query.pyi
"""
[docs] def all(self) -> List[Entity]: # noqa
...
[docs] def count(self) -> int: # noqa
...
[docs] def delete(self) -> None: # noqa
...
# def exists(self): ... # noqa
[docs] def first(self) -> Optional[Entity]: # noqa
...
[docs] def get(self, ident) -> Optional[Entity]: # noqa
...
[docs] def one(self) -> Entity: # noqa
...
# def slice(self, start: int, stop: Optional[int]): ... # noqa
[docs] def yield_per(self, count: int) -> List[Entity]: # noqa
...