LoginSignup
6
3

More than 1 year has passed since last update.

Go で AWS のプライベート VPC リソースにアクセスする

Last updated at Posted at 2020-12-13

やりたいこと

AWS Systems Manager 経由で SSH トンネルを使用してプライベート VPC リソースにアクセスしたいと考えています。どうすればよいですか?

こちらの記事で紹介されているように、AWS Systems Manager Session Manager を利用することで、VPC 内に用意した踏み台サーバ経由でプライベート VPC 内のリソース (RDS など) にアクセスすることができます。

本来は SSH Client と AWS CLI を組み合わせて SSH トンネリング (ポートフォワーディング) を行うのですが、今回はこれを Go 言語と aws-sdk-go でやってみようと思います。
これにより、いちいち ssh コマンドを叩かずにプライベート VPC リソースに対してプログラムを実行できるのでかっこいいです(たぶん)。

検証環境

VPC 内の Private Subnet に RDS インスタンス (MySQL) と踏み台用の EC2 インスタンスがある環境を想定します。

bastion.png

図のようにローカル PC から Session Manager 経由で踏み台サーバに接続し、最終的にプライベートな RDS インスタンスに対して SHOW DATABASES を実行することをゴールとします。

上記の検証環境を再現する CloudFormation テンプレートを用意したのでお手元で試したい方は下記の詳細をご参照ください。

詳細

次のテンプレートを使用して CloudFormation スタックを作成すると検証環境を作成できます(ap-northeast-1 限定)。

AWSTemplateFormatVersion: 2010-09-09
Description: Create private RDS and bastion instance in VPC

Metadata:
  AWS::CloudFormation::Interface:
    ParameterGroups:
      - Label:
          default: VPC Configuration
        Parameters:
          - VPCCIDR
          - PrivateSubnetACIDR
          - PrivateSubnetCCIDR
      - Label:
          default: DB Configuration
        Parameters:
          - DBMasterUsername
          - DBMasterPassword
      - Label:
          default: Bastion Configuration
        Parameters:
          - BastionKeyPair
          - BastionImageId

    ParameterLabels:
      VPCCIDR:
        default: VPC CIDR
      PrivateSubnetACIDR:
        default: Private Subnet A CIDR
      PrivateSubnetCCIDR:
        default: Private Subnet C CIDR
      DBMasterUsername:
        default: Database Master Username
      DBMasterPassword:
        default: Database Master Password
      BastionKeyPair:
        default: Bastion Server Key Pair Name
      BastionImageId:
        default: Bastion Server Image ID (DO NOT CHANGE)

Parameters:
  VPCCIDR:
    Type: String
    Default: 10.1.0.0/24
  PrivateSubnetACIDR:
    Type: String
    Default: 10.1.0.1/26
  PrivateSubnetCCIDR:
    Type: String
    Default: 10.1.0.64/26
  DBMasterUsername:
    Type: String
    Default: root
  DBMasterPassword:
    Type: String
  BastionKeyPair:
    Type: String
  BastionImageId:
    Type: AWS::SSM::Parameter::Value<String>
    Default: /aws/service/ami-amazon-linux-latest/amzn2-ami-hvm-x86_64-gp2

