"""Functions that help define SQLAlchemy models."""
import re
from datetime import datetime
from typing import List, Tuple, Union
from warnings import warn
from sqlalchemy import Table, Column, ForeignKey, Sequence
from sqlalchemy.orm import backref as _backref, class_mapper, ColumnProperty
from sqlalchemy.orm.attributes import CollectionAttributeImpl, ScalarObjectAttributeImpl
from sqlalchemy.orm.dynamic import DynamicAttributeImpl
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.types import Integer, DateTime, Unicode
from bag.settings import resolve
from bag.web.exceptions import Problem
from ..web import gravatar_image
# http://docs.sqlalchemy.org/en/latest/orm/cascades.html
CASC = "all, delete-orphan"
def now_column(nullable: bool = False, **k) -> Column:
"""Return a DateTime column that defaults to utcnow."""
return Column(DateTime, default=datetime.utcnow, nullable=nullable, **k)
def get_col(model, attribute_name):
"""Introspect the SQLAlchemy ``model``; return the column object.
...for ``attribute_name``. E.g.: ``get_col(User, 'email')``
"""
return model._sa_class_manager.mapper.columns[attribute_name]
def _get_length(col):
return None if col is None else getattr(col.type, "length", None)
def get_length(model, field):
"""Return the length of column ``field`` of a SQLAlchemy ``model``."""
return _get_length(get_col(model, field))
def col(attrib):
"""Return the column that stores an ``attrib`` of a model.
Given a sqlalchemy.orm.attributes.InstrumentedAttribute
(type of the attributes of model classes),
return the corresponding column. E.g.: ``col(User.email)``
"""
return attrib.property.columns[0]
def length(attrib):
"""Return the length of the ``attrib``."""
return _get_length(col(attrib))
def fk(
attrib, nullable=False, index=True, primary_key=False, doc=None, ondelete="CASCADE"
):
"""Return a ForeignKey column while automatically setting the type."""
assert ondelete in (
"CASCADE", # Creates ON DELETE CASCADE
"SET NULL", # Creates ON DELETE SET NULL
None, # Creates ON DELETE NO ACTION, with more runtime errors
)
column = col(attrib)
return Column(
column.copy().type,
ForeignKey(column, ondelete=ondelete),
doc=doc,
index=index,
primary_key=primary_key,
nullable=nullable,
)
def fk_rel(
cls,
attrib="id",
nullable=False,
index=True,
primary_key=False,
doc=None,
ondelete="CASCADE",
backref=None,
order_by=None,
lazy="select",
):
"""Return a ForeignKey column and a relationship.
Automatically sets the type of the foreign key.
Usage::
# A relationship in an Address model pointing to a parent Person:
person_id, person = fk_rel(Person, nullable=False,
index=True, backref='addresses', ondelete='CASCADE')
A backref is created only if you provide its name in the argument.
``nullable`` and ``index`` are usually ommited, because these are the
default values and they are good.
``ondelete`` is "CASCADE" by default, but you can set it to "SET NULL",
or None which translates to "NO ACTION" (less interesting).
If provided, ``order_by`` is used on the backref.
To load the backref greedily, use ``lazy='joined'`` as per
http://docs.sqlalchemy.org/en/latest/orm/loading_relationships.html
You may also pass an ``attrib`` which is the column name for
the foreign key.
"""
# http://docs.sqlalchemy.org/en/latest/orm/collections.html#passive-deletes
from sqlalchemy.orm import relationship
if ondelete == "CASCADE":
cascade = CASC
passive_deletes = True
else:
cascade = False # meaning "save-update, merge"
passive_deletes = False
return (
fk(
getattr(cls, attrib),
nullable=nullable,
index=index,
primary_key=primary_key,
doc=doc,
ondelete=ondelete,
),
relationship(
cls,
backref=_backref(
backref,
cascade=cascade,
passive_deletes=passive_deletes,
order_by=order_by,
lazy=lazy,
),
)
if backref
else relationship(cls),
)
def many_to_many(Model1, Model2, pk1="id", pk2="id", metadata=None, backref=None):
"""Easily set up a many-to-many relationship between 2 existing models.
Return an association table and the relationship itself.
Usage:
customer_user, Customer.users = many_to_many(Customer, User,
pk2='__id__')
"""
from sqlalchemy.orm import relationship
table1 = Model1.__tablename__
table2 = Model2.__tablename__
col1 = col(getattr(Model1, pk1))
col2 = col(getattr(Model2, pk2))
type1 = col1.copy().type
type2 = col2.copy().type
metadata = metadata or Model1.__table__.metadata
association = Table(
table1 + "_" + table2,
metadata,
Column(
table1 + "_id",
type1,
ForeignKey(table1 + "." + col1.name),
nullable=False,
index=True,
),
Column(
table2 + "_id",
type2,
ForeignKey(table2 + "." + col2.name),
nullable=False,
index=True,
),
)
backref = backref or table1 + "s"
rel = relationship(Model2, secondary=association, backref=backref)
return association, rel
def pk(tablename: str) -> Column:
"""Return a primary key column."""
# The type must be Integer for Sequences to work, AFAICT.
# Maybe this problem is in Python only?
return Column(
Integer, Sequence(tablename + "_id_seq"), primary_key=True, autoincrement=True
)
def is_model_class(val) -> bool:
"""Return whether the parameter is a SQLAlchemy model class."""
return hasattr(val, "__base__") and hasattr(val, "__table__")
def models_and_tables_in(arg) -> Tuple[List, List]:
"""Return 2 lists containing the model classes and tables in ``arg``.
``arg`` may be a resource spec, a module or a dictionary::
models, tables = models_and_tables_in(globals())
"""
if not isinstance(arg, dict):
arg = resolve(arg) # ensure arg is a python module
arg = arg.__dict__
models = [o for o in arg.values() if is_model_class(o)]
tables = [o for o in arg.values() if isinstance(o, Table)]
return models, tables
def model_property_names(
cls, whitelist=None, blacklist=None, include_relationships=True
):
"""Return the property names in the passed class, maybe filtered."""
names = (str(n).split(".")[1] for n in cls.__mapper__.iterate_properties)
filtered = []
for name in names:
if blacklist and name in blacklist:
continue
if whitelist and name not in whitelist:
continue
if not include_relationships and isinstance(
getattr(cls, name).impl,
(CollectionAttributeImpl, DynamicAttributeImpl, ScalarObjectAttributeImpl),
):
continue
filtered.append(name)
return filtered
def foreign_key_from_col(col):
# I don't know how there would ever be more than one item in this, so:
for fk in col.foreign_keys: # foreign_keys is, strangely, a set
return fk
def foreign_keys_in(cls):
filtered = {}
for name in model_property_names(cls, include_relationships=False):
a_set = getattr(cls, name).expression.foreign_keys
for fk in a_set:
# I don't understand why there would ever be more than one item
# in this, so:
filtered[name] = fk
break
return filtered
def persistent_attribute_names_of(cls):
"""Return a list of the names of the persistent attributes of ``cls``.
...except collections.
"""
# return [x for x in dir(cls) if isinstance(
# getattr(cls, x), InstrumentedAttribute)]
return [
prop.key
for prop in class_mapper(cls).iterate_properties
if isinstance(prop, ColumnProperty)
]
[docs]class MinimalBase:
"""Declarative base class that auto-generates __tablename__."""
__table_args__: Union[dict, tuple] = {
"mysql_engine": "InnoDB",
"mysql_charset": "utf8",
}
@declared_attr
def __tablename__(cls):
"""Convert CamelCase class to underscores_between_words table name."""
name = cls.__name__.replace("Mixin", "")
return name[0].lower() + re.sub(
r"([A-Z])", lambda m: "_" + m.group(0).lower(), name[1:]
)
[docs] def update(self, adict, transient=False):
"""Merge dictionary into this entity.
Optionally check that the keys exist.
"""
for k, v in adict.items():
if not transient:
assert hasattr(
type(self), k
), "Model {} does not have a '{}' attribute.".format(
type(self).__name__, k
)
setattr(self, k, v)
return self
[docs] def update_from_schema(self, schema, adict):
"""Validate ``adict`` against ``schema``; return updated entity.
Validates the information in the dictionary ``adict`` against
a Colander ``schema``. If validation fails, colander.Invalid
is raised. If happy, returns the updated model instance.
"""
schema._model_instance = self # makes some validations easier
clean = schema.deserialize(adict) # May raise colander.Invalid
self.update(clean)
return self
[docs] def clone(self, values=None, pk="id", sas=None):
"""Return a clone of this model.
Optionally update some of its ``values``.
Optionally add the clone to the ``sas`` session.
The name of the primary key column should be given as ``pk``.
Although in general model methods should not use the session, the recursive
nature of this one seems to require it.
"""
attrs = persistent_attribute_names_of(self.__class__)
adict = {}
for attr in attrs:
adict[attr] = getattr(self, attr)
if pk:
del adict[pk]
if values:
adict.update(values)
clone = self.__class__(**adict)
if sas: # Optionally add the clone to the SQLAlchemy session
sas.add(clone)
return clone
class ID:
"""Mixin class that includes a primary key column "id"."""
@declared_attr
def id(cls):
"""Primary key column for your model."""
return Column(Integer, autoincrement=True, primary_key=True)
class Names:
"""Mixin class that includes 2 ways to handle a person's names."""
@declared_attr
def full_name(cls): # noqa
return Column(Unicode(120), nullable=False)
@declared_attr
def short_name(cls): # noqa
return Column(Unicode(16), nullable=False)
@property
def display_name(self): # noqa
return self.short_name or self.full_name
@property
def formal_name(self): # noqa
return self.full_name or self.short_name
class AddressBase:
"""Base class for addresses.
In subclasses you can just define ``__tablename__``, ``id``,
the foreign key, and maybe indexes.
"""
# __tablename__ = 'customer'
# pk = pk(__tablename__)
street = Column("street", Unicode(160), default="")
district = Column("district", Unicode(80), default="")
city = Column("city", Unicode(80), default="")
province = Column("province", Unicode(40), default="")
country_code = Column("country_code", Unicode(2), default="")
postal_code = Column("postal_code", Unicode(16), default="", doc="Zip code")
# kind = Column(Unicode(1), default='',
# doc="c for commercial, r for residential")
# charge = Column(Boolean, default=False,
# doc="Whether this is the address to bill to.")
# comment = Column(Unicode, default='')
class EmailParts:
"""Mixin class that stores an email address in 2 columns.
One column contains the local part, another contains the domain.
This makes it easy to find emails from the same domain.
Typical usage:
.. code-block:: python
class Customer(SABase, EmailParts):
__table_args__ = (UniqueConstraint('email_local', 'email_domain',
name='customer_email_key'), {})
"""
email_local = Column("email_local", Unicode(160), nullable=False)
email_domain = Column("email_domain", Unicode(255), nullable=False)
@hybrid_property
def email(self):
"""Get or set the entire email, in Python or in the RDBMS."""
return self.email_local + "@" + self.email_domain
@email.setter
def set_email(self, val):
self.email_local, self.email_domain = val.split("@")
if not self.email_local:
raise Problem("Missing the local part of the email address.")
if not self.email_domain:
raise Problem("Missing the domain part of the email address.")
def gravatar_image(
self,
default: str = "mm",
size: int = 80,
cacheable: bool = True,
) -> str:
"""Return the URL for the gravatar image for this email address."""
return gravatar_image(
self.email, default=default, size=size, cacheable=cacheable
)
def commit_session_or_transaction(sas) -> None:
"""Not sure if using the transaction package or not? No problem."""
try:
sas.commit()
except AssertionError as exc:
if str(exc) == "Transaction must be committed using " "the transaction manager":
import transaction
transaction.commit()
else:
raise
class SubtransactionTrick:
"""Encloses your code in a subtransaction. Good for writing tests.
Usage::
trick = SubtransactionTrick(my_engine, sessionmaker)
# Be sure to use the session provided as the ``sas`` variable:
my_session = trick.sas
# Finally, call ``close()`` to roll back the changes:
trick.close()
"""
def __init__(self, engine, sessionmaker):
"""Constructor.
- ``engine`` should be a completely configured SQLAlchemy engine.
- ``sessionmaker`` should be a session factory that can be bound
to a specific connection.
"""
self.connection = engine.connect()
# begin a non-ORM transaction
self.transaction = self.connection.begin()
# Base.metadata.bind = connection
# bind an individual Session to the connection
if hasattr(sessionmaker, "query"): # scoped session detected
sessionmaker.configure(bind=self.connection)
self.sas = sessionmaker
else: # not a scoped session
self.sas = sessionmaker()(bind=self.connection)
def close(self):
"""Roll back everything that happened with the session.
...including calls to commit().
"""
self.transaction.rollback()
self.sas.close()
# self.connection.close()