はじめに
ConnectしかできないMQTTサーバをつくってみましょう
なぜ?
- 知ってるもの使っているものを一から組み上げる楽しさを味わう
- HTTP以外のネットワークプロトコルにふれる
- ネットワークストリームのデコード方法を学ぶ
目次
- Lesson1: TCPサーバを作る
- Lesson2: MQTTデコーダの作成
- Lesson3: MQTTエンコーダの作成
MQTTについて
MQTTは軽量さやシンプルさが特徴のアプリケーションレイヤーのプロトコルです。
前提
- Rustを使用します
- 非同期ランタイムはtokioを使用します
今回作成したソースコード
ここにあります。lessonごとの終了時のソースコードです。
https://github.com/heya-naohiro/mqtt-rust-learn
Lesson1 TCPサーバを作る
本記事で作成するMQTTはMQTT over TCPです。MQTTはアプリケーションレイヤーのプロトコルですのでトランスポートレイヤーのプロトコルが必要です。MQTTのトランスポートレイヤーとしてよく使われるのはTCPとWebsocketです。
まずはRustでTCPサーバを作成してみましょう。
- TcpListenerを初期化 (
let listener = TcpListener::bind(addr).await?;) -
acceptしたら別のスレッドを立てて(tokio::spawn)その接続はそのスレッド内で取り扱う - 今後の取り回しがよくなるようにreaderとwriterを分離します。(
let (mut reader, mut _writer) = tcp_stream.into_split();具体的には書き込みと読み取りは別スレッドで行う想定です - 読み取った (
reader.read_buf(&mut buf)) この章では捨てます (buf.clear()) - bufferとしてBytesMutを使用します
let mut buf = BytesMut::with_capacity(4096);でバッファーを確保し、let n = match reader.read_buf(&mut buf).await {で受信待ちに入ります
use bytes::BytesMut;
use std::net::SocketAddr;
use tokio::io::AsyncReadExt;
use tokio::net::TcpListener;
#[tokio::main]
async fn main() -> std::io::Result<()> {
let addr = "127.0.0.1:1883".parse::<SocketAddr>().unwrap();
let listener = TcpListener::bind(addr).await?;
loop {
let (tcp_stream, remote_addr) = listener.accept().await?;
tokio::spawn(async move {
let (mut reader, mut _writer) = tcp_stream.into_split();
let mut buf = BytesMut::with_capacity(4096);
loop {
let n = match reader.read_buf(&mut buf).await {
Ok(0) => {
println!("Connection closed by {}", remote_addr);
break;
}
Ok(n) => n,
Err(e) => {
eprintln!("Read error: {}", e);
break;
}
};
println!("Received {} bytes from {}: {:?}", n, remote_addr, &buf[..]);
buf.clear(); // ひとまず、受け取ったら、buf clearしていることに注意!
}
});
}
}
実際に動かしてみましょう。まだtcpサーバしかできていないのでクライアントとしてはncコマンドを使用します。
サーバ側を起動します
[~/mqtt_server_learn/lesson1]$cargo run
Compiling lesson1 v0.1.0 (/Users/syureneko/mqtt_server_learn/lesson1)
Finished `dev` profile [unoptimized + debuginfo] target(s) in 1.91s
Running `target/debug/lesson1`
端末をもう一枚開いて、ncコマンドで接続し、helloと打ってEnterキー
[~/mqtt_server_learn]$nc 127.0.0.1 1883
hello
サーバ側の端末に、受信したメッセージが表示されます。
[~/mqtt_server_learn/lesson1]$cargo run
Compiling lesson1 v0.1.0 (/Users/syureneko/mqtt_server_learn/lesson1)
Finished `dev` profile [unoptimized + debuginfo] target(s) in 1.91s
Running `target/debug/lesson1`
Received 6 bytes from 127.0.0.1:50490: [104, 101, 108, 108, 111, 10]
Lesson2 MQTTデコーダの作成
MQTTのパケット構造
MQTTのパケット構造は3つのパートが合わさったものになっています。
- fixed header ・・・ パケットの種類を含む
- variable header ・・・ パケットに付随するメタデータを含む
- payload ・・・ パケットのデータ本体
+-------------------+----------------------+--------------------+
| 固定ヘッダ | 可変ヘッダ | ペイロード |
+-------------------+----------------------+--------------------+
よって、ソフトウェアの構造としても
fixed header -> variable header -> payloadの順番にデコードします。パケットの種類によってvariable headerの中身とpayloadの中身はそれぞれ異なるのでパケット種類ごとにdecoderを実装する必要があります。
パケットの種類
今回はConnectしかできなくても良いのでConnectしか実装しません。MQTTパケットの種類は全部で15種類です。Connect, Connack, Publish, Puback, Pubrec, Pubrel, Pubcomp, Subscribe, Suback, Unsubscribe, Unsuback, Pingreq, Pingresp, Disconnect, Auth
リファクタ
lesson1では実装しているファイルはmain.rsのみでしたが、ここではlib.rsを作成し機能を集約します。
.
├── Cargo.lock
├── Cargo.toml
├── src
│ ├── lib.rs
│ └── main.rs
└── tests
└── mqtt_connect.rs
main.rsは薄く下記のようにしてあります。
#[tokio::main]
async fn main() -> std::io::Result<()> {
my_mqtt::run_server().await
}
デコーダーの作成
デコーダーはTCP層で流れるバイトストリームをアプリケーション的に意味のある変数や構造体に変換する役割を持ちます。パケットの構と同様にdecodeのメソッドを実装します。そして使用する際も順番に使用してデコードしていきます。
decode_fixed_headerdecode_variable_headerdecode_payload
デコーダーのインターフェイスの工夫
decodeする関数には引数として「デコード依頼する最初のインデックス」をbufferとともに渡し、処理した返り値として、「次にbufferから読むとしたらここだよ!の終端のインデックス」を返すことにします。
こうする理由は、MQTTは可変長で区切りのあるプロトコルではなく、パケットを送信し終わっても接続しっぱなしのが前提であり、バッファのどこからどこまでかをわかりやすくすることが実装やデバッグ観点で大事であると考えるからです。
lib.rsの流れ
- decode_fixed_headerでパケットの種類を特定し、構造体を返す
- パケットに応じてdecode_variable_headerを呼んで構造体の中のメタ情報を埋める
- ペイロードをdecode_payloadでデコードする
pub async fn run_server() -> std::io::Result<()> {
let addr = "127.0.0.1:1883".parse::<SocketAddr>().unwrap();
...省略
... 省略, ソケットからbufferに読み取る
println!("Received {} bytes from {}: {:?}", n, remote_addr, &buf[..]);
// buf.clear(); // ひとまず、受け取ったら、buf clearしていることに注意!
// clearしない、insufficient/多すぎる場合がある。
// fixed headerをdecodeする
// fixed header -> 構造体 -> 各パケットに応じたdecode
let (packet, next_pos, remaining_length) = match decode_fixed_header(&buf) { // 1
Ok(ok) => ok,
Err(_) => return, // 失敗時は関数を抜ける
};
match packet {
Packet::CONNECT(mut connect) => {
println!("fixed header decoded: connect packet");
let next_pos = match connect.decode_variable_header(&buf, next_pos) {
Ok(next_pos) => next_pos,
Err(DecodeError::InsufficientBytes) => {
continue;
}
Err(e) => {
eprintln!("connect decode error: {}", e);
break;
}
};
println!("variable header decoded {:?}", connect);
let remain = remaining_length - next_pos;
let _next_pos = match connect.decode_payload(&buf, next_pos, remain) {
Ok(a) => a,
Err(DecodeError::InsufficientBytes) => {
continue;
}
Err(e) => {
eprintln!("decode payload error {}", e);
break;
}
};
println!("end: connect packet {:?}", connect);
/* Send Connack */
}
_ => {}
}
}
});
}
}
Remaining Length
Remaining Lengthは可変長のバイトを使って残りの長さ(fixed headerより後に続く残りのバイト数)を表現しています。最小"1bit + 7bit"で表現され、先頭1ビットは後続に8bitが続くかどうかのみを示しています。最大4ブロック続きます。
+------------------------+------------------------+------------------------+
| 1bit | 7bit | 1bit | 7bit | 1bit | 7bit |
| (MSB) | (Value bits) | (MSB) | (Value bits) | (MSB) | (Value bits) |
+------------------------+------------------------+------------------------+
│ │ │ │ │
│ └─ ×128^0 │ └─ ×128^1 │
│ │ └─ ×128^2
└─ 続きがあるなら 1、
最後のバイトなら 0
pub fn decode_variable_length(
buf: &bytes::BytesMut,
start_pos: usize,
) -> Result<(usize, usize), DecodeError> {
let mut remaining_length: usize = 0;
let mut inc = 0;
for pos in start_pos..=start_pos + 3 {
if buf.len() <= pos {
return Err(DecodeError::InsufficientBytes);
}
remaining_length += ((buf[pos] & 0b01111111) as usize) << (inc * 7); // A
if (buf[pos] & 0b10000000) == 0 {
return Ok((remaining_length, pos + 1)); // 次のバイト位置を返す
}
inc += 1;
if inc >= 4 {
// B
return Err(DecodeError::InvalidRemainingLength);
}
}
Ok((remaining_length, start_pos + 4)) // 全4バイトを処理した場合も次の位置を返す
}
デコーダーの流れ
- httpのようにプロトコル上区切りはあるわけではありませんので、規格書よりそれぞれ何バイトかを確認してバイト列で表現されたデータをintやString, boolにキャストしていきます
impl Connect {
fn decode_variable_header(
&mut self,
buf: &BytesMut,
start_pos: usize,
) -> Result<usize, DecodeError> {
let protocol_name_length =
u16::from_be_bytes([buf[start_pos], buf[start_pos + 1]]) as usize;
// this protocol name must be "MQTT in 3.1.1"
let _protocol_name =
std::str::from_utf8(&buf[start_pos + 2..start_pos + 2 + protocol_name_length])
.expect("invalid utf-8")
.to_string();
let mut next_pos = start_pos + 2 + protocol_name_length;
self.protocol_level = buf[next_pos];
next_pos = next_pos + 1;
self.username_flag = buf[next_pos] & (1 << 7) != 0;
self.password_flag = buf[next_pos] & (1 << 6) != 0;
self.will_retain = buf[next_pos] & (1 << 5) != 0;
let qos = (buf[next_pos] >> 3) & 0b11;
self.will_qos = match qos {
0 => QoS::QoS0,
1 => QoS::QoS1,
2 => QoS::QoS2,
_ => return Err(DecodeError::InvalidFormat),
};
self.will_flag = buf[next_pos] & (1 << 2) != 0;
self.clean_session = buf[next_pos] & (1 << 1) != 0;
/* [TODO] check reserve */
next_pos = next_pos + 1;
self.keep_alive = u16::from_be_bytes([buf[next_pos], buf[next_pos + 1]]);
next_pos = next_pos + 2;
println!("p: {:?}", self);
Ok(next_pos)
}
fn decode_payload(
&mut self,
buf: &BytesMut,
start_pos: usize,
_remain: usize,
) -> Result<usize, DecodeError> {
/* MUST: ClientID */
let len = u16::from_be_bytes([buf[start_pos], buf[start_pos + 1]]) as usize;
self.client_id = std::str::from_utf8(&buf[start_pos + 2..start_pos + 2 + len])
.map_err(|_| DecodeError::ProtocolViolation)?
.to_string();
let next_pos = start_pos + 2 + len;
/* Will Topic */
/* if self.will_flag {
...will topic decode...
}
*/
return Ok(next_pos);
}
}
パケットの種類を表現するためにenumを用意しています。
pub enum Packet {
CONNECT(Connect),
CONNACK(Connack),
PUBLISH(Publish),
OTHER,
}
ここでは簡単のため端折っていますが、下記の事柄を考える必要があります。
- tcpのバッファーにすべて乗っているわけではありませんので下記のどちらかの対応をする必要があります
- 不足している場合は固有のエラーを出して上位でさらなるデータを要求する
- remain lengthから残りのバイト数を計算してデコードが完全にできることを保証するまでデータきてからデコードする
- 実際の実装ではありえない数字が入ってきた場合にはProtocol違反ですのでその接続を閉じる必要があります
テストを書く
テストを書いてみましょう。testsフォルダをおいてファイルを追加します。
├── Cargo.lock
├── Cargo.toml
├── book.md
├── src
│ ├── lib.rs
│ └── main.rs
└── tests
└── mqtt_connect.rs
2つのテストを作成します。
- tcpの接続を確認するテスト
test_server_tcp_connect - mqttのConnectを実行するテスト
test_server_mqtt_connect
use my_mqtt;
use tokio::net::TcpStream;
use tokio::task;
use tokio::time::{Duration, sleep};
#[tokio::test]
async fn test_server_tcp_connect() {
tokio::spawn(async {
my_mqtt::run_server().await.unwrap();
});
sleep(Duration::from_millis(1000)).await;
let conn = TcpStream::connect("127.0.0.1:1883").await;
assert!(conn.is_ok());
}
#[tokio::test]
async fn test_server_mqtt_connect() {
let server = tokio::spawn(async {
my_mqtt::run_server().await.unwrap();
});
tokio::time::sleep(Duration::from_millis(200)).await;
// サーバ起動完了を待つ
let result = task::spawn_blocking(|| {
let client_opts = paho_mqtt::CreateOptionsBuilder::new()
.server_uri("tcp://127.0.0.1:1883")
.client_id("test_client")
.finalize();
let client = paho_mqtt::Client::new(client_opts).unwrap();
let connect_opts = paho_mqtt::ConnectOptionsBuilder::new_v3()
.keep_alive_interval(Duration::from_secs(30))
.clean_session(true)
.finalize();
client.connect(connect_opts)
})
.await
.unwrap();
// 接続が成功したかをチェック
assert!(result.is_ok(), "MQTT connect failed: {:?}", result.err());
}
- 並列で実行されるとポートが被ってしまうため、
--test-threads=1をつけて実行します - まだConnectに対してConnackを送信する機能を実装していないため
test_server_mqtt_connectは失敗します - 標準出力が出ている通りデコードは成功しています→
end: connect packet Connect { protocol_level: 4, username_flag: false, password_flag: false, will_retain: false, will_qos: QoS0, ...
[~/mqtt_server_learn/lesson2]$cargo test -- --test-threads=1
...
...
running 2 tests
test test_server_mqtt_connect ... FAILED
test test_server_tcp_connect ... ok
failures:
---- test_server_mqtt_connect stdout ----
run server
accept
read buf
Received 25 bytes from 127.0.0.1:50370: [16, 23, 0, 4, 77, 81, 84, 84, 4, 2, 0, 30, 0, 11, 116, 101, 115, 116, 95, 99, 108, 105, 101, 110, 116]
fixed header decoded: connect packet
p: Connect { protocol_level: 4, username_flag: false, password_flag: false, will_retain: false, will_qos: QoS0, will_flag: false, clean_session: true, keep_alive: 30, client_identifier: "", will_topic: "", will_message: b"", client_id: "" }
variable header decoded Connect { protocol_level: 4, username_flag: false, password_flag: false, will_retain: false, will_qos: QoS0, will_flag: false, clean_session: true, keep_alive: 30, client_identifier: "", will_topic: "", will_message: b"", client_id: "" }
end: connect packet Connect { protocol_level: 4, username_flag: false, password_flag: false, will_retain: false, will_qos: QoS0, will_flag: false, clean_session: true, keep_alive: 30, client_identifier: "", will_topic: "", will_message: b"", client_id: "test_client" }
read buf
Connection closed by 127.0.0.1:50370
thread 'test_server_mqtt_connect' panicked at tests/mqtt_connect.rs:45:5:
MQTT connect failed: Some(TcpConnectTimeout)
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
failures:
test_server_mqtt_connect
test result: FAILED. 1 passed; 1 failed; 0 ignored; 0 measured; 0 filtered out; finished in 32.24s
error: test failed, to rerun pass `--test mqtt_connect`
[~/mqtt_server_learn/lesson2]$
Lesson3 MQTTエンコーダの作成
エンコーダの実装
- エンコーダはアプリケーション的に意味のある構造体をバイト列に変換する役割を持ちます
- Connackだけだと割り切って固定値を返します
- 使用する際は実際はロジックに応じて構造体を作成し、エンコードを呼び出すことで構造体から順番にバイト列に格納していきます
pub struct Connack {}
impl Connack {
fn encode_all(&self) -> bytes::Bytes {
let mut buf = BytesMut::with_capacity(4096);
//fixed header
buf.extend_from_slice(&[0x20]); // type connack
buf.extend_from_slice(&[0x02]); //remaining length = 2
buf.extend_from_slice(&[0x00]); // new session
buf.extend_from_slice(&[0x00]); // accept connect
buf.freeze()
}
}
送信の実装
- 複数のスレッドにまたがって共有するためにArcを利用します.
let send_map: Arc<DashMap<String, mpsc::Sender<Packet>>> = Arc::new(DashMap::new()); - 送信専用のスレッドを立ち上げますが、スレッドにデータを渡すためにchannelを作ります
let (tx, mut rx) = mpsc::channel::<Packet>(256); - tcpの接続を受けた後に送信専用スレッドを立ち上げます。
tokio::spawn({ - 送信専用スレッドはチャンネルを受信してエンコードしてwriterに書き込みます
writer.write_all(&byte_all).await
pub async fn run_server() -> std::io::Result<()> {
let addr = "127.0.0.1:1883".parse::<SocketAddr>().unwrap();
let listener = TcpListener::bind(addr).await?;
// Connection Map
let send_map: Arc<DashMap<String, mpsc::Sender<Packet>>> = Arc::new(DashMap::new());
println!("run server");
loop {
let (tcp_stream, remote_addr) = listener.accept().await?;
println!("accept");
let (mut reader, mut writer) = tcp_stream.into_split();
let (tx, mut rx) = mpsc::channel::<Packet>(256);
let tx_for_map = tx.clone();
// 送信用スレッド
tokio::spawn({
async move {
while let Some(packet) = rx.recv().await {
println!("recieve start!!!!");
// decode
let byte_all = match packet {
Packet::CONNACK(packet) => packet.encode_all(),
_ => continue,
};
// write all
if let Err(e) = writer.write_all(&byte_all).await {
eprintln!("Write error to {}: {}", remote_addr, e);
break;
}
println!("Send all done");
}
}
});
...(省略
- Connectを受信したあとに、接続マップにclient idを追加します
map.insert(connect.client_id.clone(), tx_for_map.clone()); - 送信専用スレッドが待ち受けているチャンネルにConnackを送信します
if let Err(e) = sender.send(Packet::CONNACK(Connack {})).await
println!("end: connect packet {:?}", connect);
map.insert(connect.client_id.clone(), tx_for_map.clone());
/* Send Connack */
println!("insert end");
if let Some(sender) = map.get(&connect.client_id) {
println!("send will");
if let Err(e) = sender.send(Packet::CONNACK(Connack {})).await {
eprintln!("channel send error: {}", e);
} else {
println!("channel send done");
}
}
Lesson2では失敗していたテストも通ることを確認します
[~/mqtt_server_learn/lesson3]$cargo test -- --test-threads=1
running 2 tests
test test_server_mqtt_connect ... ok
test test_server_tcp_connect ... ok
まとめ
今回のゴールであるConnectのみ対応したMQTTサーバを作成しました。
この考え方を延長すれば、PublishやSubscribeも受けれるようになるはずです。
また、ネットワーク・プロトコルの考え方や、バッファーに追加して中身を解析するといった知識はほかでも役に立つはずです。