1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

boto3でECSのExecute Commandを使ってローカルとコンテナでファイルをコピーする

Last updated at Posted at 2024-11-04

概要

ECSのExecute Commandはものすごく便利で、コンテナ内に入ってデバッグする際には重宝します。そこで、ファイルのコピーがしたいなぁと思うこともあります。
それほど大きくないテキストファイル(例えばPEM公開鍵とか)ならコピペでもいいですが、ちょっと大きめのファイル(設定ファイルなど)とかバイナリファイルだと少し面倒です。

そういったことを実現するためにboto3を使ってセッションを開いて、セッション経由でいろいろ試してみましたのでメモしておきます。ここではファイルのコピーにフォーカスしていますが、セッションWebSocketでの通信についてもまとめてありますので、boto3を使ったオートメーションにも応用できると思います。通信仕様に関する公式なドキュメントが見つけられなかったので、仕様通りの動作ではないかもしれませんが、ファイルコピーに関しては一応動作しました。

ここではFargateのコンテナに対して行いました。コンテナへの接続設定は完了している前提で、設定方法は扱いませんので、以下のページなどを参考にして下さい。

関連しているかも、ドキュメント

そのものズバリな解決方法はなく、結構みなさん苦戦しているようでした(しかもちょっと古い)。あまり需要がないもしくはそれほど頻繁ではないので手作業でなんとか、という感じでしょうか。

  1. [GitHub] ECS接続のサンプルないんか?という議論
  2. [GitHub] セッションマネージャプラグイン
  3. [AWSドキュメント] セッションマネージャのStartSession

通信仕様について

雑感

ECSのExecute Commandはセッションマネージャ(SSM)のセッション通信を使っているのでセッションマネージャの仕様に従えばいいみたいです(通信の始め方は書いてありますが、通信の維持のドキュメントがないんですよね…)。GitHubの議論やそれに類する議論やブログなどでは、シンプルな通信については言及されています。ただ、あまり詳細に検証したものが見つからなかったので、みなさんそれほど必要としていないのかもしれないです。GitHubの議論(1)のコメントでは、受信確認(ACK)についても触れており、それを受けたこのスレッドではACKの構造について言及されています。

受信確認に関して試した感じだと、あまりに変なACKを送ってしまうと通信が切れることがありますが、間違ったACKを送ってもセッション出力のシーケンスは進むのでACKが必要かどうかはよくわかりません(今回のコードではACKを行ってますが)。実際、Message IDはバイト列では下位8バイト、上位8バイトの順でエンコードされているのに気付かずに送っていても問題なかったので、SSM側では検証されていないのかもしれません(こちらの受信確認時に気付いた)。また、議論ではACKのペイロードがMessageTypeなどになっていましたが、実際はAcknowledgedMessageTypeなどが正しいみたいでした。そのあたり、間違ったままでも処理はされていたので、こちらからのACKは不要なのかもしれません。

一方、ローカル側からデータを送信する際は、ACKを受信(検証はどちらでもよい)しておかないと送信データがコミットされないので、少なくとも受信は必要なようです。

また、あまり細かいことを気にせずに受信を続けて、(おそらく)ping的に使われているSeq=0のoutput_stream_dataを一度だけ受信して、その後は受信して無視、を繰り返すだけでもデータのやりとりはできるような気がします。

結局のところ、細かい仕様はセッションマネージャプラグインのソースを読んで確認していくという感じになりました。

セッション通信

ECSコンテナとのセッションは以下のように開かれるようです。ここでのACK送信は(私が試した範囲では)送らなくても特に問題は起こりませんでしたが、実装では送信しています。channel_closed受信の時はACK不要で切断すればいいと思います。

  1. Execute Command APIを実行
    • セッション用ストリームURL(WebSocket/TLS)、セッショントークンを得る
  2. それらの情報を使ってWebSocketを開く(SSM StartSessionに書いてある)
  3. start_publicationメッセージの受信+ACK送信
  4. output_stream_dataの受信+ACK送信(Seq=0は1度だけ受信し、2回目以降は受信のみ)
  5. (必要なら)input_stream_dataの送信+ACK受信(必須)
  6. channel_closedの受信
  7. WebSocket切断

