LoginSignup
11
7

More than 3 years have passed since last update.

Rust+async/awaitでノンブロッキングGUI ~ webクライアント作成編

Last updated at Posted at 2019-08-25

導入

当記事はRust+async/awaitでノンブロッキングGUI ~ 頓挫編の続編です。

前回の記事では近々安定化予定の async/await を使い Rust でノンブロッキング GUI の実現を目指しました。
ところが既存の web クライアントライブラリで gtk-rs の Executor 上で動作するものは見当たらず、目論見は失敗に終わりました。
しかし gtk-rs のバックエンドとして使用していた gio 内にも futures 0.3 に対応した API が存在していました。

「なければ作ればいいじゃない」ということで、今回はこれを用いて web クライアントを自作し、リベンジを果たしたいと思います。

gio の Future API の利用法についてはこちらを参考にしました。

環境

rustc 1.39.0-nightly (760226733 2019-08-22)
async_await フィーチャの有効化には nightly ビルドが必要です。

※Windows の場合、msvc ツールチェインでは動作しません。glib-networkingライブラリが msvc 向けに提供されていないためです。Windows 環境下では gnu ツールチェインをお使いください。

準備

web モジュールの作成

web クライアント関係の機能をまとめるモジュールを作成します。

.
├── Cargo.toml
...
├── src
│   ├── main.rs
│   └── web.rs // <- NEW
└── target
    ...

futures フィーチャの有効化

gio の Future API を使用するために futures フィーチャを有効化する必要があります。
またより使い勝手の良い *_all_async_future はバージョン 2.44 以降の API であるため、v2_44 も同時に有効化します。

Cargo.toml
...
[dependencies]
...
- gio = "0.7.0"
+ gio = { version = "0.7.0", features = ["futures", "v2_44"] }
...

コーディング

基本的なフロー

まずは全体の流れを見てみましょう。

src/web.rs
use gio::{IOStreamExt, InputStreamExtManual, OutputStreamExtManual, SocketClientExt};

/// `gio`でダウンロードを行うリファレンス実装。
pub async fn get_async_essential(
    host: &str,
    path: &str,
    use_tls: bool,
    port: Option<u16>,
) -> Vec<u8> {
    // クライアントの生成
    let client = gio::SocketClient::new();
    client.set_tls(use_tls);

    // ポート指定がない場合、デフォルトポートを設定
    let port = port.unwrap_or_else(|| if use_tls { 80 } else { 443 });

    // ホストへ接続
    let conn = client
        .connect_to_host_async_future(host, port)
        .await
        .unwrap();

    // OutputStream の取得
    let output = conn.get_output_stream().unwrap();

    // リクエストの生成
    let request = format!(
        "GET {} HTTP/1.1\r\n\
         Host: {}\r\n\
         \r\n",
        path, host
    );

    // リクエストの送信
    let res = output
        .write_all_async_future(request, glib::PRIORITY_DEFAULT)
        .await;

    // 送信エラーのハンドリング
    if let Ok((_, _, Some(err))) | Err((_, err)) = res {
        panic!("Error: {:?}", err);
    }

    // InputStream の取得
    let input = conn.get_input_stream().unwrap();

    // レスポンスの継続受信
    let mut body = Vec::new();
    loop {
        // 戻り値:
        // Ok((v, u, e))
        //  : (引数として渡したバッファ、読み込んだバイト数、バッファリング中のエラー(無視してもいい))
        // Err((v, e))
        //  : (引数として渡したバッファ、エラー)
        let res = input
            .read_all_async_future(vec![0; 1024], glib::PRIORITY_DEFAULT)
            .await;

        match res {
            // 読み込んだバイト数が 0 == ダウンロード完了
            Ok((_, _u @ 0, _)) => break,
            // ダウンロード中
            Ok((mut chunk, u, _)) => {
                // チャンクを切り詰めてボディに追加
                chunk.truncate(u);
                body.extend(chunk);
            }
            // エラー発生
            Err((_, _)) => break,
        }
    }

    body
}

シグニチャの最前列に燦然と輝く async の文字が眩しいですね。これは非同期関数の証です。この関数から返されるのは Future を実装した存在型となり、Executor に渡されることで初めて処理を開始するようになります。
それはさておき基本的な流れを整理します。

  1. クライアントの生成
  2. (https の場合)SocketClientExt::set_tls(true) の呼び出し
  3. SocketClientExt::connect_to_host_async_future でホストに接続
  4. IOStreamExt::get_output_stream でOutputStreamの取得
  5. OutputStreamExtManual::write_all_async_future でリクエストを送信
  6. IOStreamExt::get_input_stream でInputStreamの取得
  7. InputStreamExtManual::read_all_async_future でレスポンスを受信
  8. 読み込んだバイトサイズが 0 or Err を返すまで7をループ

深くは考えず、こういうものとして飲み込んでください。

また *_all_async_future ではなく *_async_future という似たような名前のメソッドもありますが、*all_async_future は内部で繰り返し呼び出しを行い、可能な限り引数として与えられたバッファ全体に対する 書き込み/読み込み を試みるというのが違いです(たぶん)。

