1
0

Rustのndarrayでnp.padのような関数を作る

Last updated at Posted at 2024-03-11

動機

Rustがすごいらしいので最近は勉強しています。
Rustに慣れるためにライフゲームを作って、その盤面をndarrayで管理しようとしました。境界条件のために、numpyでいうpadのような関数があると楽だと思ったのですが、調べたところどうやらndarrayには存在しないらしいです。
そこで、pad関数を作ってみました。

注意など

  • ジェネリクスがよくわかってないので, この関数は二次元配列のu8型のみを受け付けています。ほかのプリミティブな整数についてはエディタでu8を置換するだけで動くはずです。
  • [[1, 2, 3], [4, 5, 6], [7, 8, 9]]では動いているはずですが, 使う際は注意してください。

コード

// original: the original array.
// width: the width of the padding.
// padding: the type of padding to apply. allow "constant", "edge", "wrap", "reflect", "symmetric", "minimum", "maximum", "mean"
// Returns the padded array.
fn pad(original: Array2<u8>, width: usize, padding: &str) -> Array2<u8> {
    let (rows, cols) = original.dim();
    let mut padded_board = Array::zeros((rows + 2 * width, cols + 2 * width));
    for i in 0..rows {
        for j in 0..cols {
            padded_board[[i + width, j + width]] = original[[i, j]];
        }
    }
    let inner_cols_range = width..width + cols;
    let inner_rows_range = width..width + rows;
    match padding {
        "constant" => {}
        "edge" => {
            padded_board
                .slice_mut(s![..width, ..width])
                .fill(original[[0, 0]]);
            padded_board
                .slice_mut(s![..width, width + cols..])
                .fill(original[[0, cols - 1]]);
            padded_board
                .slice_mut(s![width + rows.., ..width])
                .fill(original[[rows - 1, 0]]);
            padded_board
                .slice_mut(s![width + rows.., width + cols])
                .fill(original[[rows - 1, cols - 1]]);
            println!("{:?}", padded_board);
            padded_board
                .slice_mut(s![..width, inner_cols_range.clone()])
                .assign(&original.slice(s![0..1, ..]));
            padded_board
                .slice_mut(s![rows + width.., inner_cols_range.clone()])
                .assign(&original.slice(s![rows - 1..rows, ..]));
            padded_board
                .slice_mut(s![inner_rows_range.clone(), ..width])
                .assign(&original.slice(s![.., 0..1]));
            padded_board
                .slice_mut(s![inner_rows_range.clone(), cols + width..])
                .assign(&original.slice(s![.., cols - 1..cols]));
        }
        "maximum" => {
            let max_val = original.iter().max().unwrap();
            padded_board.slice_mut(s![..width, ..width]).fill(*max_val);
            padded_board
                .slice_mut(s![..width, width + cols..])
                .fill(*max_val);
            padded_board
                .slice_mut(s![width + rows.., ..width])
                .fill(*max_val);
            padded_board
                .slice_mut(s![width + rows.., width + cols..])
                .fill(*max_val);
            let cols_max = original
                .columns()
                .into_iter()
                .map(|col| col.iter().max().unwrap().clone())
                .collect::<Vec<u8>>();
            let cols_max = Array::from_shape_vec((1, cols), cols_max).unwrap();
            let rows_max = original
                .rows()
                .into_iter()
                .map(|row| row.iter().max().unwrap().clone())
                .collect::<Vec<u8>>();
            let rows_max = Array::from_shape_vec((rows, 1), rows_max).unwrap();
            padded_board
                .slice_mut(s![..width, inner_cols_range.clone()])
                .assign(&cols_max);
            padded_board
                .slice_mut(s![rows + width.., inner_cols_range.clone()])
                .assign(&cols_max);
            padded_board
                .slice_mut(s![inner_rows_range.clone(), ..width])
                .assign(&rows_max);
            padded_board
                .slice_mut(s![inner_rows_range.clone(), cols + width..])
                .assign(&rows_max);
        }
        "mean" => {
            let mean_val = (original.iter().sum::<u8>() as isize / (rows * cols) as isize) as u8;
            padded_board.slice_mut(s![..width, ..width]).fill(mean_val);
            padded_board
                .slice_mut(s![..width, width + cols..])
                .fill(mean_val);
            padded_board
                .slice_mut(s![width + rows.., ..width])
                .fill(mean_val);
            padded_board
                .slice_mut(s![width + rows.., width + cols])
                .fill(mean_val);
            let cols_mean = original
                .columns()
                .into_iter()
                .map(|col| col.iter().sum::<u8>() as isize / rows as isize)
                .collect::<Vec<isize>>();
            let cols_mean = Array::from_shape_vec((1, cols), cols_mean).unwrap();
            let rows_mean = original
                .rows()
                .into_iter()
                .map(|row| row.iter().sum::<u8>() as isize / cols as isize)
                .collect::<Vec<isize>>();
            let rows_mean = Array::from_shape_vec((rows, 1), rows_mean).unwrap();
            padded_board
                .slice_mut(s![..width, inner_cols_range.clone()])
                .assign(&cols_mean.mapv(|x| x as u8));
            padded_board
                .slice_mut(s![rows + width.., inner_cols_range.clone()])
                .assign(&cols_mean.mapv(|x| x as u8));
            padded_board
                .slice_mut(s![inner_rows_range.clone(), ..width])
                .assign(&rows_mean.mapv(|x| x as u8));
            padded_board
                .slice_mut(s![inner_rows_range.clone(), cols + width..])
                .assign(&rows_mean.mapv(|x| x as u8));
        }
        "minimum" => {
            let min_val = original.iter().min().unwrap();
            padded_board.slice_mut(s![..width, ..width]).fill(*min_val);
            padded_board
                .slice_mut(s![..width, width + cols..])
                .fill(*min_val);
            padded_board
                .slice_mut(s![width + rows.., ..width])
                .fill(*min_val);
            padded_board
                .slice_mut(s![width + rows.., width + cols..])
                .fill(*min_val);
            let cols_max = original
                .columns()
                .into_iter()
                .map(|col| col.iter().min().unwrap().clone())
                .collect::<Vec<u8>>();
            let cols_max = Array::from_shape_vec((1, cols), cols_max).unwrap();
            let rows_max = original
                .rows()
                .into_iter()
                .map(|row| row.iter().min().unwrap().clone())
                .collect::<Vec<u8>>();
            let rows_max = Array::from_shape_vec((rows, 1), rows_max).unwrap();
            padded_board
                .slice_mut(s![..width, inner_cols_range.clone()])
                .assign(&cols_max);
            padded_board
                .slice_mut(s![rows + width.., inner_cols_range.clone()])
                .assign(&cols_max);
            padded_board
                .slice_mut(s![inner_rows_range.clone(), ..width])
                .assign(&rows_max);
            padded_board
                .slice_mut(s![inner_rows_range.clone(), cols + width..])
                .assign(&rows_max);
        }
        "wrap" => {
            padded_board
                .slice_mut(s![..width, ..width])
                .assign(&original.slice(s![cols-width..cols, rows-width..rows]));
            padded_board
                .slice_mut(s![..width, width + cols..])
                .assign(&original.slice(s![cols-width..cols, ..width]));
            padded_board
                .slice_mut(s![width + rows.., ..width])
                .assign(&original.slice(s![..width, rows-width..rows]));
            padded_board
                .slice_mut(s![width + rows.., width + cols..])
                .assign(&original.slice(s![..width, ..width]));
            padded_board
                .slice_mut(s![..width, inner_cols_range.clone()])
                .assign(&original.slice(s![cols-width..cols, ..]));
            padded_board
                .slice_mut(s![rows + width.., inner_cols_range.clone()])
                .assign(&original.slice(s![..width, ..]));
            padded_board
                .slice_mut(s![inner_rows_range.clone(), ..width])
                .assign(&original.slice(s![.., rows-width..rows]));
            padded_board
                .slice_mut(s![inner_rows_range.clone(), cols + width..])
                .assign(&original.slice(s![.., ..width]));
        }
        "reflect" => {
            fn reflect(arr: ArrayView2<u8>) -> Array2<u8> {
                let (rows, cols) = arr.dim();
                let mut reflected = arr.iter().cloned().collect::<Vec<u8>>();
                reflected.reverse();
                Array::from_shape_vec(
                    (rows, cols),
                    reflected
                ).unwrap()
            }
            padded_board
                .slice_mut(s![..width, ..width])
                .assign(&reflect(original.slice(s![cols-width..cols, rows-width..rows])));
            padded_board
                .slice_mut(s![..width, width + cols..])
                .assign(&reflect(original.slice(s![cols-width..cols, ..width])));
            padded_board
                .slice_mut(s![width + rows.., ..width])
                .assign(&reflect(original.slice(s![..width, rows-width..rows])));
            padded_board
                .slice_mut(s![width + rows.., width + cols..])
                .assign(&reflect(original.slice(s![..width, ..width])));
            padded_board
                .slice_mut(s![..width, inner_cols_range.clone()])
                .assign(&reflect(original.slice(s![cols-width..cols, ..])));
            padded_board
                .slice_mut(s![rows + width.., inner_cols_range.clone()])
                .assign(&reflect(original.slice(s![..width, ..])));
            padded_board
                .slice_mut(s![inner_rows_range.clone(), ..width])
                .assign(&reflect(original.slice(s![.., rows-width..rows])));
            padded_board
                .slice_mut(s![inner_rows_range.clone(), cols + width..])
                .assign(&reflect(original.slice(s![.., ..width])));
        }
        "symmetric" => {
            fn reflect(arr: ArrayView2<u8>) -> Array2<u8> {
                let (rows, cols) = arr.dim();
                let mut reflected = arr.iter().cloned().collect::<Vec<u8>>();
                reflected.reverse();
                Array::from_shape_vec(
                    (rows, cols),
                    reflected
                ).unwrap()
            }
            fn reverse(arr: ArrayView2<u8>) -> Array2<u8> {
                let mut ans = arr.to_owned();
                for i in 0..arr.dim().0 {
                    let row = arr.slice(s![i, ..]).to_owned();
                    ans.slice_mut(s![arr.dim().0 - i - 1, ..]).assign(&row);
                }
                ans
            }
            padded_board
                .slice_mut(s![..width, ..width])
                .assign(&reflect(original.slice(s![..width, ..width])));
            padded_board
                .slice_mut(s![..width, width + cols..])
                .assign(&reflect(original.slice(s![..width, cols-width..])));
            padded_board
                .slice_mut(s![width + rows.., ..width])
                .assign(&reflect(original.slice(s![rows-width.., ..width])));
            padded_board
                .slice_mut(s![width + rows.., width + cols..])
                .assign(&reflect(original.slice(s![rows-width.., cols-width..])));
            padded_board
                .slice_mut(s![..width, inner_cols_range.clone()])
                .assign(&reverse(original.slice(s![..width, ..])));
            padded_board
                .slice_mut(s![rows + width.., inner_cols_range.clone()])
                .assign(&reverse(original.slice(s![rows-width.., ..])));
            padded_board
                .slice_mut(s![inner_rows_range.clone(), ..width])
                .assign(&reverse(original.slice(s![.., ..width])));
            padded_board
                .slice_mut(s![inner_rows_range.clone(), cols + width..])
                .assign(&reverse(original.slice(s![.., cols-width..])));
        }
        _ => panic!("Invalid padding type"),
    }
    padded_board
}