Resources:
  VPC:
    Type: AWS::EC2::VPC
    Properties:
      CidrBlock: !Ref VPCCIDR
      EnableDnsSupport: true
      EnableDnsHostnames: true
      InstanceTenancy: default
      Tags:
        - Key: Name
          Value: ssm-bastion-example-vpc

  PrivateSubnetA:
    Type: AWS::EC2::Subnet
    Properties:
      AvailabilityZone: ap-northeast-1a
      CidrBlock: !Ref PrivateSubnetACIDR
      VpcId: !Ref VPC
      Tags:
        - Key: Name
          Value: ssm-bastion-example-private-subnet-a

  PrivateSubnetC:
    Type: AWS::EC2::Subnet
    Properties:
      AvailabilityZone: ap-northeast-1c
      CidrBlock: !Ref PrivateSubnetCCIDR
      VpcId: !Ref VPC
      Tags:
        - Key: Name
          Value: ssm-bastion-example-private-subnet-c

  PrivateRouteTable:
    Type: AWS::EC2::RouteTable
    Properties:
      VpcId: !Ref VPC
      Tags:
        - Key: Name
          Value: ssm-bastion-example-private-route

  PrivateSubnetRouteTableAssociationA:
    Type: AWS::EC2::SubnetRouteTableAssociation
    Properties:
      SubnetId: !Ref PrivateSubnetA
      RouteTableId: !Ref PrivateRouteTable

  PrivateSubnetRouteTableAssociationC:
    Type: AWS::EC2::SubnetRouteTableAssociation
    Properties:
      SubnetId: !Ref PrivateSubnetC
      RouteTableId: !Ref PrivateRouteTable

  VPCEndpointSecurityGroup:
    Type: AWS::EC2::SecurityGroup
    Properties:
      GroupDescription: Serucity group for vpc endpoint
      VpcId: !Ref VPC
      SecurityGroupIngress:
        - IpProtocol: tcp
          FromPort: 443
          ToPort: 443
          CidrIp: !Ref VPCCIDR

  VPCEndpointSSM:
    Type: AWS::EC2::VPCEndpoint
    Properties:
      ServiceName: com.amazonaws.ap-northeast-1.ssm
      VpcEndpointType: Interface
      VpcId: !Ref VPC
      SubnetIds:
        - !Ref PrivateSubnetA
      SecurityGroupIds:
        - !Ref VPCEndpointSecurityGroup
      PrivateDnsEnabled: true

  VPCEndpointSSMMessages:
    Type: AWS::EC2::VPCEndpoint
    Properties:
      ServiceName: com.amazonaws.ap-northeast-1.ssmmessages
      VpcEndpointType: Interface
      VpcId: !Ref VPC
      SecurityGroupIds:
        - !Ref VPCEndpointSecurityGroup
      SubnetIds:
        - !Ref PrivateSubnetA
      PrivateDnsEnabled: true

  VPCEndpointEC2Messages:
    Type: AWS::EC2::VPCEndpoint
    Properties:
      ServiceName: com.amazonaws.ap-northeast-1.ec2messages
      VpcEndpointType: Interface
      VpcId: !Ref VPC
      SecurityGroupIds:
        - !Ref VPCEndpointSecurityGroup
      SubnetIds:
        - !Ref PrivateSubnetA
      PrivateDnsEnabled: true

  VPCEndpointS3:
    Type: AWS::EC2::VPCEndpoint
    Properties:
      ServiceName: com.amazonaws.ap-northeast-1.s3
      VpcEndpointType: Gateway
      VpcId: !Ref VPC
      RouteTableIds:
        - !Ref PrivateRouteTable

  BastionRole:
    Type: AWS::IAM::Role
    Properties:
      Description: EC2 role for SSM
      AssumeRolePolicyDocument:
        Version: 2012-10-17
        Statement:
          - Effect: Allow
            Principal:
              Service:
                - ec2.amazonaws.com
            Action:
              - sts:AssumeRole
      ManagedPolicyArns:
        - arn:aws:iam::aws:policy/AmazonSSMManagedInstanceCore

  BastionInstanceProfile:
    Type: AWS::IAM::InstanceProfile
    Properties:
      Roles:
        - !Ref BastionRole

  BastionSecurityGroup:
    Type: AWS::EC2::SecurityGroup
    Properties:
      GroupDescription: Security group for bastion server
      GroupName: ssm-bastion-example-bastion-sg
      VpcId: !Ref VPC

  BastionServer:
    Type: AWS::EC2::Instance
    Properties:
      ImageId: !Ref BastionImageId
      InstanceType: t2.micro
      SubnetId: !Ref PrivateSubnetA
      SecurityGroupIds:
        - !Ref BastionSecurityGroup
      IamInstanceProfile: !Ref BastionInstanceProfile
      KeyName: !Ref BastionKeyPair
      Tags:
        - Key: Name
          Value: ssm-bastion-example-bastion-server

  DBSecurityGroup:
    Type: AWS::EC2::SecurityGroup
    Properties:
      GroupDescription: Security group for bastion exmaple db
      GroupName: ssm-bastion-example-db-sg
      VpcId: !Ref VPC
      SecurityGroupIngress:
        - SourceSecurityGroupId: !Ref BastionSecurityGroup
          IpProtocol: tcp
          FromPort: 3306
          ToPort: 3306

  DBSubnetGroup:
    Type: AWS::RDS::DBSubnetGroup
    Properties:
      DBSubnetGroupDescription: DB subnet group for bastion example db
      DBSubnetGroupName: ssm-bastion-example-db-sng
      SubnetIds:
        - !Ref PrivateSubnetA
        - !Ref PrivateSubnetC
      Tags:
        - Key: Name
          Value: ssm-bastion-example-db-sng

  DB:
    Type: AWS::RDS::DBInstance
    Properties:
      DBInstanceClass: db.t2.micro
      Engine: MySQL
      AllocatedStorage: 5
      PubliclyAccessible: false
      DBSubnetGroupName: !Ref DBSubnetGroup
      VPCSecurityGroups:
        - !GetAtt DBSecurityGroup.GroupId
      MasterUsername: !Ref DBMasterUsername
      MasterUserPassword: !Ref DBMasterPassword

