LoginSignup
12
2

More than 1 year has passed since last update.

OpenAPI Generatorで生成したRustのコードを本番利用するまで

Last updated at Posted at 2022-12-05

これは株式会社LabBase テックカレンダー Advent Calendar 2022 4日目の記事です。

株式会社LabBaseでエンジニアをしている高橋です。
昨日の記事は @yiwi さんによるRust学習の試行錯誤の記事でした。(こちら)

今回はOpenAPI Generatorで生成したコードRustのコードを、検証プロダクトの本場環境利用するまでに行ったことについて紹介します。

TL;DR

OpenAPI GeneratorのRust言語サポートについて

OpenAPI Generator allows generation of API client libraries (SDK generation), server stubs, documentation and configuration automatically given an OpenAPI Spec (both 2.0 and 3.0 are supported). Currently, the following languages/frameworks are supported:

openapi-generatorより引用。

Open API GeneratorはOpen API Specificationで記載されたResfful APIの仕様に基づいてコードやドキュメントを生成してくれます。多くのフレームワークをサポートしており、Rust言語についてはrust-server を指定するとRustのコードが生成されます。

# こんなコマンドを指定するとコードが生成されます
openapi-generator-cli generate -g rust-server -i docs/api.yaml -o ../webapp/openapi/

生成時に指定したOpen API Specificationの内容を元に、各APIエンドポイントを表現する Api traitが生成され、Api traitのimplementによって実際のAPIの動作を実装することができます。

// openapi/src/lib.rs
#[async_trait]
pub trait Api<C: Send + Sync> {
    fn poll_ready(
        &self,
        _cx: &mut Context,
    ) -> Poll<Result<(), Box<dyn Error + Send + Sync + 'static>>> {
        Poll::Ready(Ok(()))
    }

    // sample GETの定義
    async fn sample_get(&self, context: &C) -> Result<SampleGetResponse, ApiError>;
    // 他のエンドポイントの定義が続く
}

// webapp/openapi/examples/server/server.rs
// サーバーをstructで定義(自動生成されたサンプルコードより)
#[derive(Copy, Clone)]
pub struct Server<C> {
    marker: PhantomData<C>,
}

impl<C> Server<C> {
    pub fn new() -> Self {
        Server {
            marker: PhantomData,
        }
    }
}

// Apiトレイトの実装(サンプルなのでApiErrorを返しています)
#[async_trait]
impl<C> Api<C> for Server<C>
where
    C: Has<XSpanIdString> + Send + Sync,
{
    async fn sample_get(&self, context: &C) -> Result<SampleGetResponse, ApiError> {
        let context = context.clone();
        info!("sample_get() - X-Span-ID: {:?}", context.get().0.clone());
        Err(ApiError("Generic failure".into()))
    }
}

// webapp/openapi/examples/server/main.rs
// tokio::mainでserver::createによってサーバーを起動
#[tokio::main]
async fn main() {
    env_logger::init();

    let matches = App::new("server")
        .arg(
            Arg::with_name("https")
                .long("https")
                .help("Whether to use HTTPS or not"),
        )
        .get_matches();

    let addr = "127.0.0.1:8080";

    server::create(addr, matches.is_present("https")).await;
}

サンプルレベルのAPIであれば特に問題なく実装することができます。Rustでもスキーマ駆動開発ができそうです。

rust-serverの機能サポートについて

rust-serverがサポートしている機能について記載がありますが、例えばBearerTokenをサポートしていないのでリクエストヘッダにJWTを指定した認証(認可も)がサポートされていません。

本番環境で使うためになんとかする必要があります。「できそう」と「できる」は違うのです。

サポートされていない機能の実装

今回の検証用プロダクトでは認証が必須なので、BearerTokenを実装しました。またCORSもサポートする必要があるのでプリフライトリクエストによるAccess-Control-Allow-Headersの解決も実装しました(初めて自分でCORSを実装しました…)。

BearerTokenの実装

BearerTokenはサポートしていないですがAuthDataというenumでJWTの値を保持しContextに格納する処理が存在するので、各エンドポイントの実際の処理でContextからJWTを解決しデコードすることで、JWTによる認証、認可が実現できます(記載時点で自動生成側でうまくハンドリングされておらず動かないですが、不具合っぽいのでその辺の修正の詳細は割愛して利用側だけ紹介します:pray:)。

