SQL Alchemyを魔改造した

  • 9
    Like
  • 0
    Comment
More than 1 year has passed since last update.

PyramidやFlaskの標準O/RマッパーであるSQL Alchemyを魔改造しました。

SQL Alchemyのsyntax問題

スクリーンショット 2015-12-15 11.36.53.png

魔改造した

シンプルに記述できるように魔改造してみました。また速度面ではDBコネクションをThreadLocalStorageを使ってconnection poolingしてるので未使用と比較して8倍速くらいで動作します。ローカル環境だとselectクエリ1回8msから1msに高速化しました。

SQLAlchemy魔改造後の使用例
# -*- coding: utf-8 -*-
from module.book import Book

# select
book1 = Book.get(1)
books = Book.objects().filter(Book.price==2160).all()

# insert
book_rye = Book(pk=None,
                title="The catcher in the rye",
                price=1000,
                publish="J. D. Salinger",
                published="")
book_rye = Book.insert(book_rye)
print(book_rye.id)

# update
book_rye.price = 1200
book_rye.save()

# delete
book_rye.delete()

book.py
# -*- coding: utf-8 -*-
from sqlalchemy import Column, String, Integer
from sqlalchemy.ext.declarative import declarative_base
from module.db.base import DBBaseMixin

Base = declarative_base()


class Book(DBBaseMixin, Base):
    title = Column('title', String(200))
    price = Column('price', Integer)
    publish = Column('publish', String(200))
    published = Column('published', String(200))

    def __init__(self, pk, title, price, publish, published):
        self.pk = pk
        self.title = title
        self.price = price
        self.publish = publish
        self.published = published
base_mixin.py
# -*- coding: utf-8 -*-
from sqlalchemy import Column, Integer
from sqlalchemy.ext.declarative import declarative_base, declared_attr
import re
from utils.db import get_db_session

Base = declarative_base()


def camel_to_snake(s):
    """
    >>> convert('CamelCase')
    'camel_case'
    >>> convert('CamelCamelCase')
    'camel_camel_case'
    >>> convert('Camel2Camel2Case')
    'camel2_camel2_case'
    >>> convert('getHTTPResponseCode')
    'get_http_response_code'
    >>> convert('get2HTTPResponseCode')
    'get2_http_response_code'
    >>> convert('HTTPResponseCode')
    'http_response_code'
    >>> convert('HTTPResponseCodeXYZ')
    'http_response_code_xyz'
    :param s: str
    :return: str
    """
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', s)
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


class DBBaseMixin(object):
    id = Column('id', Integer, primary_key=True)

    def __repr__(self):
        return '<{0}.{1} object at {2}>'.format(
            self.__module__, type(self).__name__, hex(id(self)))

    @declared_attr
    def __tablename__(cls):
        return camel_to_snake(cls.__name__)

    @classmethod
    def objects(cls):
        """
        :rtype : sqlalchemy.orm.query.Query
        """
        return get_db_session().query(cls)

    @classmethod
    def session(cls):
        return get_db_session()

    @classmethod
    def get(cls, pk):
        """
        :param pk: int
        :rtype: cls
        """
        return cls.objects().get(pk)

    @classmethod
    def insert(cls, obj):
        """
        :param obj: cls
        :rtype: cls
        """
        cls.session().add(obj)
        cls.session().commit()
        return obj

    @classmethod
    def bulk_insert(cls, objs):
        """
        :param objs: list[cls]
        :rtype: list[cls]
        """
        cls.session().add_all(objs)
        cls.session().commit()
        return objs

    def delete(self):
        session = self.__class__.session()
        session.query(self.__class__).filter(self.__class__.id==self.id).delete()
        session.commit()

    def save(self):
        """
        :return:self
        """
        session = self.__class__.session()
        session.add(self)
        session.commit()
        return self

db.py
# -*- coding: utf-8 -*-
import threading
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker

tls = threading.local()


def get_db_session():
    """
    SQL Alchemy のDBセッションを生成して使い回す
    :rtype : scoped_session
    """
    if hasattr(tls, "db_session"):
        return tls.db_session

    # DBセッションの生成
    engine = get_db_engine()
    db_session = scoped_session(sessionmaker(autocommit=False,
                                             autoflush=False,
                                             bind=engine))
    tls.db_session = db_session
    return db_session


def get_db_engine():
    db_user = 'root'
    db_host = '127.0.0.1'
    db_name = 'Flask'
    db_path = 'mysql://{}@{}/{}'.format(db_user, db_host, db_name)
    engine = create_engine(db_path, encoding='utf-8', pool_size=5)
    return engine

このコードのかなり致命的な問題

DBトランザクション処理による更新に対応していません。with句でsessionを生成してsave(transaction_session=session)とか出来るようにしたら対応できると思います。もちろんcommitはwith句の__exit__ で実行。

DBトランザクション対応の疑似コード
# 値段を交換する
with commit_on_success as session:
  # select for update
  book1 = Book.get(1, transaction_session=session, for_update=True)
  book2 = Book.get(2, transaction_session=session, for_update=True)

  # exchange price
  _tmp_price = book1.price
  book1.price = book2.price
  book2.price = _tmp_price

  # update
  book1.save(transaction_session=session)
  book2.save(transaction_session=session)