Outputs:
  BastionInstanceId:
    Description: Bastion server instance id
    Value: !Ref BastionServer
  DBEndpoint:
    Description: Database endpoint
    Value: !GetAtt DB.Endpoint.Address

スタック作成時に次のパラメータを適切に設定してください。

  • VPC CIDR
    • 作成する VPC の CIDR ブロックです
    • 既存の VPC とぶつかる場合は適切な値に変更してください
  • Private Subnet A CIDR, Private Subnet C CIDR
    • 作成する Private Subnet の CIDR ブロックです
    • VPC CIDR を変更した場合はこちらも適切な値に変更してください
  • Database Master Password
    • 作成する RDS インスタンスのマスタパスワードです
  • Bastion Server Key Pair Name
    • 踏み台サーバに接続するためのキーペア名です
    • キーペアは予め作成し、秘密鍵をローカル PC にダウンロードしておいてください

どうやって実装するか

SSH Client と AWS CLI の処理を Go で書ければ実現可能なはずです。

SSH Client + AWS CLI の場合の処理の流れは次の通りです。

  1. SSH Client から ProxyCommand として aws ssm start-session を実行する
    1. AWS CLI は SSM の StartSession を呼んでセッションを開始する
    2. StartSession のレスポンスとして得られた URL とトークンを使って踏み台インスタンスと WebSocket で通信する
  2. SSH Client が ProxyCommand の通信を利用して RDS のポートをローカルポートにフォワーディングする

SSH Client の処理は golang.org/x/crypto/ssh パッケージを利用することで実装可能です。

問題は AWS CLI の処理です。
aws-sdk-goSSM.StartSession() で WebSocket 通信用の URL とトークンを得ることができますが、その後の WebSocket 通信の仕様が明らかにされていないのでどう使えばよいのか全くの不明です。

AWS CLI の実装を見ると、 Boto3 の SSM.Client.start_session で得られた出力を session-manager-plugin に渡していることが分かります。

どうやら WebSocket 通信はこの session-manager-plugin に任せているようです(session-manager-plugin はバイナリ形式で配布されているため実装の詳細は不明)。

WebSocket 通信の仕様を頑張って解読するのは不毛な上、いつ変更されるかもわからないので今回は AWS CLI 同様に session-manager-plugin を呼び出す形で実装することにします。
AWS CLI と全く同じ呼び出し方をしてあげれば問題なく使えるはずです。
また、AWS CLI と session-manager-plugin 間のインタフェースは互換性を保つためにそう簡単には変更されないものと予想されます。

2021/06/07 追記:
session-manager-plugin が OSS になりました!
しかも Go で書かれているためコア部分をそのままライブラリとして利用することもできそうです。

実装

go.mod
module port-forward

go 1.15

require (
    github.com/aws/aws-sdk-go v1.36.0
    github.com/go-sql-driver/mysql v1.5.0
    golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
)
main.go
package main

import (
    "database/sql"
    "encoding/json"
    "errors"
    "flag"
    "fmt"
    "io"
    "io/ioutil"
    "net"
    "os"
    "os/exec"
    "path"
    "runtime"
    "strconv"

    "github.com/aws/aws-sdk-go/aws"
    "github.com/aws/aws-sdk-go/aws/session"
    "github.com/aws/aws-sdk-go/service/ssm"
    "github.com/go-sql-driver/mysql"
    "golang.org/x/crypto/ssh"
    "golang.org/x/crypto/ssh/knownhosts"
)

