LoginSignup
14
12

More than 5 years have passed since last update.

Redis Mutex を Python で実装する

Last updated at Posted at 2014-03-16

発端

サービス構築時, 相互排他でロックを獲得したいが, RDS を使うほどでも無いため Redis を使用したい.
Python で SETNX – Redis の "Design pattern: Locking with SETNX" を実装しているライブラリを探したが, ささっと探した感じ見当たらなかったため, 実装してみた.
なお, SETNX – Redis の日本語訳は 文字列型 - redis 2.0.3 documentation を参照のこと.

実装

Source Code

mutex.py
from datetime import datetime
import time
from functools import wraps

from .exception import (DuplicateLockError,
                        HasNotLockError,
                        ExpiredLockError,
                        SetnxError,
                        LockError)


class Mutex(object):
    def __init__(self, client, key,
                 expire=10,
                 retry_count=6, # retry_count * retry_sleep_sec = 最大待ち時間
                 retry_setnx_count=100,
                 retry_sleep_sec=0.25):
        self._lock = None
        self._r = client
        self._key = key
        self._expire = expire
        self._retry_count = retry_count
        self._retry_setnx_count = retry_setnx_count
        self._retry_sleep_sec = retry_sleep_sec

    def _get_now(self):
        return float(datetime.now().strftime('%s.%f'))

    def lock(self):
        if self._lock:
            raise DuplicateLockError(self._key)
        self._do_lock()

    def _do_lock(self):
        for n in xrange(0, self._retry_count):
            is_set, old_expire = self._setnx()
            if is_set:
                self._lock = self._get_now()
                return

            if self._need_retry(old_expire):
                continue

            if not self._need_retry(self._getset()):
                self._lock = self._get_now()
                return 

        raise LockError(self._key)

    def _setnx(self):
        for n in xrange(0, self._retry_setnx_count):
            is_set = self._r.setnx(self._key, self._get_now() + self._expire)
            if is_set:
                return True, 0

            old_expire = self._r.get(self._key)
            if old_expire is not None:
                return False, float(old_expire)

        raise SetnxError(self._key)

    def _need_retry(self, expire):
        if expire < self._get_now():
            return False

        time.sleep(self._retry_sleep_sec)
        return True

    def _getset(self):
        old_expire = self._r.getset(self._key, self._get_now() + self._expire)
        if old_expire is None:
            return 0

        return float(old_expire)

    def unlock(self):
        if not self._lock:
            raise HasNotLockError(self._key)

        elapsed_time = self._get_now() - self._lock
        if self._expire <= elapsed_time:
            raise ExpiredLockError(self._key, elapsed_time)

        self._r.delete(self._key)
        self._lock = None

    def __enter__(self):
        self.lock()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self._lock:
            self.unlock()
        return True if exc_type is None else False

    def __call__(self, func):
        @wraps(func)
        def inner(*args, **kwargs):
            with self:
                return func(*args, **kwargs)
        return inner
exception.py
class MutexError(Exception):
    pass


class DuplicateLockError(MutexError):

    """
    既に lock() 実行済みの Mutex オブジェクトで lock() を再実行すると発生.
    一度, unlock() を実行するか, 別の Mutex オブジェクトを作成する必要がある.
    """

    pass


class HasNotLockError(MutexError):

    """
    まだ, lock() が実行されていない Mutex オブジェクトで unlock() を実行すると発生.
    lock() 後に実行する必要がある.
    """

    pass


class ExpiredLockError(MutexError):

    """
    lock() 実行後, expire によりロックが解放されている状態で unlock() を実行すると発生.
    """

    pass


class SetnxError(MutexError):
    pass


class LockError(MutexError):
    pass

解説