仕様および動作例

第一引数: 元の配列。所有権を奪われることに注意してください。
第二引数: パディングの幅
第三引数: パディングの種類
使えるものを以下に示します。動作は基本numpy.padと同じように作ったつもりです。
constant

いわゆる0埋めです
fn main() {
    let board: Array2<u8> = array![
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]];
    let padded_board = pad(board, 2, "constant");
    println!("{:?}", padded_board);
    // [[0, 0, 0, 0, 0, 0, 0],
    //  [0, 0, 0, 0, 0, 0, 0],
    //  [0, 0, 1, 2, 3, 0, 0],
    //  [0, 0, 4, 5, 6, 0, 0],
    //  [0, 0, 7, 8, 9, 0, 0],
    //  [0, 0, 0, 0, 0, 0, 0],
    //  [0, 0, 0, 0, 0, 0, 0]], shape=[7, 7], strides=[7, 1], layout=Cc (0x5), const ndim=2
}

edge
境界にあるデータをそのままコピーします。

fn main() {
    let board: Array2<u8> = array![
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]];
    let padded_board = pad(board, 2, "edge");
    println!("{:?}", padded_board);
    // [[1, 1, 1, 2, 3, 3, 3],
    // [1, 1, 1, 2, 3, 3, 3],
    // [1, 1, 1, 2, 3, 3, 3],
    // [4, 4, 4, 5, 6, 6, 6],
    // [7, 7, 7, 8, 9, 9, 9],
    // [7, 7, 7, 8, 9, 9, 0],
    // [7, 7, 7, 8, 9, 9, 0]], shape=[7, 7], strides=[7, 1], layout=Cc (0x5), const ndim=2
}