type config struct {
    instanceID string
    region     string
    user       string
    keyPath    string
    localPort  uint16
    dbHost     string
    dbPort     uint16
    dbUser     string
    dbPass     string
}

func main() {
    conf := &config{}

    var localPort, dbPort uint

    flags := flag.NewFlagSet("port-forward", flag.ContinueOnError)
    flags.StringVar(&conf.instanceID, "instance-id", "", "bastion server instance id")
    flags.StringVar(&conf.region, "region", "ap-northeast-1", "aws region")
    flags.StringVar(&conf.user, "ssh-user", "ec2-user", "ssh user for bastion server")
    flags.StringVar(&conf.keyPath, "key", "", "ssh key file path")
    flags.UintVar(&localPort, "local-port", 9090, "local port for port-fowarding")
    flags.StringVar(&conf.dbHost, "db-host", "", "database host")
    flags.UintVar(&dbPort, "db-port", 3306, "database port")
    flags.StringVar(&conf.dbUser, "db-user", "root", "database user")
    flags.StringVar(&conf.dbPass, "db-pass", "", "database password")
    if err := flags.Parse(os.Args[1:]); err != nil {
        os.Exit(2)
    }

    conf.localPort = uint16(localPort)
    conf.dbPort = uint16(dbPort)

    if err := run(conf); err != nil {
        fmt.Fprintln(os.Stderr, err)
        os.Exit(1)
    }
    os.Exit(0)
}

func run(conf *config) error {
    sess, err := session.NewSession(&aws.Config{
        Region: aws.String(conf.region),
    })
    if err != nil {
        return err
    }

    svc := ssm.New(sess)

    proxyCmd, closeSession, err := openSession(svc, conf.instanceID)
    if err != nil {
        return err
    }
    defer closeSession()

    sshConfig, err := newSSHClientConfig(conf.user, conf.keyPath)
    if err != nil {
        return err
    }

    client, killProxyCmd, err := newSSHClientWithProxyCommand(conf.instanceID, 22, proxyCmd, sshConfig)
    if err != nil {
        return err
    }
    defer killProxyCmd()
    defer client.Close()

    done, err := portForward(conf.localPort, client, conf.dbHost, conf.dbPort)
    if err != nil {
        return err
    }
    defer done()

    if err := printDBList("localhost", conf.localPort, conf.dbUser, conf.dbPass); err != nil {
        return err
    }

    return nil
}

// openSession AWS Systems Manager Session Manager のセッションを開始し、
// session-manager-plugin を実行する *exec.Cmd とセッションを終了する関数を返す。
func openSession(svc *ssm.SSM, instanceID string) (*exec.Cmd, func() error, error) {
    in := &ssm.StartSessionInput{
        DocumentName: aws.String("AWS-StartSSHSession"),
        Parameters: map[string][]*string{
            "portNumber": {aws.String("22")},
        },
        Target: aws.String(instanceID),
    }
    out, err := svc.StartSession(in)
    if err != nil {
        return nil, nil, err
    }

    close := func() error {
        in := &ssm.TerminateSessionInput{
            SessionId: out.SessionId,
        }
        if _, err := svc.TerminateSession(in); err != nil {
            return err
        }
        return nil
    }

    cmd, err := sessionManagerPlugin(svc, in, out)
    if err != nil {
        defer close()
        return nil, nil, err
    }

    return cmd, close, nil
}

// sessionManagerPlugin session-manager-plugin を実行する *exec.Cmd を返す。
func sessionManagerPlugin(
    svc *ssm.SSM,
    in *ssm.StartSessionInput,
    out *ssm.StartSessionOutput,
) (*exec.Cmd, error) {
    command := "session-manager-plugin"
    if runtime.GOOS == "windows" {
        command += ".exe"
    }

    encodedIn, err := json.Marshal(in)
    if err != nil {
        return nil, err
    }
    encodedOut, err := json.Marshal(out)
    if err != nil {
        return nil, err
    }
    region := *svc.Config.Region
    profile := getAWSProfile()
    endpoint := svc.Endpoint

    cmd := exec.Command(command, string(encodedOut), region,
        "StartSession", profile, string(encodedIn), endpoint)

    return cmd, nil
}