4.、5.を繰り返すことでインタラクティブな通信を行うことができます。
ここで使うメッセージタイプはstart_publicationoutput_stream_datainput_stream_dataacknowledgechannel_closedです。

通信はテキストを想定されているためか、改行コードがCRLFに変換されます。そのため、バイナリデータを送ると不整合が起こるため、バイナリデータはBase64などで送る必要があります。

メッセージフォーマット

メッセージはバイト文字列です。WebSocket接続直後はJSON文字列を送りますが、それ以降はバイト文字列で通信することになります。

  • バイトオーダーはビッグエンディアン
  • ペイロードは可変長だが、最大1024バイトかも(送られてくるサイズがそうなってる)
オフセット 内容 サイズ
0 ヘッダ長 UInt32 4 116(固定)
4 メッセージタイプ String 32 ※1
36 Schema Version UInt32 4 1(固定)
40 メッセージ生成時刻 UInt64 8 UNIX秒(msec)
48 シーケンス番号 UInt64 8
56 フラグ UInt64 8 1 or 3 ※2
64 メッセージIDの下位8バイト UInt64 8
72 メッセージIDの上位8バイト UInt64 8
80 ペイロードのSHA256ダイジェスト Bytes 32 ※3
112 ペイロードタイプ UInt32 4 0 or 1 ※2
116 ペイロード長 UInt32 4
120 ペイロード Bytes 可変 最大1024bytes?

※1 ここで使うのは start_publication, channel_closed, output_stream_data, input_stream_data, acknowledge
※2 下の表を参照
※3 start_publicationの時はペイロードにstart_publicationがセットされているが、ダイジェストは空文字列のSHA256ハッシュ(e3b0c442...55)になっている

メッセージタイプ シーケンス番号 フラグ ペイロードタイプ
start_publication 0 3 0
channel_closed 0 3 0
output_stream_data 連番 Seq=0の時:1、それ以外:0 1
input_stream_data 連番 Seq=0の時:1、それ以外:0 1
acknowledge 0 3 0

実装

通信仕様について、以上の内容で理解し実装しました。以下に実装を掲載します。エラー処理などはあまり実装していないので、動作検証用です。不具合等が起こっても保証はできません。

WebSocketにはwebsocketsモジュールを使用しています。

動作確認

  • Rocky Linux 9.4
  • Python 3.12.3
  • websockets 13.1 -- PyPI

使い方

usage: ecs-cp.py [-h] --cluster CLUSTER [--service-name SERVICE_NAME | --task TASK] src dst

ECSコンテナとファイル転送

positional arguments:
  src                   ソースファイルパス
  dst                   宛先ファイルパス

options:
  -h, --help            show this help message and exit
  --cluster CLUSTER     ECSクラスター名
  --service-name SERVICE_NAME
                        ServiceNameを指定
  --task TASK           TaskIDを指定

クラスター名とサービス名またはタスクIDを組み合わせてタスクを特定します。クラスター名のみ、または、クラスター名とサービス名で指定した場合はlist-tasksでヒットした先頭のタスクに接続します。タスクIDを指定した場合はそのタスクに接続します。クラスター名は必須です。

ローカルファイルとECSファイルを指定しますが、ECS側のパスはecs://path/toで指定します。
ECS側のパスは単純にecs://を空文字列に置換するだけなので、絶対パスの場合はecs:///root/test.txtというようにスラッシュが3つ並びます。

また、デバッグ用のprint()が入っていますので適宜コメントアウトして下さい。

ローカルからECSにコピー

ローカルで指定したファイルをペイロードにおさまるサイズのbase64チャンクにして送ります。単純にecho -n BASE64DATA | base64 -d >> /path/to/test.dataというコマンドで送っているだけなので、DSTパラメータにはファイル名まで指定する必要があります。

$ ./ecs-cp.py --cluster MyCluster test.data ecs:///path/to/test.data

ECSからローカルにコピー

