Scala
Akka
Akka-Stream
ScalaDay 10

akka-streamを使ってmemcachedクライアントを作るには

かなり前になるのですが、akka-streamの習作課題として、Memcachedクライアントを実装したので、概要を解説する記事をまとめます。詳しくはgithubをみてください。

https://github.com/j5ik2o/reactive-memcached

TCPのハンドリング方法

akkaでTCPをハンドリングするには、以下の二つになる。おそらく

  1. akka.actorakka.ioのパッケージを使う方式。つまり、アクターでTCP I/Oを実装する方法。
  2. akka.streamTcp#outgoingConnectionというAPIを使う方法

今回は、akka-streamをベースにしたかったので、2を採用しました。

outgoingConnectionメソッドの戻り値の型は、Flow[ByteString, ByteString, Future[OutgoingConnection]]です。ByteStringを渡せば、ByteStringが返ってくるらしい。わかりやすいですね。そしてストリームが起動中であれば、コネクションは常時接続された状態になります。

def outgoingConnection(
  remoteAddress : InetSocketAddress,
  localAddress : Option[InetSocketAddress],
  options : Traversable[SocketOption], 
  halfClose : Boolean,
  connectTimeout : Duration,
  idleTimeout : Duration) : Flow[ByteString, ByteString, Future[Tcp.OutgoingConnection]]

Memcachedのコネクションを表現するオブジェクトを実装する

では、早速 Memcachedのコネクションを表現するオブジェクトである MemcachedConnection を実装します。このオブジェクトを生成するとTCP接続され、破棄されると切断されます。そして、考えたインターフェイスは以下です(別にtrait切り出さなくてもよかったですが、説明のために分けました)。sendメソッドにコマンドを指定して呼ぶとレスポンスが返る単純なものです。
ちなみに、sendメソッドの戻り値の型はmonix.eval.Taskです。詳しくは https://monix.io/docs/3x/eval/task.html をご覧ください。

trait MemcachedConnection {
  def id: UUID // 接続ID
  def shutdown(): Unit // シャットダウン
  def peerConfig: Option[PeerConfig] // 接続先設定

  def send[C <: CommandRequest](cmd: C): Task[cmd.Response] // コマンドの送信

  // ...
}

実装はどうなるか。全貌は以下。先ほどのTcp#outgoingConnectionを使います。

private[memcached] class MemcachedConnectionImpl(_peerConfig: PeerConfig,
                                                 supervisionDecider: Option[Supervision.Decider])(
    implicit system: ActorSystem
) extends MemcachedConnection {

  // ...

  private implicit val mat: ActorMaterializer = ...

  protected val tcpFlow: Flow[ByteString, ByteString, NotUsed] =
    RestartFlow.withBackoff(minBackoff, maxBackoff, randomFactor, maxRestarts) { () =>
      Tcp()
        .outgoingConnection(remoteAddress, localAddress, options, halfClose, connectTimeout, idleTimeout)
    }

  protected val connectionFlow: Flow[RequestContext, ResponseContext, NotUsed] =
    Flow.fromGraph(GraphDSL.create() { implicit b =>
      import GraphDSL.Implicits._
      val requestFlow = b.add(
        Flow[RequestContext]
          .map { rc =>
            log.debug(s"request = [{}]", rc.commandRequestString)
            (ByteString.fromString(rc.commandRequest.asString + "\r\n"), rc)
          }
      )
      val responseFlow = b.add(Flow[(ByteString, RequestContext)].map {
        case (byteString, requestContext) =>
          log.debug(s"response = [{}]", byteString.utf8String)
          ResponseContext(byteString, requestContext)
      })
      val unzip = b.add(Unzip[ByteString, RequestContext]())
      val zip   = b.add(Zip[ByteString, RequestContext]())
      requestFlow.out ~> unzip.in
      unzip.out0 ~> tcpFlow ~> zip.in0
      unzip.out1 ~> zip.in1
      zip.out ~> responseFlow.in
      FlowShape(requestFlow.in, responseFlow.out)
    })

  protected val (requestQueue: SourceQueueWithComplete[RequestContext], killSwitch: UniqueKillSwitch) = Source
    .queue[RequestContext](requestBufferSize, overflowStrategy)
    .via(connectionFlow)
    .map { responseContext =>
      log.debug(s"req_id = {}, command = {}: parse",
                responseContext.commandRequestId,
                responseContext.commandRequestString)
      val result = responseContext.parseResponse // レスポンスのパース
      responseContext.completePromise(result.toTry) // パース結果を返す
    }
    .viaMat(KillSwitches.single)(Keep.both)
    .toMat(Sink.ignore)(Keep.left)
    .run()

  override def shutdown(): Unit = killSwitch.shutdown()

  override def send[C <: CommandRequest](cmd: C): Task[cmd.Response] = Task.deferFutureAction { implicit ec =>
    val promise = Promise[CommandResponse]()
    requestQueue
      .offer(RequestContext(cmd, promise, ZonedDateTime.now()))
      .flatMap {
        case QueueOfferResult.Enqueued =>
          promise.future.map(_.asInstanceOf[cmd.Response])
        case QueueOfferResult.Failure(t) =>
          Future.failed(BufferOfferException("Failed to send request", Some(t)))
        case QueueOfferResult.Dropped =>
          Future.failed(
            BufferOfferException(
              s"Failed to send request, the queue buffer was full."
            )
          )
        case QueueOfferResult.QueueClosed =>
          Future.failed(BufferOfferException("Failed to send request, the queue was closed"))
      }
  }

}