// getAWSProfile 有効な AWS Profile を取得する。
func getAWSProfile() string {
    profile := os.Getenv("AWS_PROFILE")
    if profile != "" {
        return profile
    }

    enableSharedConfig, _ := strconv.ParseBool(os.Getenv("AWS_SDK_LOAD_CONFIG"))
    if enableSharedConfig {
        profile = os.Getenv("AWS_DEFAULT_PROFILE")
    }

    return profile
}

// newSSHClientConfig *ssh.ClientConfig を生成する。
func newSSHClientConfig(user string, keyPath string) (*ssh.ClientConfig, error) {
    key, err := ioutil.ReadFile(keyPath)
    if err != nil {
        return nil, err
    }

    signer, err := ssh.ParsePrivateKey(key)
    if err != nil {
        return nil, err
    }

    hostKeyCallback, err := newHostKeyCallback()
    if err != nil {
        return nil, err
    }

    return &ssh.ClientConfig{
        User: user,
        Auth: []ssh.AuthMethod{
            ssh.PublicKeys(signer),
        },
        HostKeyCallback: hostKeyCallback,
    }, nil
}

// newHostKeyCallback ~/.ssh/known_hosts を参照して
// ホストの公開鍵を確認する ssh.HostKeyCallback を返す。
func newHostKeyCallback() (ssh.HostKeyCallback, error) {
    home, err := os.UserHomeDir()
    if err != nil {
        return nil, err
    }

    knownHosts := path.Join(home, ".ssh", "known_hosts")

    cb, err := knownhosts.New(knownHosts)
    if err != nil {
        return nil, err
    }

    return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
        // net.Pipe() から生成した net.Conn で ssh.Conn を作ると
        // remote.String() の値が "pipe" となり net.SplitHostPort() が失敗してしまう。
        // https://github.com/golang/crypto/blob/5f87f3452ae9/ssh/knownhosts/knownhosts.go#L336
        //
        // hostname には `${instance-id}:22` が入っているので
        // それを返す net.Addr に差し替えておく。
        if remote.String() == "pipe" {
            remote = &addrImpl{
                network: remote.Network(),
                addr:    hostname,
            }
        }

        err := cb(hostname, remote, key)

        var ke *knownhosts.KeyError
        if errors.As(err, &ke) {
            // known_hosts と一致しない場合はエラー
            if len(ke.Want) > 0 {
                return ke
            }

            f, err := os.OpenFile(knownHosts, os.O_WRONLY|os.O_APPEND, 0644)
            if err != nil {
                return err
            }
            defer f.Close()

            // 未知のホストの場合は known_hosts に追記する
            line := knownhosts.Line([]string{remote.String()}, key)
            fmt.Fprintln(f, line)

            return nil
        }

        return err
    }, nil
}

// addrImple net.Addr の実装。
type addrImpl struct {
    network string
    addr    string
}

func (s *addrImpl) Network() string {
    return s.network
}

func (s *addrImpl) String() string {
    return s.addr
}

// newSSHClientWithProxyCommand ProxyCommand を利用した *ssh.Client を返す。
func newSSHClientWithProxyCommand(
    host string,
    port uint16,
    proxyCmd *exec.Cmd,
    conf *ssh.ClientConfig,
) (*ssh.Client, func() error, error) {
    c, s := net.Pipe()

    proxyCmd.Stdin = s
    proxyCmd.Stdout = s
    proxyCmd.Stderr = os.Stderr

    if err := proxyCmd.Start(); err != nil {
        return nil, nil, err
    }

    done := func() error {
        return proxyCmd.Process.Kill()
    }

    addr := fmt.Sprintf("%s:%d", host, port)
    conn, chans, reqs, err := ssh.NewClientConn(c, addr, conf)
    if err != nil {
        defer done()
        return nil, nil, err
    }

    client := ssh.NewClient(conn, chans, reqs)

    return client, done, nil
}