-> read_async_futureread_all_async_future

便利な一方、Ok を返した場合にも第 3 要素にループ呼び出し中のエラーが格納されるのでシグニチャとしては複雑になっています。

Err の場合にも引数として与えたバッファが戻り値で返ってきているのは、C の実装のバインディングなので Rust 側でアロケートしたバッファの所有権を取り戻し、Rust 側で開放するためだと思います(自信なし)。

では実際にこの関数を使ってデータをダウンロードしてみましょう。
ビュー側に逐一変更を加えるのも手間なので、ユニットテストを作成します。
内容についてはコメントをご参照ください。

src/web.rs
...
#[cfg(test)]
mod test {
    use futures::FutureExt;

    use super::*;

    /// メインループを起動して引数の`Future`が完了するまでブロックする。
    fn exec_async(future: impl std::future::Future<Output = ()> + 'static) {
        // メインコンテキスト == 非同期オペレーションを実行するスレッド
        // https://valadoc.org/glib-2.0/GLib.MainContext.push_thread_default.html

        // メインコンテキストの生成
        let main_context = glib::MainContext::new();
        // 現在のスレッドをメインコンテキストとしてマーク
        main_context.push_thread_default();
        // メインコンテキストを引数としてメインループを生成
        let main_loop = glib::MainLoop::new(Some(&main_context), false);
        // メインコンテキストにタスクを追加
        main_context.spawn_local(future.map({
            let main_loop = main_loop.clone();
            move |_| main_loop.quit()
        }));
        // メインループを開始。MainLoop::quit が呼び出されるまでブロックする
        main_loop.run();
        // 現在のスレッドをメインコンテキストから外す(不要)
        // コンテキストをネストさせる場合に push_thread_default と対になるように呼び出す
        main_context.pop_thread_default();
    }

    #[test]
    fn download_test() {
        let future = async move {
            get_async_essential("some.host.com", "/path?query=hoge", true, None).await;

            println!("success");
        };

        exec_async(future);
    }
}

テストを実行します。

~$ cargo test -- --nocapture
    Finished dev [unoptimized + debuginfo] target(s) in 0.33s
     Running target\debug\deps\async_image_downloader-10760be0578b0999.exe

running 1 test
success
test web::test::download_test ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out

無事に成功しました。1

実装

上記の関数が一応の動作をすることを確認したので、本格的な実装に移ります。
get_async関数を定義し、機能を追加していきましょう。

pub async fn get_async(...) -> ... {
    ...
}

URI のパース

まずはこの関数の使い勝手を改善しましょう。

let res = get_async_essential("some.host.com", "/path?query=hoge", true, Some(443)).await;

現在はこのように呼び出さなければならないところを

let res = get_async("https://some.host.com/path?query=hoge").await;

と URI を渡すだけで内部で自動的にパースされるようにします。
今回は hyper の内部クレートでもあるhttpを使用しましょう。

Cargo.toml に http を追加します。

Cargo.toml
...
[dependencies]
...
http = "0.1.18"

URI からリクエストへのパースは以下のように書けます。

let uri = http::Request::get("https://some.host.com/path?query=hoge")
        .body(())
        .unwrap()
        .into_parts()
        .0
        .uri;;

ボディには用がないので、URI 関連の情報を持つURIのみを取り出しています。

URI からはhostport_u16など、そのままな名前で必要な情報が取り出せます。
省略可能な部分は Option 型で返されるので、unwrap_orで適宜デフォルト値を設定します。

変更を加えたget_async関数は以下のようになります。