ECS側のファイルをbase64で送っています。base64 -w0 /myecs.binで改行なしのbase64で送っているだけです。これも受信側はファイル名まで指定して下さい。

$ ./ecs-cp.py --cluster MyCluster ecs:///myecs.bin myecs.bin

実装

ecs-cp.py
#!/usr/bin/env python3
from hashlib import sha256
from datetime import datetime
from base64 import b64decode, b64encode
import json, uuid, struct
import boto3
from websockets.sync.client import connect

# ECSのExecuteCommandで出力を得る
# 出力は改行コードがCRLFになるため、バイナリ通信をしたいときはbase64などでカプセル化する必要がありそう

# 2024-11-02: この議論を参考にした(ACKに関してが未解決)
# https://github.com/boto/boto3/issues/3496#issuecomment-1319039520

PAYLOAD_MAX_SIZE = 1024     # 送信時ペイロードの最大長(たぶん)

class Payload:
    # ペイロード部分
    def __init__(self):
        self.type = 0
        self.content = b""

    @classmethod
    def from_bytes(cls, data, payload_type, expected_digest=None):
        # バイト列からインスタンスを生成する
        # ダイジェストを確認する
        if expected_digest is not None and sha256(data).digest() != expected_digest:
            raise ValueError("Digest not match!!")

        obj = cls()
        obj.type = payload_type
        obj.content = data
        return obj

    def to_bytes(self):
        # バイト列に変換する
        # |Size(uint32_t)|Content(bytes)|
        return (len(self.content) + 4).to_bytes(4, "big") + self.content

class Message:
    # メッセージをパースするためのクラス
    def __init__(self):
        self.typename = None
        self.schema_version = 1         # 固定
        self.created = datetime.now()   # 生成日時
        self.sequence_number = 0        # Sequence No.
        self.flags = None               # フラグ
        self.message_id = None          # UUID
        self.payload = Payload()        # ペイロード

    @classmethod
    def from_bytes(cls, data):
        # バイナリ文字列をパースする
        # |HL |Message Type                   |ScVer  |Timestamp      |SeqNumber      |Flags          |
        # |MsgID(Lower)   |MsgID(Upper)   |Payload Digest                 |PLType |
        header_length = struct.unpack(">I", data[:4])[0]                    # ヘッダ長
        header, data = data[4:header_length], data[header_length:]          # ヘッダだけを分離
        mtyp, ver, ts, seq, flg, midl, midu, dgst, ptyp = struct.unpack("!32sIQQQ8s8s32sI", header)
        payload_length, data = struct.unpack(">I", data[:4])[0], data[4:]   # ペイロード長
        payload, data = data[:payload_length], data[payload_length:]        # ペイロードだけを分離
        assert payload_length == len(payload)                               # 長さ確認
        assert not data                                                     # 残りはないはず

        # オブジェクトを生成する(ACKでもMessageオブジェクトで生成)
        obj = cls()
        obj.typename = mtyp.replace(b"\x00", b"").decode().strip()
        obj.schema_version = ver                            # 1(固定)
        obj.created = datetime.fromtimestamp(ts / 1000)     # 生成されたタイムスタンプ(msec)
        obj.sequence_number = seq                           # シーケンス番号
        obj.flags = flg                                     # フラグ(1 or 3)
        obj.message_id = uuid.UUID(bytes=midu+midl)         # メッセージID(UUID)

        # ペイロードを構成する
        if obj.typename == "start_publication":
            # start_publicationの時はペイロード文字列もstart_publication
            # ダイジェストは空文字列のもの(e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855)
            assert payload == b"start_publication"
            payload = b""
        obj.payload = Payload.from_bytes(payload, ptyp, dgst)

        return obj

    def to_bytes(self):
        # バイト列に変換する
        # UUID(16bytes)は下位8バイト+上位8バイトという並び
        header  = self.typename.encode().ljust(32, b"\x00")                 # メッセージタイプ
        header += self.schema_version.to_bytes(4, "big")                    # スキーマバージョン(=1)
        header += int(self.created.timestamp() * 1000).to_bytes(8, "big")   # タイムスタンプ(msec)
        header += self.sequence_number.to_bytes(8, "big")                   # SeqNo
        header += self.flags.to_bytes(8, "big")                             # flags
        header += self.message_id.bytes[8:]                                 # MessageID(UUID下位8バイト)
        header += self.message_id.bytes[:8]                                 # MessageID(UUID上位8バイト)
        header += sha256(self.payload.content).digest()                     # ペイロード内容のダイジェスト
        header += self.payload.type.to_bytes(4, "big")                      # ペイロードのタイプ

        header = (len(header) + 4).to_bytes(4, "big") + header              # ヘッダ先頭にヘッダサイズを追加
        payload = self.payload.to_bytes()                                   # ペイロードを追加

        return header + payload

    def acknowledge(self):
        # ACKメッセージを構成する
        # ペイロードの情報
        # https://github.com/aws/session-manager-plugin/blob/61cf1288/src/datachannel/streaming.go#L383
        obj = {
            "AcknowledgedMessageType": self.typename,
            "AcknowledgedMessageId": str(self.message_id),
            "AcknowledgedSequenceNumber": self.sequence_number,
            "IsSequentialMessage": True,
        }
        return AcknowledgeMessage(json.dumps(obj, separators=(",", ":")).encode())