// portForward ポートフォワードを行う。
func portForward(
    localPort uint16,
    sshClient *ssh.Client,
    remoteHost string,
    remotePort uint16,
) (func(), error) {
    listener, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort))
    if err != nil {
        return nil, err
    }

    remoteAddr := fmt.Sprintf("%s:%d", remoteHost, remotePort)

    done := make(chan struct{})

    go func() {
        defer listener.Close()

        for {
            select {
            case <-done:
                return
            default:
            }

            localConn, err := listener.Accept()
            if err != nil {
                var ne net.Error
                if errors.As(err, &ne) && ne.Temporary() {
                    continue
                }
                fmt.Fprintln(os.Stderr, "accept failed: ", err)
                return
            }

            remoteConn, err := sshClient.Dial("tcp", remoteAddr)
            if err != nil {
                fmt.Fprintln(os.Stderr, "dial failed: ", err)
                return
            }

            go func() {
                defer localConn.Close()
                defer remoteConn.Close()
                if _, err := io.Copy(remoteConn, localConn); err != nil {
                    fmt.Fprintln(os.Stderr, "copy failed: ", err)
                }
            }()

            go func() {
                if _, err := io.Copy(localConn, remoteConn); err != nil {
                    fmt.Fprintln(os.Stderr, "copy failed: ", err)
                }
            }()
        }
    }()

    return func() {
        close(done)
    }, nil
}

// printDBList RDS に接続し DB 一覧を出力する。
func printDBList(host string, port uint16, user, password string) error {
    conf := mysql.NewConfig()
    conf.User = user
    conf.Passwd = password
    conf.Addr = fmt.Sprintf("%s:%d", host, port)
    conf.Net = "tcp"

    dsn := conf.FormatDSN()
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        return err
    }

    res, err := db.Query("SHOW DATABASES")
    if err != nil {
        return err
    }
    defer res.Close()

    var database string
    for res.Next() {
        if err := res.Scan(&database); err != nil {
            return err
        }
        fmt.Println(database)
    }

    if err := res.Err(); err != nil {
        return err
    }

    return nil
}

前述の通り session-manager-plugin がインストールされている必要があります。

次のように踏み台サーバのインスタンス ID ・ RDS インスタンスのエンドポイント・秘密鍵のパスなどを与えて実行します。
(デフォルトでは RDS インスタンスの 3306 ポートがローカルの 9090 ポートにフォワーディングされます)

$ go run main.go -instance-id i-xxxxxx -key ~/.ssh/bastion.key -db-host xxxx.xxxx.ap-northeast-1.rds.amazonaws.com -db-pass xxxx
information_schema
mysql
performance_schema

プライベート VPC 内の RDS インスタンスに SHOW DATABASES を実行して得られた DB 一覧が出力されます。

コードの解説

要点だけを解説します。

メインの処理は run() 関数に実装されています。

func run(conf *config) error {
    sess, err := session.NewSession(&aws.Config{
        Region: aws.String(conf.region),
    })
    if err != nil {
        return err
    }

    svc := ssm.New(sess)

    proxyCmd, closeSession, err := openSession(svc, conf.instanceID)
    if err != nil {
        return err
    }
    defer closeSession()

    sshConfig, err := newSSHClientConfig(conf.user, conf.keyPath)
    if err != nil {
        return err
    }

    client, killProxyCmd, err := newSSHClientWithProxyCommand(conf.instanceID, 22, proxyCmd, sshConfig)
    if err != nil {
        return err
    }
    defer killProxyCmd()
    defer client.Close()

    done, err := portForward(conf.localPort, client, conf.dbHost, conf.dbPort)
    if err != nil {
        return err
    }
    defer done()

    if err := printDBList("localhost", conf.localPort, conf.dbUser, conf.dbPass); err != nil {
        return err
    }

    return nil
}

次のような流れになっています。

  1. openSession() で Session Manager のセッションを開始
  2. newSSHClientWithProxyCommand() で session-manager-plugin を ProxyCommand として使う SSH Client を生成
  3. portFoward() で RDS インスタンスのポートをローカルのポートにフォワーディング
  4. ローカルポートに対してクエリを実行

openSession()

