最近Dataflowを触っていて、細かいところで実装の仕方に悩むことが多かったので
自分へのメモも兼ねて、サンプルコードを載せておきます。
(細かい用語の説明などはしません)
やりたいこと
以下データを使って、設問IDと回答IDの組み合わせ毎の件数を求める。
| ユーザID | 設問ID | 回答 |
|---|---|---|
| A | 1 | 1,2,3 |
| B | 1 | 3,4 |
| C | 2 | 1,3 |
ポイント
- カンマ区切りの文字列を分解し、別々のレコードにする
- 複合キーでのGroupByをする
カンマ区切りの文字列を分解し、別々のレコードにする
class SplitAnswers(beam.DoFn):
"""answer_idを,で分解し別々のレコードにする"""
def process(self, element, *args, **kwargs):
user_id, question_id, answer_ids = element
for answer_id in answer_ids.split(","):
yield user_id, question_id, answer_id
1:1のサンプルはいくつか見たものの、1:多のサンプルが見当たらなかったので。
1:1と特に変わらず、分解結果をyieldで返すだけでできる。
複合キーでのGroupByをする
class QuestionAnswer(object):
"""GroupByに使用する複合キー用のクラス"""
def __init__(self, question_id, answer_id):
self.question_id = question_id
self.answer_id = answer_id
class QuestionAnswerCoder(beam.coders.Coder):
SEPARATOR = "##SEP##"
def encode(self, value):
return self.SEPARATOR.join([value.question_id, value.answer_id])
def decode(self, encoded):
return QuestionAnswer(*encoded.split(self.SEPARATOR))
def is_deterministic(self):
return True
beam.coders.registry.register_coder(QuestionAnswer, QuestionAnswerCoder)
複合キーでのGroupByをするには、beam.Coderを使う。
基本的に、SEPARATORは複数文字列を使った方が良い。
実運用の時に、何も考えず公式サンプル同様に「:」を使っていたら
ユーザIDと、アクセスしたサイトのURLをキーにしようとした時に
http://hogehoge.com/http://hogehoge.com
のような不正なデータが入っていて、パースエラーになったことがあった。
サンプルプログラム
import unittest
import apache_beam as beam
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
class SplitAnswers(beam.DoFn):
"""answer_idを,で分解し別々のレコードにする"""
def process(self, element, *args, **kwargs):
user_id, question_id, answer_ids = element
for answer_id in answer_ids.split(","):
yield user_id, question_id, answer_id
class QuestionAnswer(object):
"""GroupByに使用する複合キー用のクラス"""
def __init__(self, question_id, answer_id):
self.question_id = question_id
self.answer_id = answer_id
class QuestionAnswerCoder(beam.coders.Coder):
SEPARATOR = "##SEP##"
def encode(self, value):
return self.SEPARATOR.join([value.question_id, value.answer_id])
def decode(self, encoded):
return QuestionAnswer(*encoded.split(self.SEPARATOR))
def is_deterministic(self):
return True
beam.coders.registry.register_coder(QuestionAnswer, QuestionAnswerCoder)
class EnqueteTest(unittest.TestCase):
# user_id, question_id, answer_id
DATA = [
["A", "1", "1,2,3"],
["B", "1", "3,4"],
["C", "2", "1,3"]
]
def test_enquete(self):
with TestPipeline() as pipeline:
result = (
pipeline
| beam.Create(EnqueteTest.DATA)
| beam.ParDo(SplitAnswers())
| beam.Map(lambda (user_id, question_id, answer_id): QuestionAnswer(question_id, answer_id))
| beam.combiners.Count.PerElement()
| beam.Map(lambda (qa, value): (qa.question_id, qa.answer_id, value))
)
assert_that(result, equal_to([
('1', '1', 1),
('1', '2', 1),
('1', '3', 2),
('1', '4', 1),
('2', '1', 1),
('2', '3', 1)
]))
if __name__ == '__main__':
unittest.main()
終わりに
今回はDoFnとCoderについてでしたが、三ヶ月触っていただけあって
ネタはたくさんあるので少しずつ共有できたらなと思います。