TCPの送受信するにはtcpFlowメソッドが返すFlowに必要なByteStringを流して、ByteStringを受け取ればいいですが、リクエストとレスポンスを扱うためにもう少し工夫が必要です。そのためにconnectionFlowメソッドはtcpFlowメソッドを内部Flowとして利用します。そして、FlowはRequestContextResponseContextの型を利用します。その実装は以下のようなシンプルなものです。キューであるrequestQueueconnectionFlowにリクエストを送信します。ストリームの後半で返ってきたレスポンスを処理するという流れになっています。キューへのリクエストのエンキューはsendメソッドが行います。sendメソッドが呼ばれるとRequestContextを作ってrequestQueueにエンキュー、完了するとレスポンスを返します。requestQueueはストリームと繋がっていて、エンキューされたRequestContextconnectionFlowに渡され返ってきたResponseContextがレスポンス内容をパースし、その結果をPromise経由で返すというものです1。エラーハンドリングについて、エラーを起こしやすいところにRestartFlow.withBackoffを入れています。

final case class RequestContext(commandRequest: CommandRequest,
                                promise: Promise[CommandResponse],
                                requestAt: ZonedDateTime) {
  val id: UUID                     = commandRequest.id
  val commandRequestString: String = commandRequest.asString
}

final case class ResponseContext(byteString: ByteString,
                                 requestContext: RequestContext,
                                 requestsInTx: Seq[CommandRequest] = Seq.empty,
                                 responseAt: ZonedDateTime = ZonedDateTime.now)
    extends ResponseBase {

  val commandRequest: CommandRequest = requestContext.commandRequest

  def withRequestsInTx(values: Seq[CommandRequest]): ResponseContext = copy(requestsInTx = values)
  // レスポンスのパース
  def parseResponse: Either[ParseException, CommandResponse] = {
    requestContext.commandRequest match {
      case scr: CommandRequest =>
        scr.parse(ByteVector(byteString.toByteBuffer)).map(_._1)
    }
  }

}

コネクションオブジェクトの使い方は簡単です。ほとんどの場合は、このようにコネクションオブジェクトを直接操作せずに、高レベルなAPIを操作したいと思うはずです。そのためのクライアントオブジェクトはあとで述べます。

val connection = MemcachedConnection(
  PeerConfig(new InetSocketAddress("127.0.0.1", memcachedTestServer.getPort),
  backoffConfig = BackoffConfig(maxRestarts = 1)),
  None
)

val resultFuture = (for {
  _   <- connection.send(SetRequest(UUID.randomUUID(), "key1", "1", 10 seconds))
  _   <- connection.send(SetRequest(UUID.randomUUID(), "key2", "2", 10 seconds))
  gr1 <- connection.send(GetRequest(UUID.randomUUID(), "key1"))
  gr2 <- connection.send(GetRequest(UUID.randomUUID(), "key2"))
} yield (gr1, gr2)).runAsync
val result = Await.result(resultFuture, Duration.Inf) // (1, 2)

コマンドの実装

コマンドの一例はこんな感じです。
asStringメソッドは文字通りMemcachedにコマンドを送信するときに利用される文字列です。responseParserはfastparseで実装されたレスポンスパーサを指定します。parseResponseメソッドはMemcachedからのレスポンスを表現する構文木をコマンドのレンスポンスに変換するための関数です。これは少し複雑なのであとで説明します。

final class GetRequest private (val id: UUID, val key: String) extends CommandRequest with StringParsersSupport {

  override type Response = GetResponse
  override val isMasterOnly: Boolean = false

  override def asString: String = s"get $key"

  // レスポンスを解析するためのパーサを指定。レスポンス仕様に合わせて指定する
  override protected def responseParser: P[Expr] = P(retrievalCommandResponse)