maximum
各行各列の最大値で埋めます。

fn main() {
    let board: Array2<u8> = array![
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]];
    let padded_board = pad(board, 2, "maximum");
    println!("{:?}", padded_board);
    // [[9, 9, 7, 8, 9, 9, 9],
    // [9, 9, 7, 8, 9, 9, 9],
    // [3, 3, 1, 2, 3, 3, 3],
    // [6, 6, 4, 5, 6, 6, 6],
    // [9, 9, 7, 8, 9, 9, 9],
    // [9, 9, 7, 8, 9, 9, 9],
    // [9, 9, 7, 8, 9, 9, 9]], shape=[7, 7], strides=[7, 1], layout=Cc (0x5), const ndim=2
}

mean
各行各列の平均値で埋めます。

fn main() {
    let board: Array2<u8> = array![
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]];
    let padded_board = pad(board, 2, "mean");
    println!("{:?}", padded_board);
    // [[5, 5, 4, 5, 6, 5, 5],
    // [5, 5, 4, 5, 6, 5, 5],
    // [2, 2, 1, 2, 3, 2, 2],
    // [5, 5, 4, 5, 6, 5, 5],
    // [8, 8, 7, 8, 9, 8, 8],
    // [5, 5, 4, 5, 6, 5, 0],
    // [5, 5, 4, 5, 6, 5, 0]], shape=[7, 7], strides=[7, 1], layout=Cc (0x5), const ndim=2
}