class AcknowledgeMessage(Message):
    # ACKメッセージ
    TYPENAME = "acknowledge"
    def __init__(self, data):
        super().__init__()

        # データ構造
        # https://github.com/aws/session-manager-plugin/blob/61cf1288/src/message/messageparser.go#L488
        self.typename = self.TYPENAME       # 固定(AcknowledgeMessage)
        self.flags = 3                      # 固定(=3)
        self.message_id = uuid.uuid4()      # Message IDを生成
        self.payload.type = 0               # 固定(=0) たぶん
        self.payload.content = data

class InputStreamData(Message):
    # データ送信
    TYPENAME = "input_stream_data"
    def __init__(self, data, seq):
        super().__init__()
        self.typename = self.TYPENAME
        self.sequence_number = seq
        self.flags = 1 if seq else 0        # Seq=0の時は1のようだ
        self.message_id = uuid.uuid4()      # Message IDを生成
        self.payload.type = 1               # 固定(=1) たぶん
        self.payload.content = data

class MessageClient:
    # WebSocketでのメッセージクライアント
    def __init__(self, ws_conn, recv_handler):
        self.conn = ws_conn             # WebSocket接続(websockets.sync.client.connect())
        self.received_seq = -1          # 受け取ったシーケンス番号
        self.sent_seq = -1              # 送ったシーケンス番号
        self.handler = recv_handler     # 受け取ったデータを処理するハンドラ

    def recv(self):
        # ACK(受信確認)メッセージを送信しているが、送らなくても特に問題はないかも
        # 送っても送らなくても動作に違いがなさそう…
        # start_publication -> output_stream_data -> ... -> channel_closed
        # データ送信メッセージ(input_stream_data)を送ったときは、acknowledgeを待ってから次を送る
        msg = Message.from_bytes(self.conn.recv())          # メッセージをパースする

        # これらはACK送信しない
        if msg.typename == "channel_closed": return msg     # 接続終了要求
        if msg.typename == "acknowledge": return msg        # ACK受信

        self.conn.send(msg.acknowledge().to_bytes())        # ACK送信

        if msg.typename == "start_publication": return msg  # 通信開始

        # ここにはoutput_stream_dataのみが届くはず
        assert msg.typename == "output_stream_data", f"Returned: {msg.typename}"
        print(f"{msg.sequence_number:04} {msg.flags} {msg.payload.type} Msg:{len(msg.to_bytes()):4}B Payload:{len(msg.payload.content):4}B")

        # 2回目以降のseq=0はスキップ(Pingっぽい?不定期に届く)
        if self.received_seq >= 0 and msg.sequence_number == 0: return msg

        # Seq確認(received_seq+1が届くはず)
        expected_seq = self.received_seq + 1
        assert msg.sequence_number == expected_seq, f"Seq: {msg.sequence_number}, Expected: {expected_seq}"
        self.received_seq = msg.sequence_number  # seqを合わせる

        # データを処理
        self.handler.handle_data(msg.payload.content)

        return msg

    def send(self, data):
        # メッセージを送信する
        self.sent_seq += 1                          # 送信Seqをインクリメント
        msg = InputStreamData(data, self.sent_seq)  # 送信メッセージ作成
        self.conn.send(msg.to_bytes())              # メッセージ送信
        assert len(msg.payload.content) <= PAYLOAD_MAX_SIZE, f"Payload max size exceeded ({len(msg.payload.content)}bytes)"
        print(f"SEND: Msg:{len(msg.to_bytes()):4}B Payload:{len(msg.payload.content):4}B")

        while True:
            # ACKを待つ
            # output_stream_dataが届いたらrecv()内で処理される
            resp = self.recv()
            if resp.typename == "acknowledge":
                # ACK応答を確認する
                obj = json.loads(resp.payload.content)
                assert obj["AcknowledgedMessageType"] == msg.typename, \
                    f"{msg.typename}, Expected: {obj['AcknowledgedMessageType']}"
                assert obj["AcknowledgedMessageId"] == str(msg.message_id), \
                    f"{str(msg.message_id)}, Expected: {obj['AcknowledgedMessageId']}"
                assert obj["AcknowledgedMessageSequenceNumber"] == msg.sequence_number, \
                    f"{msg.sequence_number}, Expected: {obj['AcknowledgedMessageSequenceNumber']}"
                assert obj["IsSequentialMessage"], f"{obj['IsSequentialMessage']}, Expected: True"
                break