// openSession AWS Systems Manager Session Manager のセッションを開始し、
// session-manager-plugin を実行する *exec.Cmd とセッションを終了する関数を返す。
func openSession(svc *ssm.SSM, instanceID string) (*exec.Cmd, func() error, error) {
    in := &ssm.StartSessionInput{
        DocumentName: aws.String("AWS-StartSSHSession"),
        Parameters: map[string][]*string{
            "portNumber": {aws.String("22")},
        },
        Target: aws.String(instanceID),
    }
    out, err := svc.StartSession(in)
    if err != nil {
        return nil, nil, err
    }

    close := func() error {
        in := &ssm.TerminateSessionInput{
            SessionId: out.SessionId,
        }
        if _, err := svc.TerminateSession(in); err != nil {
            return err
        }
        return nil
    }

    cmd, err := sessionManagerPlugin(svc, in, out)
    if err != nil {
        defer close()
        return nil, nil, err
    }

    return cmd, close, nil
}

SSM.StartSession() を叩いてセッションを開始します。
AWS CLI の実装と、SSH の ProxyCommand 設定での呼び出し方を参考に実装しています。

SSM.StartSession() の入出力を JSON エンコードしたものを session-manager-plugin に与える必要があるので、ここで session-manager-plugin を実行するための *exec.Cmd も生成してしまっています(実際に生成している箇所は sessionManagerPlugin())。

newSSHClientWithProxyCommand()

// newSSHClientWithProxyCommand ProxyCommand を利用した *ssh.Client を返す。
func newSSHClientWithProxyCommand(
    host string,
    port uint16,
    proxyCmd *exec.Cmd,
    conf *ssh.ClientConfig,
) (*ssh.Client, func() error, error) {
    c, s := net.Pipe()

    proxyCmd.Stdin = s
    proxyCmd.Stdout = s
    proxyCmd.Stderr = os.Stderr

    if err := proxyCmd.Start(); err != nil {
        return nil, nil, err
    }

    done := func() error {
        return proxyCmd.Process.Kill()
    }

    addr := fmt.Sprintf("%s:%d", host, port)
    conn, chans, reqs, err := ssh.NewClientConn(c, addr, conf)
    if err != nil {
        defer done()
        return nil, nil, err
    }

    client := ssh.NewClient(conn, chans, reqs)

    return client, done, nil
}

与えられた *exec.Cmd を ProxyCommand として使用する SSH Client を生成します。
net.Pipe() を使用してコマンドの入出力を SSH Client に結び付けるのがポイントです。

portFoward()

// portForward ポートフォワードを行う。
func portForward(
    localPort uint16,
    sshClient *ssh.Client,
    remoteHost string,
    remotePort uint16,
) (func(), error) {
    listener, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort))
    if err != nil {
        return nil, err
    }

    remoteAddr := fmt.Sprintf("%s:%d", remoteHost, remotePort)

    done := make(chan struct{})

    go func() {
        defer listener.Close()

        for {
            select {
            case <-done:
                return
            default:
            }

            localConn, err := listener.Accept()
            if err != nil {
                var ne net.Error
                if errors.As(err, &ne) && ne.Temporary() {
                    continue
                }
                fmt.Fprintln(os.Stderr, "accept failed: ", err)
                return
            }

            remoteConn, err := sshClient.Dial("tcp", remoteAddr)
            if err != nil {
                fmt.Fprintln(os.Stderr, "dial failed: ", err)
                return
            }

            go func() {
                defer localConn.Close()
                defer remoteConn.Close()
                if _, err := io.Copy(remoteConn, localConn); err != nil {
                    fmt.Fprintln(os.Stderr, "copy failed: ", err)
                }
            }()

            go func() {
                if _, err := io.Copy(localConn, remoteConn); err != nil {
                    fmt.Fprintln(os.Stderr, "copy failed: ", err)
                }
            }()
        }
    }()

    return func() {
        close(done)
    }, nil
}

ポートフォワードを行います。
Listener.Accept() で得られたローカルポートの net.Conn と SSH Client から RDS インスタンスに Dial() して得られた net.Conn とを goroutine 内で相互に io.Copy() することでポートフォワードを実現できます。

無限ループ内で Accept(), Dial() することで複数のコネクションを扱うことが可能です。

参考

6
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
3