LoginSignup
9

More than 5 years have passed since last update.

SQL Alchemyを魔改造した

Last updated at Posted at 2015-12-15

PyramidやFlaskの標準O/RマッパーであるSQL Alchemyを魔改造しました。

SQL Alchemyのsyntax問題

スクリーンショット 2015-12-15 11.36.53.png

魔改造した

シンプルに記述できるように魔改造してみました。また速度面ではDBコネクションをThreadLocalStorageを使ってconnection poolingしてるので未使用と比較して8倍速くらいで動作します。ローカル環境だとselectクエリ1回8msから1msに高速化しました。

SQLAlchemy魔改造後の使用例
# -*- coding: utf-8 -*-
from module.book import Book

# select
book1 = Book.get(1)
books = Book.objects().filter(Book.price==2160).all()

# insert
book_rye = Book(pk=None,
                title="The catcher in the rye",
                price=1000,
                publish="J. D. Salinger",
                published="")
book_rye = Book.insert(book_rye)
print(book_rye.id)

# update
book_rye.price = 1200
book_rye.save()

# delete
book_rye.delete()

book.py
# -*- coding: utf-8 -*-
from sqlalchemy import Column, String, Integer
from sqlalchemy.ext.declarative import declarative_base
from module.db.base import DBBaseMixin

Base = declarative_base()


class Book(DBBaseMixin, Base):
    title = Column('title', String(200))
    price = Column('price', Integer)
    publish = Column('publish', String(200))
    published = Column('published', String(200))

    def __init__(self, pk, title, price, publish, published):
        self.pk = pk
        self.title = title
        self.price = price
        self.publish = publish
        self.published = published
base_mixin.py
# -*- coding: utf-8 -*-
from sqlalchemy import Column, Integer
from sqlalchemy.ext.declarative import declarative_base, declared_attr
import re
from utils.db import get_db_session

Base = declarative_base()


def camel_to_snake(s):
    """
    >>> convert('CamelCase')
    'camel_case'
    >>> convert('CamelCamelCase')
    'camel_camel_case'
    >>> convert('Camel2Camel2Case')
    'camel2_camel2_case'
    >>> convert('getHTTPResponseCode')
    'get_http_response_code'
    >>> convert('get2HTTPResponseCode')
    'get2_http_response_code'
    >>> convert('HTTPResponseCode')
    'http_response_code'
    >>> convert('HTTPResponseCodeXYZ')
    'http_response_code_xyz'
    :param s: str
    :return: str
    """
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', s)
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


class DBBaseMixin(object):
    id = Column('id', Integer, primary_key=True)

    def __repr__(self):
        return '<{0}.{1} object at {2}>'.format(
            self.__module__, type(self).__name__, hex(id(self)))

    @declared_attr
    def __tablename__(cls):
        return camel_to_snake(cls.__name__)

    @classmethod
    def objects(cls):
        """
        :rtype : sqlalchemy.orm.query.Query
        """
        return get_db_session().query(cls)

    @classmethod
    def session(cls):
        return get_db_session()

    @classmethod
    def get(cls, pk):
        """
        :param pk: int
        :rtype: cls
        """
        return cls.objects().get(pk)

    @classmethod
    def insert(cls, obj):
        """
        :param obj: cls
        :rtype: cls
        """
        cls.session().add(obj)
        cls.session().commit()
        return obj

    @classmethod
    def bulk_insert(cls, objs):
        """
        :param objs: list[cls]
        :rtype: list[cls]
        """
        cls.session().add_all(objs)
        cls.session().commit()
        return objs

    def delete(self):
        session = self.__class__.session()
        session.query(self.__class__).filter(self.__class__.id==self.id).delete()
        session.commit()

    def save(self):
        """
        :return:self
        """
        session = self.__class__.session()
        session.add(self)
        session.commit()
        return self

db.py
# -*- coding: utf-8 -*-
import threading
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker

tls = threading.local()


def get_db_session():
    """
    SQL Alchemy のDBセッションを生成して使い回す
    :rtype : scoped_session
    """
    if hasattr(tls, "db_session"):
        return tls.db_session

    # DBセッションの生成
    engine = get_db_engine()
    db_session = scoped_session(sessionmaker(autocommit=False,
                                             autoflush=False,
                                             bind=engine))
    tls.db_session = db_session
    return db_session


def get_db_engine():
    db_user = 'root'
    db_host = '127.0.0.1'
    db_name = 'Flask'
    db_path = 'mysql://{}@{}/{}'.format(db_user, db_host, db_name)
    engine = create_engine(db_path, encoding='utf-8', pool_size=5)
    return engine

このコードのかなり致命的な問題

DBトランザクション処理による更新に対応していません。with句でsessionを生成してsave(transaction_session=session)とか出来るようにしたら対応できると思います。もちろんcommitはwith句の__exit__ で実行。

DBトランザクション対応の疑似コード
# 値段を交換する
with commit_on_success as session:
  # select for update
  book1 = Book.get(1, transaction_session=session, for_update=True)
  book2 = Book.get(2, transaction_session=session, for_update=True)

  # exchange price
  _tmp_price = book1.price
  book1.price = book2.price
  book2.price = _tmp_price

  # update
  book1.save(transaction_session=session)
  book2.save(transaction_session=session)

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9