ざっくりしたロックの流れは次の通り.

  1. SETNX で "有効期限X" を設定
  2. 設定できれば Lock 成功
  3. 設定できなければ, GET で "有効期限Y" を取得
  4. "有効期限Y" が有効ならば, 1 へ
  5. 既に "有効期限Y" が無効ならば, GETSET で "有効期限X" を設定
  6. GETSET で取得した "有効期限Z" が無効なら Lock 成功
  7. "有効期限Z" が有効なら, 他のプロセスが GETSET 済みであるため 1 へ

使い方は次の通り.

usage.py
>>> from mutex import Mutex
>>> with Mutex(':'.join(['EmitAccessToken', user_id]):
>>>     # do something ...
>>>     pass

>>> @Mutex(':'.join(['EmitAccessToken', user_id]):
>>> def emit_access_token():
>>>     # do something ...
>>>     pass

>>> mutex = Mutex(':'.join(['EmitAccessToken', user_id])
>>> mutex.lock()
>>> # do something ...
>>> mutex.unlock()

テスト

test.py
import unittest
import redis
import time
from multiprocessing import Process

from .mutex import Mutex
from .exception import (DuplicateLockError,
                        HasNotLockError,
                        ExpiredLockError,
                        LockError)


class TestMutex(unittest.TestCase):
    def setUp(self):
        self.key = 'spam'
        self.r = redis.StrictRedis()
        self.mutex = Mutex(self.r, self.key)

    def tearDown(self):
        mutex = self.mutex
        if mutex._lock:
            mutex.unlock()
        mutex._r.delete('ham')

    def test_lock(self):
        mutex = self.mutex
        mutex.lock()
        self.assertIsNotNone(mutex._r.get(mutex._key))

        with self.assertRaises(DuplicateLockError):
            mutex.lock()

    def test_unlock(self):
        self.test_lock()

        mutex = self.mutex
        self.mutex.unlock()
        self.assertIsNone(mutex._r.get(mutex._key))

        with self.assertRaises(HasNotLockError):
            mutex.unlock()

        self.test_lock()
        time.sleep(10.5)
        with self.assertRaises(ExpiredLockError):
            mutex.unlock()
        mutex._lock = None # 強制的に初期化

    def test_expire(self):
        mutex1 = self.mutex

        mutex2 = Mutex(self.r, self.key, expire=2)
        mutex2.lock() # 2 秒 Lock し続ける

        with self.assertRaises(LockError):
            mutex1.lock() # retry 6 回 * sleep 0.25 秒 = 1.5 秒

        time.sleep(0.6) # おまけ
        mutex1.lock()
        self.assertIsNotNone(mutex1._r.get(mutex1._key))

    def test_with(self):
        mutex1 = self.mutex
        with mutex1:
            self.assertIsNotNone(mutex1._r.get(mutex1._key))
        self.assertIsNone(mutex1._r.get(mutex1._key))

        mutex2 = Mutex(self.r, self.key, expire=2)
        mutex2.lock() # 2 秒 Lock し続ける

        with self.assertRaises(LockError):
            with mutex1: # retry 6 回 * sleep 0.25 秒 = 1.5 秒
                pass

        mutex2.unlock()

        with mutex1:
            with self.assertRaises(DuplicateLockError):
                with mutex1:
                    pass

    def test_decorator(self):
        mutex = self.mutex
        @mutex
        def egg():
            self.assertIsNotNone(mutex._r.get(mutex._key))
        egg()
        self.assertIsNone(mutex._r.get(mutex._key))

    def test_multi_process(self):
        procs = 20
        counter = 100

        def incr():
            mutex = Mutex(redis.StrictRedis(), self.key, retry_count=100)
            for n in xrange(0, counter):
                mutex.lock()

                ham = mutex._r.get('ham') or 0
                mutex._r.set('ham', int(ham) + 1)

                mutex.unlock()

        ps = [Process(target=incr) for n in xrange(0, procs)]
        for p in ps:
            p.start()

        for p in ps:
            p.join()

        self.assertEqual(int(self.mutex._r.get('ham')), counter * procs)
14
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
12