// webapp/src/lib.rs
#[derive(Clone)]
pub struct Server<C> {
    marker: PhantomData<C>,
    // JWTのデコーダー
    auth_token_decoder: AuthTokenDecoder,
}

// 認証Guardの実行結果
enum GuardResult {
    // JWTをデコードした値がOKで取得できる
    Ok(AccessTokenModel),
    // 認証エラー
    UnauthorizedErr(UnauthorizedErrorResponse),
    // 認可エラー
    ForbiddenError(ForbiddenErrorResponse),
}

impl<C> Server<C> {
    pub fn new(auth_token_decoder: AuthTokenDecoder) -> Self {
        Server {
            marker: PhantomData,
            auth_token_decoder,
        }
    }

    /// JWTをデコードした値が、引数の認可情報を所持しているかチェックします
    fn guard_service(
        &self,
        context: &C,
        guard_target: Vec<Service>,
    ) -> Result<GuardResult, ApiError>
    where
        C: Has<XSpanIdString> + Has<Option<AuthData>> + Send + Sync,
    {
        let auth_context = context as &(dyn Has<Option<AuthData>> + Send + Sync);
        let auth_data = auth_context.get();

        match auth_data {
            Some(AuthData::ApiKey(auth_data)) => {
                let splitted = auth_data.split(" ").collect::<Vec<_>>();
                let token = splitted
                    .get(1)
                    .ok_or_else(|| ApiError(format!("invalid bearer format.{}", auth_data)))?;

                // JWTのデコード
                let access_token_model = self.auth_token_decoder.decode_token(&token.to_string());
                if let Err(e) = &access_token_model {
                    if matches!(e.kind(), 
                    // JWTの期限切れ
                    auth_sdk::error::ErrorKind::ExpiredSignature) {
                        // 401エラーを返す
                        return Ok(GuardResult::UnauthorizedErr(
                            Self::create_unauthorized_response(context),
                        ));
                    }
                }
                let access_token_model = access_token_model.map_err(|e| {
                    ApiError(format!(
                        "tokenのデコードに失敗しました。:{:?}, detail:{:?}",
                        token, e
                    ))
                })?;

                // 認可チェック
                let check = &access_token_model
                    .scope()
                    .get()
                    .iter()
                    .map(|service_id| Service::try_from(service_id))
                    .collect::<Result<Vec<_>>>()
                    .map_err(|e| ApiError(format!("service_idが不正です。Error:{:?}", &e)))?
                    .iter()
                    .any(|service| guard_target.contains(service));
                if *check {
                    Ok(GuardResult::Ok(access_token_model))
                } else {
                    Ok(GuardResult::ForbiddenError(
                        Self::create_forbidden_response(context),
                    ))
                }
            }
            None => Err(ApiError(format!(
                "auth_dataが不正が存在しません:{:?}",
                &auth_data
            ))),
            _ => Err(ApiError(format!(
                "auth_dataが不正です。auth_data:{:?}",
                &auth_data
            ))),
        }
    }
}