  // ASTをコマンドレスポンスに変換する
  override protected def parseResponse: Handler = {
    case (EndExpr, next) =>
      (GetSucceeded(UUID.randomUUID(), id, None), next)
    case (ValueExpr(key, flags, length, casUnique, value), next) =>
      (GetSucceeded(UUID.randomUUID(), id, Some(ValueDesc(key, flags, length, casUnique, value))), next)
    case (ErrorExpr, next) =>
      (GetFailed(UUID.randomUUID(), id, MemcachedIOException(ErrorType.OtherType, None)), next)
    case (ClientErrorExpr(msg), next) =>
      (GetFailed(UUID.randomUUID(), id, MemcachedIOException(ErrorType.ClientType, Some(msg))), next)
    case (ServerErrorExpr(msg), next) =>
      (GetFailed(UUID.randomUUID(), id, MemcachedIOException(ErrorType.ServerType, Some(msg))), next)
  }

}

これはGetRequestの上位traitであるCommandRequestです。Memcachedのコマンド送信に必要なのはasStringメソッドだけで、コマンドのレスポンスのパースにはparseメソッドが利用されます。

trait CommandRequest {
  type Elem
  type Repr
  type P[+T] = core.Parser[T, Elem, Repr]

  type Response <: CommandResponse

  type Handler = PartialFunction[(Expr, Int), (Response, Int)]

  val id: UUID
  val key: String
  val isMasterOnly: Boolean

  def asString: String

  protected def responseParser: P[Expr]

  protected def convertToParseSource(s: ByteVector): Repr

  def parse(text: ByteVector, index: Int = 0): Either[ParseException, (Response, Int)] = {
    responseParser.parse(convertToParseSource(text), index) match {
      case f @ Parsed.Failure(_, index, _) =>
        Left(new ParseException(f.msg, index))
      case Parsed.Success(value, index) => Right(parseResponse((value, index)))
    }
  }

  protected def parseResponse: Handler

}

レスポンス・パーサの実装

レスポンスパーサは、Memcachedの仕様に合わせて以下の定義から適切なものを選んで利用します。詳しくは公式ドキュメントをみてください。標準のパーサーコンビネータと少し違うところがありますが、概ね似たような感じで書けます。Memcachedのレスポンスのルールは単純で左再帰除去などもしなくていいので、パーサーコンビネータのはじめての題材にはよいと思います。

object StringParsers {

  val digit: P0      = P(CharIn('0' to '9'))
  val lowerAlpha: P0 = P(CharIn('a' to 'z'))
  val upperAlpha: P0 = P(CharIn('A' to 'Z'))
  val alpha: P0      = P(lowerAlpha | upperAlpha)
  val alphaDigit: P0 = P(alpha | digit)
  val crlf: P0       = P("\r\n")

  val error: P[ErrorExpr.type]        = P("ERROR" ~ crlf).map(_ => ErrorExpr)
  val clientError: P[ClientErrorExpr] = P("CLIENT_ERROR" ~ (!crlf ~/ AnyChar).rep(1).! ~ crlf).map(ClientErrorExpr)
  val serverError: P[ServerErrorExpr] = P("SERVER_ERROR" ~ (!crlf ~/ AnyChar).rep(1).! ~ crlf).map(ServerErrorExpr)
  val allErrors: P[Expr]              = P(error | clientError | serverError)

  val end: P[EndExpr.type]             = P("END" ~ crlf).map(_ => EndExpr)
  val deleted: P[DeletedExpr.type]     = P("DELETED" ~ crlf).map(_ => DeletedExpr)
  val stored: P[StoredExpr.type]       = P("STORED" ~ crlf).map(_ => StoredExpr)
  val notStored: P[NotStoredExpr.type] = P("NOT_STORED" ~ crlf).map(_ => NotStoredExpr)
  val exists: P[ExistsExpr.type]       = P("EXISTS" ~ crlf).map(_ => ExistsExpr)
  val notFound: P[NotFoundExpr.type]   = P("NOT_FOUND" ~ crlf).map(_ => NotFoundExpr)
  val touched: P[TouchedExpr.type]     = P("TOUCHED" ~ crlf).map(_ => TouchedExpr)

  val key: P[String]     = (!" " ~/ AnyChar).rep(1).!
  val flags: P[Int]      = digit.rep(1).!.map(_.toInt)
  val bytes: P[Long]     = digit.rep(1).!.map(_.toLong)
  val casUnique: P[Long] = digit.rep(1).!.map(_.toLong)

  val incOrDecCommandResponse: P[Expr] = P(notFound | ((!crlf ~/ AnyChar).rep.! ~ crlf).map(StringExpr) | allErrors)
  val storageCommandResponse: P[Expr]  = P((stored | notStored | exists | notFound) | allErrors)