minimum
各行各列の最小値で埋めます。

fn main() {
    let board: Array2<u8> = array![
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]];
    let padded_board = pad(board, 2, "minimum");
    println!("{:?}", padded_board);
    // [[1, 1, 1, 2, 3, 1, 1],
    // [1, 1, 1, 2, 3, 1, 1],
    // [1, 1, 1, 2, 3, 1, 1],
    // [4, 4, 4, 5, 6, 4, 4],
    // [7, 7, 7, 8, 9, 7, 7],
    // [1, 1, 1, 2, 3, 1, 1],
    // [1, 1, 1, 2, 3, 1, 1]], shape=[7, 7], strides=[7, 1], layout=Cc (0x5), const ndim=2
}

wrap
対称位置にあるデータで埋めます。

fn main() {
    let board: Array2<u8> = array![
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]];
    let padded_board = pad(board, 2, "wrap");
    println!("{:?}", padded_board);
    // [[5, 6, 4, 5, 6, 4, 5],
    // [8, 9, 7, 8, 9, 7, 8],
    // [2, 3, 1, 2, 3, 1, 2],
    // [5, 6, 4, 5, 6, 4, 5],
    // [8, 9, 7, 8, 9, 7, 8],
    // [2, 3, 1, 2, 3, 1, 2],
    // [5, 6, 4, 5, 6, 4, 5]], shape=[7, 7], strides=[7, 1], layout=Cc (0x5), const ndim=2
}

reflect
対称位置にあるデータを上下左右反転してから埋めます。

fn main() {
    let board: Array2<u8> = array![
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]];
    let padded_board = pad(board, 2, "reflect");
    println!("{:?}", padded_board);
    // [[9, 8, 9, 8, 7, 8, 7],
    // [6, 5, 6, 5, 4, 5, 4],
    // [9, 8, 1, 2, 3, 8, 7],
    // [6, 5, 4, 5, 6, 5, 4],
    // [3, 2, 7, 8, 9, 2, 1],
    // [6, 5, 6, 5, 4, 5, 4],
    // [3, 2, 3, 2, 1, 2, 1]], shape=[7, 7], strides=[7, 1], layout=Cc (0x5), const ndim=2
}

symmetric
埋め込まれた配列の境界を鏡と見立てたときのように動きます。

fn main() {
    let board: Array2<u8> = array![
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]];
    let padded_board = pad(board, 2, "symmetric");
    println!("{:?}", padded_board);
    // [[5, 4, 4, 5, 6, 6, 5],
    // [2, 1, 1, 2, 3, 3, 2],
    // [7, 8, 1, 2, 3, 8, 9],
    // [4, 5, 4, 5, 6, 5, 6],
    // [1, 2, 7, 8, 9, 2, 3],
    // [8, 7, 7, 8, 9, 9, 8],
    // [5, 4, 4, 5, 6, 6, 5]], shape=[7, 7], strides=[7, 1], layout=Cc (0x5), const ndim=2
}

これ以外はpanicします.

感想

Rustの所有権システムっておもしろいなと思うんですけど、自分で関数を作るとどの引数の所有権を奪ってどれは奪わないかわからなくなっちゃいます。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0