#[async_trait]
impl<C> Api<C> for Server<C>
where
    C: Has<XSpanIdString> + Has<Option<AuthData>> + Send + Sync,
{
    async fn sample_get(&self, context: &C) -> Result<SampleGetResponse, ApiError> {
        // APIを利用できるサービスの種類を指定して、認証を行う
        let access_token_model =
            self.guard_service(context, vec![Service::HogeService])?;

        match access_token_model {
            GuardResult::Ok(access_token_model) => {
                // デコードしたモデルから認証情報を取得して処理を行う
            }
            GuardResult::ForbiddenError(error_response) => {
                Ok(SampleGetResponse::ForbiddenError(error_response))
            }
            GuardResult::UnauthorizedErr(error_response) => {
                Ok(SampleGetResponse::UnauthorizedError(error_response))
            }
        }
}

CORSの実装

Open API SpecificationではCORSをドキュメントで表現できないので(間違っていたらごめんなさい)、別に定義した情報を元にCORSを実現する必要があります。今回は環境変数で許可するOriginを定義し、プリフライトリクエストが送信された際にレスポンスヘッダに Access-Control-Allow-Originを返却するようにしてCORSの一部を実装しました

// webapp/openapi/src/server/mod.rs
impl<T, C> hyper::service::Service<(Request<Body>, C)> for Service<T, C>
where
    T: Api<C> + Clone + Send + Sync + 'static,
    C: Has<XSpanIdString> + Has<Option<AuthData>> + Send + Sync + 'static,
{
    fn call(&mut self, req: (Request<Body>, C)) -> Self::Future {
        async fn run<T, C>(
            mut api_impl: T,
            req: (Request<Body>, C),
        ) -> Result<Response<Body>, crate::ServiceError>
        where
            T: Api<C> + Clone + Send + 'static,
            C: Has<XSpanIdString> + Has<Option<AuthData>> + Send + Sync + 'static,
        {

            let (request, context) = req;
            let (parts, body) = request.into_parts();
            let (method, uri, headers) = (parts.method, parts.uri, parts.headers);
            let path = paths::GLOBAL_REGEX_SET.matches(uri.path());

            // envに指定がないときは固定値を返す
            let allow_origins =
                env::var("CORS_ALLOW_ORIGINS").unwrap_or(String::from("https://example.com"));
            let allow_origins = allow_origins.split(",").collect::<Vec<_>>();

            let allowed_origin = match headers.get(ORIGIN) {
                Some(origin) => {
                    // originと一致する物を探す
                    let allowed_origin = allow_origins
                        .iter()
                        .find(|allow_origin| **allow_origin == origin.to_str().unwrap());
                    match allowed_origin {
                        Some(allowed_origin) => {
                            // OPTIONSメソッドでリクエストされた場合
                            if method == Method::OPTIONS {
                                let mut response = Response::new(Body::empty());
                                response.headers_mut().insert(
                                    ACCESS_CONTROL_ALLOW_ORIGIN,
                                    HeaderValue::from_str(allowed_origin)
                                        .expect("cannot create header value"),
                                );
                                response.headers_mut().insert(
                                    ACCESS_CONTROL_ALLOW_METHODS,
                                    HeaderValue::from_str(
                                        &vec![
                                            Method::OPTIONS,
                                            Method::GET,
                                            Method::POST,
                                            Method::PUT,
                                            Method::DELETE,
                                        ]
                                        .iter()
                                        .map(|m| m.to_string())
                                        .collect::<Vec<_>>()
                                        .join(", "),
                                    )
                                    .expect("cannot create header value"),
                                );
                                response.headers_mut().insert(
                                    ACCESS_CONTROL_ALLOW_HEADERS,
                                    HeaderValue::from_str(
                                        &vec![
                                            CONTENT_TYPE.to_string(),
                                            ACCEPT.to_string(),
                                            ORIGIN.to_string(),
                                            PRAGMA.to_string(),
                                            AUTHORIZATION.to_string(),
                                            String::from("X-Requested-With"),
                                        ]
                                        .into_iter()
                                        .collect::<Vec<String>>()
                                        .join(", "),
                                    )
                                    .expect("cannot create header value"),
                                );
                                return Ok(response);
                            }
                            allowed_origin
                        }
                        // originが許可されていない場合
                        None => {
                            let mut response = Response::new(Body::empty());
                            response.headers_mut().insert(
                                ACCESS_CONTROL_ALLOW_ORIGIN,
                                HeaderValue::from_str(&allow_origins.get(0).unwrap_or_else(|| &""))
                                    .expect("invalid header value"),
                            );
                            return Ok(response);
                        }
                    }
                }
                None => {
                    // Originが未指定の場合は許可する
                    allow_origins.get(0).unwrap_or_else(|| &"")
                }
            };

            // 長い実装が続く…
        }
}

まとめ

以上、OpenAPI Generatorで生成したRustのコードで不足している機能を自力で実装した紹介でした。なんとかなりましたが結構な負債になってしまったので来年にはなんとかしたいです。

次回のアドベントカレンダーはゾネス君 (@takahiro-yamada)です。よろしくお願いします。

12
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
2