RSS

The DAO Pattern with SQLAlchemy

We learned about DAO pattern in the database chapter. The chapter contains an example using the SQLite database. The example implementation ProjectDaoSqliteImplementation uses plain SQL to access the database.

But often ORMs are used to access the database to prevent having to write repeating mapping code. You might think the DAO pattern enables me to change my database, and SQLAlchemy already offers support for multiple databases. So I don’t have to use the DAO pattern.

But that’s a bad idea. The DAO pattern is not only about changing your database. It’s about being able to change your database technology quickly. And that includes an ORM like SQLAlchemy. What if SQLAlchemy is not supported someday anymore or does not offer all the features you need? SQLAlchemy only supports relational databases. What happens if you want to use a document database like MongoDB or any other NoSQL database? It’s very hard or even impossible to change, because you have to touch every single file in your codebase interacting with the database. But if the database technology is encapsulated in a DataAccessObject, it’s easy to change the implementation of the DataAccessObject.

To allow for an easy change, the DataTransferObject has to encapsulate all database related technology, including SQLAlchemy. To encapsulate, we still have a plain Project TransferObject. It contains no database-related code. ORMs like SQLAlchemy use mapping classes to map objects to the database. We define a private mapping class __ProjectMapping, which is only used by the DAO implementation. In our example, this will beProjectDaoSqliteSqlAlchemyImplementation. ProjectDaoSqliteSqlAlchemyImplementation will map between Project and __ProjectMapping. Queries will only be written in the DataAccessObject implementation because they use SQLAlchemy specific API.

Of course, this requires more effort. But in the end, the code will be much cleaner and easier to maintain.

Let’s take a look at the example code:

from collections import Iterable

import attr
from sqlalchemy import Integer, Column, String
from sqlalchemy.ext.declarative import declarative_base
from typing import Optional


@attr.s
class Project(object):
    id = attr.ib(type=int)
    name = attr.ib(type=str)


class ProjectDao(object):
    def save(self, project):
        # type: (Project) -> Project
        raise NotImplementedError()

    def find_by_id(self, id):
        # type: (int) -> Project
        raise NotImplementedError()

    def find_all(self):
        # type: () -> Iterable[Project]
        raise NotImplementedError()

    def exists_by_name(self, name):
        # type: (str) -> bool
        raise NotImplementedError()


__Base = declarative_base()


class __ProjectMapping(__Base):
    __tablename__ = "projects"

    id = Column(Integer, primary_key=True, autoincrement=True)
    name = Column(String)


class ProjectDaoSqliteSqlAlchemyImplementation(ProjectDao):
    def __init__(self, session_maker):
        self._session_maker = session_maker

    def save(self, project):
        # type: (Project) -> Project
        session = self._session_maker()

        project_mapping = __ProjectMapping(
            id=project.id if project.id else None, name=project.name
        )

        session.add(project_mapping)
        session.flush()

        project = Project(id=project_mapping.id, name=project.name)

        session.commit()
        session.close()

        return project

    def find_by_id(self, id_):
        # type: (int) -> Optional[Project]
        session = self._session_maker()

        project = session.query(__ProjectMapping).filter(__ProjectMapping.id == id_).first()

        session.close()

        return project

    def find_all(self):
        # type: () -> Iterable[Project]
        session = self._session_maker()

        projects = session.query(__ProjectMapping).all()

        session.close()

        return projects

    def exists_by_name(self, name):
        # type: (str) -> bool
        session = self._session_maker()

        project = (
            session.query(__ProjectMapping).filter(__ProjectMapping.name == name).first()
        )

        session.close()

        if project:
            return True
        return False


class ProjectCreator(object):
    _project_dao = None  # type: ProjectDao

    def __init__(self, project_dao):
        self._project_dao = project_dao

    def create_project(self, name):
        assert name

        if self._project_dao.exists_by_name(name):
            raise ValueError("Project with name {} already exists!".format(name))

        project = Project(id=0, name=name)
        project = self._project_dao.save(project)
        print("Saved project {}".format(project))