class SimpleHandler:
    # 受け取ってそのまま格納する
    def __init__(self): self.received_data = b""
    def __str__(self): return self.received_data.decode()
    def handle_data(self, data): self.received_data += data
    def close(self): pass

class Base64StreamHandler:
    # Base64データを受け取り、デコードする
    BUFFER_SIZE = 1024 * 128    # 128KB

    def __init__(self, stream, bufsize=BUFFER_SIZE):
        self.stream = stream
        self.b64buffer = b""
        self.bufsize = bufsize

    def handle_data(self, data):
        self.b64buffer += data
        if len(self.b64buffer) >= self.bufsize:
            self._flush_to_stream()

    def close(self):
        self._flush_to_stream()

    def _flush_to_stream(self):
        # b64データをデコードしてストリームに書き出す
        self.stream.write(b64decode(self.b64buffer[:self.bufsize]))
        self.b64buffer = self.b64buffer[self.bufsize:]

def open_stream(ws_url, token, receive_handler, callback=None):
    # WebSocketストリームを開き、読み取る
    # callbackが無い場合は書き込まない
    # start_publication -> output_stream_data ... -> channel_closed
    with connect(ws_url) as ws:
        print("Connected.")
        init_payload = {"MessageSchemaVersion": "1.0", "RequestId": str(uuid.uuid4()), "TokenValue": token}
        ws.send(json.dumps(init_payload))

        try:
            client = MessageClient(ws, receive_handler)
            while True:
                msg = client.recv()
                if msg.typename == "start_publication": continue    # 通信開始
                if msg.typename == "channel_closed": break          # 通信終了

                if callback:    # 一度だけcallbackをコールする
                    callback(client)
                    callback = None
        finally:
            # 切断する
            print("Close")
            receive_handler.close()
            ws.close_socket()

def open_session(cmd, cluster, service_name=None, task=None):
    # セッションを開始する
    ecs = boto3.client("ecs")

    # クラスターのタスクIDを得る
    # タスクが指定されていたらtaskを使用する
    param = {"cluster": cluster}
    if task:
        task_id = task
    else:
        if service_name: param["serviceName"] = service_name
        res = ecs.list_tasks(**param)
        task_id = res["taskArns"][0]

    # コマンド実行
    res = ecs.execute_command(
        cluster=cluster,
        interactive=True,
        task=task_id,
        command=cmd,
    )

    return {
        "wss": res["session"]["streamUrl"],     # WebSocketのURL
        "token": res["session"]["tokenValue"],  # 接続時に初期化するためのトークン値
    }