...
pub async fn get_async<T>(uri: T) -> ...
where
    http::Uri: http::HttpTryFrom<T>,
{
    // URI をパースして各情報を取得
    let uri = http::Request::get(uri)
        .body(())
        .unwrap()
        .into_parts()
        .0
        .uri;
    // ホスト
    let host = uri.host().unwrap();
    // TLS 有効/無効
    let use_tls = uri.scheme_part() == Some(&http::uri::Scheme::HTTPS);
    // ポート
    let port = uri.port_u16().unwrap_or(if use_tls { 443 } else { 80 });
    // パス & クエリ
    let path_and_query = uri.path_and_query().map(|p| p.as_str()).unwrap_or("/");
    ...

せっかくなのでジェネリクスにしました。境界がやや複雑ですが、これはRequest::getが引数に対して課すものをそのまま反映しています。
&strStringもこの境界を満たすので、無事呼び出し側は以下の通り書けるようになりました 🎉

let res = get_async("https://some.host.com/path?query=hoge").await;

レスポンスをヘッダとボディに分離

ダウンロードした生のデータの先頭にはヘッダが含まれます。
今回はデータを画像として扱うわけですが、何の処理をするにせよ実際のデータ本体であるボディだけを渡さなければなりません。
これを分離する機能を実装していきましょう。

ヘッダの終了は"\r\n\r\n"によって表現されるので、それを検出する find_crlf 関数を定義します。

/// レスポンス内のヘッダ終端部を探索。
/// 発見した場合はその位置を Ok として、未発見の場合は次回の探索開始位置を Err として返す。
fn find_crlf(buffer: &[u8]) -> Result<usize, usize> {
    buffer
        // スライスを 4 要素ずつ返すイテレータを生成
        .windows(4)
        .position(|b| b == b"\r\n\r\n")
        // マッチ位置は先頭なので b"\r\n\r\n" の文字数(4文字)分足す
        .map(|u| u + 4)
        // 発見できなかった場合、次回は (末尾 - 3) の位置から探索を開始する
        .ok_or_else(|| buffer.len().saturating_sub(4))
}

文字コードは UTF-8 ではなく ASCII なのでリテラル上は b"\r\n\r\n" となります。
この関数はループ内で呼び出されることになるのですが、パフォーマンスのために毎回データ全体を渡すのではなく、検査の進行度を表すインデックスを保持し、それ以降のスライスを渡すようにします。

    ...
    // レスポンスの継続受信を開始

    let mut buffer = Vec::new();
    let mut body = Vec::new();
    let mut header_found = false;
    let mut search_index = 0;

    loop {
        match input
            .read_all_async_future(vec![0; 1024], glib::PRIORITY_DEFAULT)
            .await
        {
            // ダウンロード完了
            Ok((_, _u @ 0, _)) => break,
            // ダウンロード進行中
            Ok((mut chunk, u, _)) => {
                // チャンクを切り詰める
                chunk.truncate(u);

                // ヘッダ発見済みの場合はボディに追加して continue
                if header_found  {
                    body.extend(chunk);
                    continue;
                }

                // ヘッダ未発見の場合は探索
                buffer.extend(chunk);
                // buffer から search_index 以降のスライスを渡す
                match find_crlf(&buffer[search_index..]) {
                    // ヘッダ発見
                    Ok(u) => {
                        header_found = true;
                        // ヘッダ以降の部分をボディとして切り出す
                        body = buffer.split_off(search_index + u);
                    }
                    // ヘッダ発見できず
                    Err(u) => {
                        // search_index を次回の探索開始位置まで進める
                        search_index += u;
                    }
                }
            }
            // エラー発生
            Err((_, err)) => ...

取り出したヘッダの方もパースして使い勝手のいい形に変換しましょう。
ヘッダを表現する構造体には http のHeaderMapを利用します。
バイト列を受け取ってHeaderMapを返すparse_header関数を定義します。
この関数内では正規表現を扱うので、デファクトスタンダードな正規表現ライブラリであるregexを使用します。

Cargo.toml
...
[dependencies]
...
regex = "1.2.1"
/// バイト列をヘッダとしてパースする。
fn parse_header(header: &[u8]) -> http::HeaderMap<http::HeaderValue> {
    // 複数行モードに設定して `{key}: {value}` のパターンを検索する
    regex::bytes::Regex::new(r"(?m)^(.+?)\s*:\s*(.+)\r$")
        .expect("failed to build Regex object")
        // マッチした全要素をイテレータとして返す
        .captures_iter(header)
        .filter_map(|cap| {
            // 第 1 キャプチャ部を key、第 2 キャプチャ部を value としてパース
            let (k, v) = (cap.get(1)?, cap.get(2)?);
            Some((
                http::header::HeaderName::from_bytes(k.as_bytes()).ok()?,
                http::HeaderValue::from_bytes(v.as_bytes()).ok()?,
            ))
        })
        // impl<T> FromIterator<(HeaderName, T)> for HeaderMap<T> の実装を利用
        .collect()
}

ざっくりとした説明になりますが、ヘッダは以下のようなフォーマットになっています。

HTTP/1.1 200 OK
content-type: text/css
content-length: 187200
header-name: header-value
...

この{ヘッダ名}: {ヘッダ値}をキャプチャする正規表現オブジェクトを生成し、両方のキャプチャが成功した場合にそれぞれHeaderNameHeaderValueとしてパース、両方ともパースに成功すればFromIteratorの実装を利用して(HeaderName, HeaderValue)のタプルからHeaderMap<HeaderValue>を生成するという流れになります。

この関数をヘッダの検出時に呼び出すことでデータ受信部は完成です。

    ...
    // レスポンスの継続受信を開始

    let mut buffer = Vec::new();
    let mut header = None;
    let mut body = Vec::new();
    let mut search_index = 0;

    loop {
        match input
            .read_all_async_future(vec![0; 1024], glib::PRIORITY_DEFAULT)
            .await
        {
            // ダウンロード完了
            Ok((_, _u @ 0, _)) => break,
            // ダウンロード進行中
            Ok((mut chunk, u, _)) => {
                // チャンクを切り詰める
                chunk.truncate(u);

                // ヘッダ発見済みの場合はボディに追加して continue
                if header.is_some() {
                    body.extend(chunk);
                    continue;
                }

                // ヘッダ未発見の場合は探索
                buffer.extend(chunk);
                // buffer から search_index 以降のスライスを渡す
                match find_crlf(&buffer[search_index..]) {
                    // ヘッダ発見
                    Ok(u) => {
                        // ヘッダ以降の部分をボディとして切り出し、残りをヘッダとしてパース
                        body = buffer.split_off(search_index + u);
                        header = Some(parse_header(&buffer));
                    }
                    // ヘッダ発見できず
                    Err(u) => {
                        // search_index を次回の探索開始位置まで進める
                        search_index += u;
                    }
                }
            }
            // エラー発生
            Err((_, err)) => ...

エラーハンドリング

仕上げに関数全体のエラーハンドリングを行い、unwrappanicをコードから取り除きましょう。

お手軽なのはBox<dyn std::error::Error>>でラップしてしまうことですが、今回はリッチで簡便なエラーハンドリングを提供するsnafuを使用することにします。

Cargo.toml に snafu を追加

Cargo.toml
...
[dependencies]
...
snafu = "0.4.4"

まとめたいエラーの数だけバリアントを持つ enum を定義し、Snafuマクロを derive します。

#[derive(Debug)]
pub enum Direction {
    Input,
    Output,
}

#[derive(Debug, snafu::Snafu)]
pub enum Error {
    /// URI のパースに失敗した。
    RequestParse { source: http::Error },
    /// `glib`内でエラーが発生した。
    Glib { source: glib::Error },
    /// 無効なホスト。
    #[snafu(display("Invalid host"))]
    InvalidHost,
    /// `Input/OutputStream`の取得に失敗した。
    #[snafu(display("Could not open {:?} stream", direction))]
    AcquireStream { direction: Direction },
}

エラー型の定義に必要なコードはたったこれだけ。実際の利用に関わるボイラープレートの生成は全て Snafu マクロが行ってくれます。すごい 👀

sourceフィールドを持つバリアントは Result から変換されるエラー、持たないバリアントは Option 型から変換、もしくはロジックによって判定されるエラーを表しています。
基本的には ResultExt/OptionExt をインポートし、それぞれのcontextメソッドにバリアント名を指定することで変換を行います。
正確な説明ではありませんが、大体このような認識で問題ありません。

以上を適用したコードがこちらになります。

use {
    gio::{IOStreamExt, InputStreamExtManual, OutputStreamExtManual, SocketClientExt},
    // 追加
    snafu::{OptionExt, ResultExt},
};
...
pub async fn get_async<T>(uri: T) -> Result<Vec<u8>, Error>
where
    http::Uri: http::HttpTryFrom<T>,
{
    // URI をパースして各情報を取得
    let uri = http::Request::get(uri)
        .body(())
        .context(RequestParse)?
        .into_parts()
        .0
        .uri;
    // ホスト
    let host = uri.host().context(InvalidHost)?;
    // TLS 有効/無効
    let use_tls = uri.scheme_part() == Some(&http::uri::Scheme::HTTPS);
    // ポート
    let port = uri.port_u16().unwrap_or(if use_tls { 443 } else { 80 });
    // パス & クエリ
    let path_and_query = uri.path_and_query().map(|p| p.as_str()).unwrap_or("/");

    // クライアントの生成
    let client = gio::SocketClient::new();
    client.set_tls(use_tls);

    // ホストに接続
    let conn = client
        .connect_to_host_async_future(host, port)
        .await
        .context(Glib)?;

    // OutputStream を取得
    let output = conn.get_output_stream().context(AcquireStream {
        direction: Direction::Output,
    })?;

    // リクエストの生成・送信
    let request = format!(
        "GET {} HTTP/1.1\r\n\
         Host: {}\r\n\
         \r\n",
        path_and_query, host
    );
    if let Ok((_, _, Some(err))) | Err((_, err)) = output
        .write_all_async_future(request, glib::PRIORITY_DEFAULT)
        .await
    {
        Err(err).context(Glib)?;
    }

    // InputStream を取得
    let input = conn.get_input_stream().context(AcquireStream {
        direction: Direction::Input,
    })?;
    ...
let output = conn.get_output_stream().context(AcquireStream {
        direction: Direction::Output,
    })?;

の部分に注目してください。
定義部で

    /// `Input/OutputStream`の取得に失敗した。
    #[snafu(display("Could not open {:?} stream", direction))]
    AcquireStream { direction: Direction },

と任意のフィールドを持たせたバリアントはそれをコンテキストとして保持し、displayアトリビュートで指定する Display トレイトを介してフォーマットされる際の文章に情報を加えることができます。

続いて論理エラーと偽陽性エラーを判定するロジックを加えます。
まずはヘッダの探索に上限を設けることにしましょう。数キロバイトほど探して見つからなければエラーとしてダウンロードを中止します。
(これは後に Stream の実装型に変換することをにらんでの実装です。)

閾値を表す定数とバリアントを追加し

...
/// ヘッダの探索を試行する閾値。
const HEADER_SEARCH_LIMIT: usize = 2048;
...
#[derive(Debug, snafu::Snafu)]
pub enum Error {
    ...
    /// ヘッダ未発見のまま閾値に到達した。
    #[snafu(display("Could not find header"))]
    HeaderNotFound,
}
...

get_async関数のデータ受信部に変更を加えます。

    ...
    // レスポンスの継続受信を開始
    ...

    loop {
        match input
            .read_all_async_future(vec![0; 1024], glib::PRIORITY_DEFAULT)
            .await
        {
            // ダウンロード完了
            Ok((_, _u @ 0, _)) => {
                // ヘッダ未発見の場合はエラー
                snafu::ensure!(header.is_some(), HeaderNotFound);
                break;
            }
            // ダウンロード進行中
            Ok((mut chunk, u, _)) => {
                ...
                // buffer から search_index 以降のスライスを渡す
                match find_crlf(&buffer[search_index..]) {
                    // ヘッダ発見
                    Ok(u) => ...
                    // ヘッダ発見できず
                    Err(u) => {
                        // search_index を次回の探索開始位置まで進める
                        search_index += u;
                        // 閾値に達したらヘッダの探索を諦めてエラーを返す
                        snafu::ensure!(search_index < HEADER_SEARCH_LIMIT, HeaderNotFound);
                    }
                }
            }
            // エラー発生
            Err((_, err)) => ...

ensureマクロは第 1 引数に渡された式が false を返した場合に第 2 引数のエラーを呼び出し元に返します。
assertの snafu 版という立ち位置です。紐付けられるエラーがないので、sourceフィールドを持たないバリアントしか指定することができません。

続いて偽陽性エラーの検出です。
ダウンロードが完了している場合にもread_all_async_futureが Err を返すことがあるのですが、データ自体は正常に受信されているため、このエラーは無視します。
先程パースしておいた HeaderMap を使いましょう。

Err が返った時にヘッダのパースが行われていれば、ボディのバイト長を表すContent-Lengthプロパティを取得し、ダウンロード済みのデータのサイズと比較します。
ダウンロード済みのデータサイズがContent-Length以上ならエラーを無視します。

    ...
    // レスポンスの継続受信を開始
    ...

    loop {
        match input
            .read_all_async_future(vec![0; 1024], glib::PRIORITY_DEFAULT)
            .await
        {
            // ダウンロード完了
            Ok((_, _u @ 0, _)) => ...,
            // ダウンロード進行中
            Ok((mut chunk, u, _)) => ...
            // エラー発生
            Err((_, err)) => {
                // ダウンロード済みの容量が Content-Length 以上なら無視
                header
                    .as_ref()
                    .and_then(|header| header.get("content-length"))
                    .and_then(|v| v.to_str().ok())
                    .and_then(|v| v.parse().ok())
                    .filter(|&cl| body.len() >= cl)
                    .ok_or(err)
                    .context(Glib)?;
                break;
            }

以上でエラーハンドリングの実装も完了です。お疲れさまでした。

(補足) Snafu マクロの実際

実用的な説明に留めたSnafuマクロの振る舞いについて少し補足します。

Snafuマクロを derive すると、適用される型に対してErrorトレイトが実装されるだけでなく(Snafuトレイトのようなものはありません)、同モジュール内に コンテキストセレクタ と呼ばれる特殊な構造体が定義されます。

#[derive(Debug, snafu::Snafu)]
pub enum Error {
    /// URI のパースに失敗した。
    RequestParse { source: http::Error },
    /// `glib`内でエラーが発生した。
    Glib { source: glib::Error },
    /// 無効なホスト。
    #[snafu(display("Invalid host"))]
    InvalidHost,
    /// `Input/OutputStream`の取得に失敗した。
    #[snafu(display("Could not open {:?} stream", direction))]
    AcquireStream { direction: Direction },
    /// ヘッダ未発見のまま閾値に到達した。
    #[snafu(display("Could not find header"))]
    HeaderNotFound,
}

// この構造体定義から生成されるコンテキストセレクタ

struct RequestParse;
struct Glib;
struct InvalidHost;
struct AcquireStream<T>
// この境界は構造体定義ではなく実装側に課される
// where T: Into<Direction>,
{
    direction: T,
}
struct HeaderNotFound;

コンテキストセレクタはソースとなるエラー型(= 基底エラー型)から、ユーザー定義エラー型への変換を仲立ちしてくれます。各バリアントに 1:1 で対応するものが生成され、その際に元のバリアントでsourceの名前がつけられていたフィールドは基底エラー型として扱われ、コンテキストセレクタのフィールドからは除去されます。

contextメソッドにバリアント名を指定することで変換を行います」と説明した部分では、実際にはこのコンテキストセレクタを渡していた訳ですね。
Result型からの変換にはsourceフィールドを持ったバリアント(= 基底エラー型を持つバリアント)に対応するコンテキストセレクタが要求されます。当然この時には基底エラー型とレシーバであるResult<T, E>Eの型は一致していなければなりません。
逆にOption型からの変換には基底エラー型を持ったコンテクストセレクタを渡すことはできません。

注意点としてはコンテキストセレクタはプライベート構造体だということです。同階層以下のモジュールからしか参照できません。
これはエラーの送出元を特定のモジュールのみに限定できるという利点の裏返しでもあります。

(要検討) 未初期化バッファを渡す

ところでread_all_async_futureに渡すバッファですが、関数内で初期化され、返ってきた際には未初期化部分は切り捨てられていますね?
ならば呼び出し元で初期化する必要はないのではないでしょうか。

ということで初期化されていないバッファを返す関数を定義します。

/// バッファサイズ。
const BUFFER_SIZE: usize = 1024;
...
/// read_all_async_future に渡す未初期化バッファを生成。
fn create_uninit_buffer() -> Vec<u8> {
    unsafe {
        let mut v = Vec::with_capacity(BUFFER_SIZE);
        v.set_len(BUFFER_SIZE);
        v
    }
}

// 呼び出し側
input.read_all_async_future(create_uninit_buffer(), glib::PRIORITY_DEFAULT);

このような安易な unsafe コードの使用には賛否あるかもしれません。

完成

完成品がこちらになります。

詳細
src/web.rs
use {
    gio::{IOStreamExt, InputStreamExtManual, OutputStreamExtManual, SocketClientExt},
    snafu::{OptionExt, ResultExt},
};

pub mod stream;

/// バッファサイズ。
const BUFFER_SIZE: usize = 1024;
/// ヘッダの探索を試行する閾値。
const HEADER_SEARCH_LIMIT: usize = 2048;

#[derive(Debug)]
pub enum Direction {
    Input,
    Output,
}

#[derive(Debug, snafu::Snafu)]
pub enum Error {
    /// URI のパースに失敗した。
    RequestParse { source: http::Error },
    /// `glib`内でエラーが発生した。
    Glib { source: glib::Error },
    /// 無効なホスト。
    #[snafu(display("Invalid host"))]
    InvalidHost,
    /// `Input/OutputStream`の取得に失敗した。
    #[snafu(display("Could not open {:?} stream", direction))]
    AcquireStream { direction: Direction },
    /// ヘッダ未発見のまま閾値に到達した。
    #[snafu(display("Could not find header"))]
    HeaderNotFound,
}

pub async fn get_async<T>(uri: T) -> Result<Vec<u8>, Error>
where
    http::Uri: http::HttpTryFrom<T>,
{
    // URI をパースして各情報を取得
    let uri = http::Request::get(uri)
        .body(())
        .context(RequestParse)?
        .into_parts()
        .0
        .uri;
    // ホスト
    let host = uri.host().context(InvalidHost)?;
    // TLS 有効/無効
    let use_tls = uri.scheme_part() == Some(&http::uri::Scheme::HTTPS);
    // ポート
    let port = uri.port_u16().unwrap_or(if use_tls { 443 } else { 80 });
    // パス & クエリ
    let path_and_query = uri.path_and_query().map(|p| p.as_str()).unwrap_or("/");

    // クライアントの生成
    let client = gio::SocketClient::new();
    client.set_tls(use_tls);

    // ホストに接続
    let conn = client
        .connect_to_host_async_future(host, port)
        .await
        .context(Glib)?;

    // OutputStream を取得
    let output = conn.get_output_stream().context(AcquireStream {
        direction: Direction::Output,
    })?;

    // リクエストの生成・送信
    let request = format!(
        "GET {} HTTP/1.1\r\n\
         Host: {}\r\n\
         \r\n",
        path_and_query, host
    );
    if let Ok((_, _, Some(err))) | Err((_, err)) = output
        .write_all_async_future(request, glib::PRIORITY_DEFAULT)
        .await
    {
        Err(err).context(Glib)?;
    }

    // InputStream を取得
    let input = conn.get_input_stream().context(AcquireStream {
        direction: Direction::Input,
    })?;

    // レスポンスの継続受信を開始

    let mut buffer = Vec::new();
    let mut header = None;
    let mut body = Vec::new();
    let mut search_index = 0;

    loop {
        match input
            .read_all_async_future(create_uninit_buffer(), glib::PRIORITY_DEFAULT)
            .await
        {
            // ダウンロード完了
            Ok((_, _u @ 0, _)) => {
                // ヘッダ未発見の場合はエラー
                snafu::ensure!(header.is_some(), HeaderNotFound);
                break;
            }
            // ダウンロード進行中
            Ok((mut chunk, u, _)) => {
                // チャンクを切り詰める
                chunk.truncate(u);

                // ヘッダ発見済みの場合はボディに追加して continue
                if header.is_some() {
                    body.extend(chunk);
                    continue;
                }

                // ヘッダ未発見の場合は探索
                buffer.extend(chunk);
                // buffer から search_index 以降のスライスを渡す
                match find_crlf(&buffer[search_index..]) {
                    // ヘッダ発見
                    Ok(u) => {
                        // ヘッダ以降の部分をボディとして切り出し、残りをヘッダとしてパース
                        body = buffer.split_off(search_index + u);
                        header = Some(parse_header(&buffer));
                    }
                    // ヘッダ発見できず
                    Err(u) => {
                        // search_index を次回の探索開始位置まで進める
                        search_index += u;
                        // 閾値に達したらヘッダの探索を諦めてエラーを返す
                        snafu::ensure!(search_index < HEADER_SEARCH_LIMIT, HeaderNotFound);
                    }
                }
            }
            // エラー発生
            Err((_, err)) => {
                // ダウンロード済みの容量が Content-Length 以上なら無視
                header
                    .as_ref()
                    .and_then(|header| header.get("content-length"))
                    .and_then(|v| v.to_str().ok())
                    .and_then(|v| v.parse().ok())
                    .filter(|&cl| body.len() >= cl)
                    .ok_or(err)
                    .context(Glib)?;
                break;
            }
        }
    }

    Ok(body)
}

/// read_all_async_future に渡す未初期化バッファを生成。
fn create_uninit_buffer() -> Vec<u8> {
    unsafe {
        let mut v = Vec::with_capacity(BUFFER_SIZE);
        v.set_len(BUFFER_SIZE);
        v
    }
}

/// レスポンス内のヘッダ終端部を探索。
/// 発見した場合はその位置を Ok として、未発見の場合は次回の探索開始位置を Err として返す。
fn find_crlf(buffer: &[u8]) -> Result<usize, usize> {
    buffer
        // スライスを 4 要素ずつ返すイテレータを生成
        .windows(4)
        .position(|b| b == b"\r\n\r\n")
        // マッチ位置は先頭なので b"\r\n\r\n" の文字数(4文字)分足す
        .map(|u| u + 4)
        // 発見できなかった場合、次回は (末尾 - 3) の位置から探索を開始する
        .ok_or_else(|| buffer.len().saturating_sub(4))
}

/// バイト列をヘッダとしてパースする。
fn parse_header(header: &[u8]) -> http::HeaderMap<http::HeaderValue> {
    // 複数行モードに設定して `{key}: {value}` のパターンを検索する
    regex::bytes::Regex::new(r"(?m)^(.+?)\s*:\s*(.+)\r$")
        .expect("failed to build Regex object")
        // マッチした全要素をイテレータとして返す
        .captures_iter(header)
        .filter_map(|cap| {
            // 第 1 キャプチャ部を key、第 2 キャプチャ部を value としてパース
            let (k, v) = (cap.get(1)?, cap.get(2)?);
            Some((
                http::header::HeaderName::from_bytes(k.as_bytes()).ok()?,
                http::HeaderValue::from_bytes(v.as_bytes()).ok()?,
            ))
        })
        // impl<T> FromIterator<(HeaderName, T)> for HeaderMap<T> の実装を利用
        .collect()
}

#[cfg(test)]
mod test {
    use futures::FutureExt;

    use super::*;

    /// メインループを起動して引数の`Future`が完了するまでブロックする。
    fn exec_async(future: impl std::future::Future<Output = ()> + 'static) {
        // メインコンテキスト == 非同期オペレーションを実行するスレッド
        // https://valadoc.org/glib-2.0/GLib.MainContext.push_thread_default.html

        // メインコンテキストの生成
        let main_context = glib::MainContext::new();
        // 現在のスレッドをメインコンテキストとしてマーク
        main_context.push_thread_default();
        // メインコンテキストを引数としてメインループを生成
        let main_loop = glib::MainLoop::new(Some(&main_context), false);
        // メインコンテキストにタスクを追加
        main_context.spawn_local(future.map({
            let main_loop = main_loop.clone();
            move |_| main_loop.quit()
        }));
        // メインループを開始。MainLoop::quit が呼び出されるまでブロックする
        main_loop.run();
        // 現在のスレッドをメインコンテキストから外す(不要)
        // コンテキストをネストさせる場合に push_thread_default と対になるように呼び出す
        main_context.pop_thread_default();
    }

    #[test]
    fn download_test() {
        let future = async move {
            get_async_essential("some.host.com", "/path?query=hoge", true, None).await;

            println!("success");
        };

        exec_async(future);
    }
}

実行

それでは完成した web クライアントを使用して再度ダウンロードに挑戦してみましょう。
ボタンのclickedシグナルハンドラを以下のように書き換えます。
前回の記事のコードをお忘れの方は今一度ご確認ください。

src/main.rs
...
    // ボタンのクリックでダウンロードを開始する
    get_button.connect_clicked(move |button| {
        // ダウンロード中はウィジェットを無効化
        entry.set_sensitive(false);
        button.set_sensitive(false);

        // ボタンがクリックされている == 有効な時点でエントリは空欄ではないので unwrap して取り出す
        let uri = entry.get_text().unwrap();

        let progress = progress.clone();
        let image = image.clone();
        // future の生成
        let future = async move {
            /*
            // クライアントを生成
            let res = reqwest::r#async::Client::new().get(uri.as_str())?.send().await?;

            // Content-Length を取得
            let content_length = res.content_length().map(|cl| cl as f64);
            // ボディを返す Stream を取得
            let mut decoder = res.into_body();

            // Stream からボディを継続受信
            let mut body = Vec::new();
            while let Some(chunk) = decoder.try_next().await? {
                body.extend(chunk);

                // 進捗に応じてプログレスバーを更新
                if let Some(content_length) = content_length {
                    let percent = body.len() as f64 / content_length;
                    progress.set_fraction(percent);
                    progress.set_text(Some(&format!("{:>6.02}%", percent * 100.0)));
                }
            }
            */

            // ダウンロード開始
            let body = web::get_async(uri.as_str()).await?;

            // ダウンロード完了

            progress.set_fraction(100.0);
            progress.set_text(Some("100.00%"));

            // Pixbuf にバイト列を流し込んでイメージにセット
            let loader = gdk_pixbuf::PixbufLoader::new();
            loader.write(&body)?;
            loader.close()?;
            image.set_from_pixbuf(loader.get_pixbuf().as_ref());

            Result::<_, Box<dyn std::error::Error>>::Ok(())
        }
            // ダウンロード後の処理
            .map({
                let entry = entry.clone();
                let button = button.clone();
                move |res| {
                    // 標準エラー出力にエラーの内容を表示
                    if let Err(err) = res {
                        eprintln!("Download failed: {:?}", err);
                    }
                    // 無効化したウィジェットを復元
                    entry.set_sensitive(true);
                    button.set_sensitive(true);
                }
            });

        // Executor にタスクを追加
        glib::MainContext::default().spawn_local(future);
    });
...

too_big.gif

でっかくなっちゃった!!!!
どうやら Image ウィジェットには画像の拡大縮小機能はついていないようです。画像の内容は重要ではありませんので、ここはひとまず見切れるように修正します。
Layout ウィジェットで Image を囲むことで固定レイアウトにしてしまいます。

src/main.rs
fn build_ui(app: &gtk::Application) {
    // 各ウィジェットの生成

    // イメージ
    let layout = gtk::LayoutBuilder::new().expand(true).build();
    let image = gtk::ImageBuilder::new().parent(layout.upcast_ref()).build();
    ...
    // ルートコンテナ
    let root = gtk::Box::new(gtk::Orientation::Vertical, 0);
    // 各ウィジェットを追加
    // root.add(&image);
    root.add(&layout);
    ...

fixed.gif

所望通りの動作になりました。
やった! やりました! これにてタイトルの「Rust+async/await でノンブロッキング GUI」は達成です!!!
これからはこの分野でも Rust 採用の機運が高まると良いですね。

そして Stream へ

ちょっと待って下さい。
確かにスピナーくんが元気いっぱいに回転しているのでノンブロッキングであることは伝わります。
しかしそれに一体何の意味があるでしょう?
これ見よがしに配置されたプログレスバーくんの存在意義は?
進捗に応じて任意の処理を行うための非同期 API だったはずです。

ではどうしましょう?
get_async 関数が on_progress 的なコールバックを受け取るようにしてみますか?

pub async fn get_async<T, F: FnMut(f64)>(uri: T, on_progress: F) -> Result<Vec<u8>, Error>
where
    http::Uri: http::HttpTryFrom<T>,
{
    ...
    loop {
        match input
            .read_all_async_future(create_uninit_buffer(), glib::PRIORITY_DEFAULT)
            .await
        {
            // ダウンロード完了
            Ok((_, _u @ 0, _)) => ...,
            // ダウンロード進行中
            Ok((mut chunk, u, _)) => {
                // チャンクを切り詰める
                chunk.truncate(u);

                // ヘッダ発見済みの場合はボディに追加
                if header.is_some() {
                    body.extend(chunk);
                    let content_length = /* header から 取得 */;
                    // コールバックの呼び出し
                    on_progress(body.len() as f64 / content_length as f64);
                    continue;
                }
                ...

確かにこれは期待通りの動作をします。
しかし呼び出し側はこのようになってしまいます。

        ...
        let future = async move {
            // 第 2 引数にコールバックを渡す
            let body = web::get_async(uri.as_str(), {
                let progress = progress.clone();
                move |percent| {
                    // 進捗に応じてプログレスバーを更新
                    progress.set_fraction(percent);
                    progress.set_text(Some(&format!("{:>6.02}%", percent * 100.0)));
                }
            })
            .await?;
        ...

こんな調子でコールバックを追加していけば、結局はコールバック地獄が顕現してしまい元の木阿弥です。
対して reqwest を利用した当初のコードはこうです。

        ...
        let future = async move {
            // クライアントを生成
            let client = reqwest::r#async::Client::new()
                .get(uri.as_str())
                .send()
                .await?;

            // Content-Length を取得
            let content_length = client.content_length();

            // ボディを返す Stream を取得
            let mut decoder = client.into_body();

            // Stream からボディを継続受信
            let mut body = Vec::new();
            while let Some(chunk) = decoder.try_next().await? {
                body.extend(chunk);
                // 進捗に応じてプログレスバーを更新
                if let Some(content_length) = content_length {
                    let percent = body.len() as f64 / content_length;
                    progress.set_fraction(percent);
                    progress.set_text(Some(&format!("{:>6.02}%", percent * 100.0)));
                }
            }
            ...

圧倒的に美しいですね。ね?
処理をいくらでも呼び出し元に追加することができますし、所有権や借用のいざこざも考える必要がありません。
やはりこちらが正道なのです。

という訳でここからは作成した web クライアントに reqwest と同等の API を追加するためStreamトレイトの実装を行っていきます。

またしても長くなってしまったので続きは別記事として投稿します。


  1. 本当は大抵のサイトで失敗してしまうのですが、成功したという体で進めます。ヘッダが足りないんですかね。 

11
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
7