  val value: P[ValueExpr] =
    P("VALUE" ~ " " ~ key ~ " " ~ flags ~ " " ~ bytes ~ (" " ~ casUnique).? ~ crlf ~ (!crlf ~/ AnyChar).rep.! ~ crlf)
      .map {
        case (key, flags, bytes, cas, value) =>
          ValueExpr(key, flags, bytes, cas, value)
      }

  val version: P[VersionExpr] = P("VERSION" ~ " " ~ alphaDigit.rep(1).!).map(VersionExpr)

  val retrievalCommandResponse: P[Expr] = P(end | value | allErrors)

  val deletionCommandResponse: P[Expr] = P(deleted | notFound | allErrors)

  val touchCommandResponse: P[Expr] = P(touched | notFound | allErrors)

  val versionCommandResponse: P[Expr] = P(version | allErrors)
}

クライアントの実装

次はクライアント実装。MemcachedClient#sendメソッドはコマンドを受け取り、ReaderTTaskMemcachedConnection[cmd.Response]を返します。ReaderTTaskMemcachedConnectionはMemcachedのために特化したReaderTで、コネクションを受け取り、レスポンスを返すTaskを返します。Memcachedのための高レベルなAPIはこのsendメソッドを利用して実装されます。get,setなどのメソッドの内部でコマンドを作成し、レスポンスを戻り値にマッピングするだけで、使いやすい高レベルAPIになります。また、ReaderTを返すためfor式での簡潔な記述もできます。

package object memcached {

  type ReaderTTask[C, A]                  = ReaderT[Task, C, A]
  type ReaderTTaskMemcachedConnection[A]  = ReaderTTask[MemcachedConnection, A]
  type ReaderMemcachedConnection[M[_], A] = ReaderT[M, MemcachedConnection, A]

}
final class MemcachedClient()(implicit system: ActorSystem) {

  def send[C <: CommandRequest](cmd: C): ReaderTTaskMemcachedConnection[cmd.Response] = ReaderT(_.send(cmd))

  def get(key: String): ReaderTTaskMemcachedConnection[Option[ValueDesc]] =
    send(GetRequest(UUID.randomUUID(), key)).flatMap {
      case GetSucceeded(_, _, result) => ReaderTTask.pure(result)
      case GetFailed(_, _, ex)        => ReaderTTask.raiseError(ex)
    }

  def set[A: Show](key: String,
                   value: A,
                   expireDuration: Duration = Duration.Inf,
                   flags: Int = 0): ReaderTTaskMemcachedConnection[Int] =
    send(SetRequest(UUID.randomUUID(), key, value, expireDuration, flags)).flatMap {
      case SetExisted(_, _)    => ReaderTTask.pure(0)
      case SetNotFounded(_, _) => ReaderTTask.pure(0)
      case SetNotStored(_, _)  => ReaderTTask.pure(0)
      case SetSucceeded(_, _)  => ReaderTTask.pure(1)
      case SetFailed(_, _, ex) => ReaderTTask.raiseError(ex)
    }

   // ...

}
  • for式での利用例
val client = new MemcachedClient()
val resultFuture = (for {
  _  <- client.set(key, "1")
  r1 <- client.get(key)
  _  <- client.set(key, "2")
  r2 <- client.get(key)
} yield (r1, r2)).run(connection).runAsync
val result = Await.result(resultFuture, Duration.Inf) // (1, 2)

コネクションプール

コネクションとクライアントを実装できたら、次はコネクションプールも欲しくなります。MemcachedConnectionをプーリングして選択できればいいわけなので、commons-poolなど使えば楽ですね。コネクションをプールするアルゴリズム(HashRingなど)も複数考えられるので、工夫するとなかなか面白いです。

implicit val system = ActorSystem()

val peerConfig = PeerConfig(remoteAddress = new InetSocketAddress("127.0.0.1", 6379))
val pool = MemcachedConnectionPool.ofSingleRoundRobin(sizePerPeer = 5, peerConfig, RedisConnection(_)) // powered by RoundRobinPool
val connection = MemcachedConnection(connectionConfig)
val client = MemcachedClient()

// ローンパターン形式
val resultFuture1 = pool.withConnectionF{ con =>
  (for{
    _ <- client.set("foo", "bar")
    r <- client.get("foo")
  } yield r).run(con) 
}.runAsync

// モナド形式
val resultFuture2 = (for {
  _ <- ConnectionAutoClose(pool)(client.set("foo", "bar").run)
  r <- ConnectionAutoClose(pool)(client.get("foo").run)
} yield r).run().runAsync

まとめ

というわけでこんな風にakka-streamを使えばTCPクライアントも比較的簡単に実装できると思います。参考にしてみてください。


  1. バックプレッシャを適切にハンドリングするには、ストリームを常に起動した状態にしておく必要があります、要求ごとにストリームを起動していると効果的ではありません))