def setup_callback(src, dst):
    # メッセージ送信用コールバック
    def _write_callback(client):
        # ペイロードを1024bytes以下で送る(送られてくるペイロードが1024bytes以下)
        # 6bitごとにエンコードされるので3bytesならパディング不要にできる
        # 送信チャンクごとに受け側でデコードする

        # ファイルを削除
        client.send(f"rm -f {dst}\n".encode())

        # base64にして転送する(ペイロードを1024バイトに近づけて送る)
        cmdtmpl = "echo -n {} | base64 -d >> {}\n"
        cmdlen = len(cmdtmpl.format("", dst))
        chunk_size = int((PAYLOAD_MAX_SIZE - cmdlen) / 4) * 3   # base64でサイズが4/3になる
        with open(src, "rb") as fh:
            for buf in iter(lambda: fh.read(chunk_size), b""):
                b64chunk = b64encode(buf)
                client.send(cmdtmpl.format(b64chunk.decode(), dst).encode())

        # セッション終了
        client.send(b"exit\n")

    return _write_callback

def get_file(src, dst, cluster, service_name=None, task=None):
    # ファイルを取得
    print(src, dst)
    ssn = open_session(f"base64 -w0 {src}", cluster, service_name, task)
    with open(dst, "wb") as fh:
        handler = Base64StreamHandler(fh)
        open_stream(ssn["wss"], ssn["token"], handler)

def put_file(src, dst, cluster, service_name=None, task=None):
    # ファイルを送信
    ssn = open_session("bash", cluster, service_name, task)
    handler = SimpleHandler()
    open_stream(ssn["wss"], ssn["token"], handler, setup_callback(src, dst))

if __name__ == "__main__":
    import sys
    import argparse

    parser = argparse.ArgumentParser(description="ECSコンテナとファイル転送")

    # 必須の引数
    parser.add_argument("--cluster", required=True, type=str, help="ECSクラスター名")

    # 相互排他のオプション引数
    group = parser.add_mutually_exclusive_group()
    group.add_argument("--service-name", type=str, help="ServiceNameを指定")
    group.add_argument("--task", type=str, help="TaskIDを指定")

    # 位置引数
    parser.add_argument("src", type=str, help="ソースファイルパス")
    parser.add_argument("dst", type=str, help="宛先ファイルパス")

    args = parser.parse_args()

    if args.src.startswith("ecs://") == args.dst.startswith("ecs://"):
        print("コンテナはSRCまたはDSTに指定してください")
        sys.exit(1)

    if args.src.startswith("ecs://"):
        src = args.src.replace("ecs://", "", 1)
        get_file(src, args.dst, args.cluster, args.service_name, args.task)
    else:
        dst = args.dst.replace("ecs://", "", 1)
        put_file(args.src, dst, args.cluster, args.service_name, args.task)

コード説明

class Payload

ペイロードクラス。メッセージのペイロード部分。
ペイロードのエンコード、デコードを行う。

class Message

メッセージクラス。メッセージを扱う。ペイロードオブジェクトを保持。
メッセージのエンコード、デコードを行う。ACKメッセージを生成する。

class AcknowledgeMessage(Message)

受信確認(ACK)メッセージ。

class InputStreamData(Message)

データ送信用メッセージ。

class MessageClient

websockets.sync.client.connect()を介して通信します。

class SimpleHandler

受信データの処理用ハンドラ。単純に受信したデータをreceived_dataに蓄積します。

class Base64StreamHandler

Base64データを受信し、デコードしてファイルに書き出します。

def open_stream(ws_url, token, receive_handler, callback=None)

セッションURLに接続し、通信を開始します。通信開始後はMessageClientオブジェクトを生成し、受信ループを回します。
最初のoutput_stream_dataを受け取った後、callbackが指定されていれば、callback(client)を実行します(このときにデータ送信のサブループを回しています)。

def open_session(cmd, cluster, service_name=None, task=None)

ECSのタスクを探し、Execute Commandを実行し、ストリームURLとトークンを得ます。

setup_callback

データ送信の処理を行います。

参考

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?