PlanetcaleをSQLAlchemyを使ってPandasで読み書きする
やること
- Pythonで
- SQL(planetscale)からpandasへ読み込む
- 平文でSQLを書くのはアレなのでSQLAlchemyを使う
PlanetScaleを使う理由
なんかアツいらしいからです。(適当)
詳しくは以下の記事を参考にしてください。
PlanetScaleというサーバレスDBが凄く勢いのあるサービスらしいのでQuick Startやってみた - Qiita
なお、今回は上記の記事を読み、ブランチを作成した前提で行きます。
脆弱性の対策
SQLを普通に文字列で書いて渡して使う〜のが普通のやり方なんですが、
それだとSQLインジェクションという脆弱性があるとのことなので、
怖いね〜ってことで対策しつつ行きます。
SQLAlchemyというライブラリを使うとSQLインジェクションを回避できるということで
これを使って準備していきます。
また今回は昨今話題らしい、planetscaleというデータベースを使って接続していきます。
DBへの接続準備
planetscaleのクイックスタートを一部見ながら進めていきます。
クイックスタートはOverview
→ connect
から閲覧できます。
今回は Python
で進めていくので言語を設定します。
ターミナルで諸々準備する
-
pip
で必要なライブラリをインストールします。 -
.env
を作り、ここにパスワードとか諸々を打ち込んでいきます。
$ pip install python-dotenv mysqlclient
$ touch .env
.env
を編集する
- エディタで
.env
を編集します。 - クイックスタートに
.env
というタブがあるので、ここをクリックして中身をコピペします。
HOST=ほにゃらら
USERNAME=ほにゃらら
PASSWORD=ほにゃらら
DATABASE=ほにゃらら
SQLAlchemyからPlanetscaleへ接続
- 諸々インポートします。
-
load_dotenv
で先程書いた.env
を読み込み、os.getenv("HOST")
とかで読み出します。 -
create_engine
でDBに接続します、フォーマットなどはコードを参照してください。- SSL接続が必須なので、
?ssl_mode=VERIFY_IDENTITY",connect_args={"ssl": {"ca": "/etc/ssl/cert.pem
を記載しておきましょう。
- SSL接続が必須なので、
- これを
settings.py
で保存しておきます。
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from dotenv import load_dotenv
load_dotenv()
import os
HOST = os.getenv("HOST")
USER = os.getenv("USERNAME")
PASSWD = os.getenv("PASSWORD")
DB = os.getenv("DATABASE")
# データベース接続
ENGINE = create_engine(
f"mysql://{USER}:{PASSWD}@{HOST}/{DB}?ssl_mode=VERIFY_IDENTITY",
connect_args={"ssl": {"ca": "/etc/ssl/cert.pem"}},
)
session = scoped_session(sessionmaker(bind=ENGINE))
# modelで使用する
Base = declarative_base()
Base.query = session.query_property()
テーブルを定義する
-
setting.py
と別のファイルを用意します、名前はuser.py
あたりでいいでしょう。 -
setting.py
からBase
とENGINE
をimportします、名前を変えてたら該当箇所を適宜変えてください。 -
class User(Base)
でテーブルを定義します。- 今回は適当に全部
string
で設定しています。 -
primary_key=True
は設定しておきましょう。
- 今回は適当に全部
-
Base.metadata.create_all(ENGINE)
を実行すると、定義したデータでテーブルがDB上に作成されます。- これがSQLで言うところの
CREATE
にあたります。
- これがSQLで言うところの
from sqlalchemy import Column, String
from setting import Base, ENGINE
class User(Base):
"""
ユーザモデル
"""
__tablename__ = "test"
user_id = Column("user_id", String(767),primary_key=True)
data = Column("data", String(767))
def main():
Base.metadata.create_all(ENGINE)
if __name__ == "__main__":
main()
データベースを読み込む
今回、あらかじめPlanetscale上に10万行ほどデータを挿入してあります。
これを読み込んで pandas
の DataFrame
に変換しましょう。
pandas
の read_sql
を使ってみる(失敗)
from setting import ENGINE
from sqlalchemy import MetaData, Table, func
from sqlalchemy.sql import select
import pandas as pd
# 既存のデータを取得
metadata = MetaData()
event_data = Table("liked", metadata, autoload=True, autoload_with=ENGINE)
df_tweet_id = pd.read_sql_query(
sql=select([
event_data,
func.count("*").label("rows")
]).group_by(event_data),
con=ENGINE,
)
エラーが出た
(MySQLdb._exceptions.OperationalError) (1153, 'rpc error: code = ResourceExhausted desc = grpc: received message larger than max (~ vs. ~)')
らしいです。
つまり、一度に大量のデータを読み込むんじゃねえぞアホがってことだと思います。
別の方法を考えましょう。
SQLAlchemyでちょっとずつ読み込む(失敗)
以下を参考に 、ちょっとずつ読み込んでいく作戦で行ってみます。
SQLAlchemyで、DBから大量にデータを取ってくる時に一度に全部取得せずちょっとずつ取る - 日々精進
from setting import ENGINE
from sqlalchemy.sql import select
sel = select(User.tweet_id).select_from(User)
con = ENGINE.connect()
res = con.execution_options(stream_results=True).execute(sel)
と書いて、本当はここから for
を書いていく予定だったんですが、
そもそも con.execution_options(stream_results=True).execute(sel)
の時点で同じエラーが出てしまうという結果に終わってしまいました。
引っ張ってくるデータはちゃんと選定する(成功…と思いきや)
…ここまでの原因として、すべての列を10万行分読み込もうとしているというのがあります。
横着はいけません、ちゃんと欲しい列をしっかりと指定してあげましょう。
from sqlalchemy.orm import sessionmaker
from setting import ENGINE
# セッション作成
SessionClass = sessionmaker(ENGINE)
session = SessionClass()
# SELECT
ids = session.query(User.tweet_id).all()
# DataFrameに変換
df = pd.DataFrame(ids)
追記ここから
10万件を超えるデータはリミットとオフセットをかけてあげる
実際10万件超えてくると流石に駄目と言われてくるので、
LIMIT
と OFFSET
を使い、小分けにしながらデータを収集します。
- データベースにどれくらいの行があるか見る
- データベースを小分けにしながら
for
で回す - 出てくるのが何重にもネストされた配列なので、解いて一次元配列にする。
# データベース上に何件あるかカウント見る
count_from_db = session.query(User.tweet_id).count()
# データベースのすべての行の1列を取得
db_datas = [
session.query(User.tweet_id).filter(User.tweet_id).limit(1).offset(i).all()
for i in tqdm(range(0, count_from_db, 10000), desc="DB取得中", leave=False)
]
# リストの平坦化
db_datas = [x[0] for x in list(itertools.chain.from_iterable(numbers))]
心配なら、データの個数を比較するのもいいかもしれません。
追記ここまで
データベースを書き込む
pandasにto_sql
というDataFrameを簡単に書き込める機能があるのでこれを使っていきます。
読みと同じく、10万件レベルの大量のデータを書き込む前提で考えていきます。
forでちょっとずつ書き込んでいく
-
task
でだいたい何件ずつ書き込んでいくかを決めていきます。ここは実際動かしながら調整すると良いと思います。多すぎるとエラー吐きます。 -
for
でDataFrame
の中身を回していきます。 -
df.iloc
でDataFrame
の位置を決めながら情報を取得していきます。 -
to_sql
で書き込みます。-
if_exists
はもしテーブルが存在したときにどうするか?を決めます、デフォルトだとエラーを吐くようになっているのでちゃんと設定します。 -
method
,chunksize
を設定して高速化を図ります。
-
from setting import ENGINE
import pandas as pd
task = 500
for i in tqdm(range(0, len(df), task)):
w = df.iloc[i : i + task, 0:task]
w.to_sql(
"liked",
con=ENGINE,
if_exists="append",
method="multi",
index=False,
chunksize=task,
)
おわりに
手探りでSQLalchemyとPandasを触ってみましたが、いかがでしたでしょうか?
とりあえず自分が引っかかったところは網羅できたかなと思います。