paizaラーニングレベルアップ問題集の二次元区間和をやってみました。
問題
参考
二次元区間和メソッド
プロトタイプ
早速ですが、二次元累積和の問題で作成したIntArray2クラスに、二次元区間和メソッドinterval_sumを追加します。
#ifndef INTARRAY2_H_
#define INTARRAY2_H_
#include <stdbool.h>
typedef struct __IntArray2 _IntArray2;
typedef struct IntArray2 {
_IntArray2 *_array;
bool (*cumulative_sum)(const struct IntArray2*, size_t, size_t, int*);
+ bool (*interval_sum)(const struct IntArray2*, size_t, size_t, size_t, size_t, int*);
} IntArray2;
IntArray2* new_IntArray2(size_t, size_t, const int**);
void free_IntArray2(IntArray2**);
#endif /* INTARRAY2_H_ */
テストコード
今回追加したinterval_sumメソッドのテストコードを作成します。
-
入力添字として
- $0\le a\le c\le H$
- $0\le b\le d\le W$
を許容。上記の条件から外れた入力値ははじく。
(問題の入力条件とは合致しませんが、メソッドとしてはこのように定義しておきます) -
$a=0$, $b=0$, $c=H$, $d=W$の時は二次元配列全要素の合計値を返す
-
$a=c$または$b=d$の時は
0を返す
(問題では与えられませんが、というのは、問題で与えられる$a$, $b$はデクリメントして渡すためです。ですが、このように定義しておきます)
ことを確認します。
単体テストツールとしてminunit.hを使用します。
static char* test_interval_sum(size_t h, size_t w, const int **A, size_t a, size_t b, size_t c, size_t d, int expected) {
IntArray2 *intArray2 = new_IntArray2(h, w, A);
if (!intArray2) {
return strerror(errno);
}
int actual;
bool result = intArray2->interval_sum(intArray2, a, b, c, d, &actual);
free_IntArray2(&intArray2);
mu_assert("Error: expected: <true> but was: <false>", result);
mu_assert(message(expected, actual), actual == expected);
return 0;
}
static char* test_interval_sum_fail(size_t h, size_t w, const int **A, size_t a, size_t b, size_t c, size_t d) {
IntArray2 *intArray2 = new_IntArray2(h, w, A);
if (!intArray2) {
return strerror(errno);
}
int actual;
bool result = intArray2->interval_sum(intArray2, a, b, c, d, &actual);
free_IntArray2(&intArray2);
mu_assert("Error: expected: <false> but was: <true>", !result);
mu_assert(message(ERANGE, errno), errno == ERANGE);
return 0;
}
static char* test_interval_sum_0() {
return test_interval_sum(0, 0, NULL, 0, 0, 0, 0, 0);
}
static char* test_interval_sum_1() {
const int a[1][1] = { { } };
const int *A[] = { a[0] };
return test_interval_sum_fail(1, 1, A, 1, 0, 0, 0);
}
static char* test_interval_sum_2() {
const int a[1][1] = { { } };
const int *A[] = { a[0] };
return test_interval_sum_fail(1, 1, A, 0, 1, 0, 0);
}
static char* test_interval_sum_3() {
const int a[1][1] = { { } };
const int *A[] = { a[0] };
return test_interval_sum_fail(1, 1, A, 0, 0, 2, 1);
}
static char* test_interval_sum_4() {
const int a[1][1] = { { } };
const int *A[] = { a[0] };
return test_interval_sum_fail(1, 1, A, 0, 0, 1, 2);
}
static char* test_interval_sum_5() {
const int a[3][3] = { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } };
const int *A[] = { a[0], a[1], a[2] };
return test_interval_sum(3, 3, A, 1, 1, 3, 3, 28);
}
static char* test_interval_sum_6() {
const int a[2][2] = { { 1, 2 }, { 3, 4 } };
const int *A[] = { a[0], a[1] };
return test_interval_sum(2, 2, A, 0, 0, 2, 2, 10);
}
static char* test_interval_sum_7() {
const int a[2][2] = { { 1, 2 }, { 3, 4 } };
const int *A[] = { a[0], a[1] };
return test_interval_sum(2, 2, A, 1, 0, 1, 2, 0);
}
static char* test_interval_sum_8() {
const int a[2][2] = { { 1, 2 }, { 3, 4 } };
const int *A[] = { a[0], a[1] };
return test_interval_sum(2, 2, A, 0, 1, 2, 1, 0);
}
本体
上記のテストコードありきで、テストを通過させるようにコードを書いていきます。
static bool IntArray2_interval_sum(const IntArray2 *self, size_t a, size_t b, size_t c, size_t d, int *sum) {
if (c < a || self->_array->h < c || d < b || self->_array->w < d) {
errno = ERANGE;
return false;
}
*sum = self->_array->sum[c][d] - self->_array->sum[c][b]
- self->_array->sum[a][d] + self->_array->sum[a][b];
return true;
}
これをコンストラクタにて
array->interval_sum = IntArray2_interval_sum;
すればよいです。
入力チェック
後半のクエリ部分で
- $1\le a\le c\le H$
- $1\le b\le d\le W$
を満たすような整数$a,b,c,d$が入力されるので、この条件を満たすかどうかチェックすればよいです。
尚、上述のIntArray2クラスのinterval_sumメソッドに渡すときは、$a$, $b$はデクリメントして渡します。
プロトタイプ
#ifndef INPUT_H_
#define INPUT_H_
#include <stdbool.h>
bool read_size(const char*, size_t*, size_t*, size_t*);
bool read_elements(const char*, size_t, int*);
- bool read_query(const char*, size_t, size_t, size_t*, size_t*);
+ bool read_query(const char*, size_t, size_t, size_t*, size_t*, size_t*, size_t*);
#endif /* INPUT_H_ */
テストコード
以下にread_queryのテストコードを示しますが、数字以外の文字が入ってきた場合にはじくチェックは省略しております。
static char* test_read_query(const char *str, size_t h, size_t w, size_t ea, size_t eb, size_t ec, size_t ed) {
size_t a, b, c, d;
mu_assert("Error: expected: <true> but was: <false>", read_query(str, h, w, &a, &b, &c, &d));
mu_assert(message_d(0, errno), errno == 0);
mu_assert(message_z(ea, a), a == ea);
mu_assert(message_z(eb, b), b == eb);
mu_assert(message_z(ec, c), c == ec);
mu_assert(message_z(ed, d), d == ed);
return 0;
}
static char* test_read_query_invalid(const char *str, size_t h, size_t w) {
size_t a, b, c, d;
mu_assert("Error: expected: <false> but was: <true>", !read_query(str, h, w, &a, &b, &c, &d));
mu_assert(message_d(EINVAL, errno), errno == EINVAL);
return 0;
}
static char* test_read_query_out_of_range(const char *str, size_t h, size_t w) {
size_t a, b, c, d;
mu_assert("Error: expected: <false> but was: <true>", !read_query(str, h, w, &a, &b, &c, &d));
mu_assert(message_d(ERANGE, errno), errno == ERANGE);
return 0;
}
static char* test_read_query_0() {return test_read_query_invalid(NULL, 0, 0);}
static char* test_read_query_1() {return test_read_query_invalid("", 0, 0);}
static char* test_read_query_2() {return test_read_query_invalid("1", 1, 1);}
static char* test_read_query_3() {return test_read_query_invalid("1 1 1", 1, 1);}
static char* test_read_query_4() {return test_read_query_invalid("1 1 1 1 1", 1, 1);}
static char* test_read_query_5() {return test_read_query_out_of_range("0 1 1 1", 1, 1);}
static char* test_read_query_6() {return test_read_query_out_of_range("1 0 1 1", 1, 1);}
static char* test_read_query_7() {return test_read_query_out_of_range("1 1 2 1", 1, 1);}
static char* test_read_query_8() {return test_read_query_out_of_range("1 1 1 2", 1, 1);}
static char* test_read_query_9() {return test_read_query("1 1 1 1", 1, 1, 1, 1, 1, 1);}
static char* test_read_query_10() {return test_read_query_out_of_range("2 1 1 1", 2, 2);}
static char* test_read_query_11() {return test_read_query_out_of_range("1 2 1 1", 2, 2);}
static char* test_read_query_12() {return test_read_query("1000 1000 1000 1000\n", 1000, 1000, 1000, 1000, 1000, 1000);}
本体
本体のread_query関数を修正します。
bool read_query(const char *str, size_t h, size_t w, size_t *a, size_t *b, size_t *c, size_t *d) {
errno = 0;
if (!str) {
errno = EINVAL;
return false;
}
char s[] = "1000 1000 1000 1000\n";
if (strlen(str) > strlen(s)) {
errno = EINVAL;
return false;
}
if (!parse_ulong(strtok(chomp(strncpy(s, str, sizeof(s))), " "), a)) {
return false;
} else if (*a < 1 || h < *a) {
errno = ERANGE;
return false;
}
if (!parse_ulong(strtok(NULL, " "), b)) {
return false;
} else if (*b < 1 || w < *b) {
errno = ERANGE;
return false;
}
if (!parse_ulong(strtok(NULL, " "), c)) {
return false;
} else if (*c < *a || h < *c) {
errno = ERANGE;
return false;
}
if (!parse_ulong(strtok(NULL, " "), d)) {
return false;
} else if (*d < *b || w < *d) {
errno = ERANGE;
return false;
}
if (strtok(NULL, " ")) {
errno = EINVAL;
return false;
}
return true;
}
エントリポイント
main関数を修正します。
テストコード
まず、テストコードを作成します。
static char* test0() {
const char *lines[] = { "1 1 1", "0", "1 1 1 1" };
size_t n = test(sizeof(lines) / sizeof(lines[0]), lines);
mu_assert(message_u(1, n), n == 1);
mu_assert(message_s("0\n", result[0]), !strcmp(result[0], "0\n"));
return 0;
}
static char* test1() {
const char *lines[] = { "3 3 2", "1 2 3", "4 5 6", "7 8 9", "1 1 3 3", "1 2 2 3" };
size_t n = test(sizeof(lines) / sizeof(lines[0]), lines);
mu_assert(message_u(2, n), n == 2);
mu_assert(message_s("45\n", result[0]), !strcmp(result[0], "45\n"));
mu_assert(message_s("16\n", result[1]), !strcmp(result[1], "16\n"));
return 0;
}
static char* test2() {
const char *lines[] = { "10 10 4",
"-74 -92 65 11 -96 66 17 33 -86 29",
"26 -83 100 -72 85 51 -29 8 49 72",
"-47 52 -69 85 23 80 -59 79 92 -97",
"14 -26 -15 9 -22 -65 -29 -66 -30 -48",
"-17 -68 -22 -50 -48 14 -29 96 -77 -23",
"-96 -83 -31 46 8 -67 -94 92 -70 -49",
"-97 8 42 -49 87 -72 -73 -80 68 66",
"100 94 -57 -62 -58 -18 -42 -80 55 47",
"18 -86 97 14 -37 -69 -19 56 58 -96",
"48 12 75 -68 19 52 -88 54 -24 -45",
"6 2 7 7", "7 7 8 10", "1 5 2 5", "6 6 10 7"
};
size_t n = test(sizeof(lines) / sizeof(lines[0]), lines);
mu_assert(message_u(4, n), n == 4);
mu_assert(message_s("-278\n", result[0]), !strcmp(result[0], "-278\n"));
mu_assert(message_s("-39\n", result[1]), !strcmp(result[1], "-39\n"));
mu_assert(message_s("-11\n", result[2]), !strcmp(result[2], "-11\n"));
mu_assert(message_s("-490\n", result[3]), !strcmp(result[3], "-490\n"));
return 0;
}
static char* test3() {
const char *lines[] = {
"1000 1000 100000", "1001 1001 100001",
"000000000000000001 1 1", "0 0 0", "1 1", "1 1 1 1", "2 2 1",
"-100 -100", "-101 -101", "101 101", "0", "0 0 0", "1 -2", "-3 4",
"0 0 3 3", "1", "1 1 2", "1 1 2 2 2", "1 1 2 2"
};
size_t n = test(sizeof(lines) / sizeof(lines[0]), lines);
mu_assert(message_u(1, n), n == 1);
mu_assert(message_s("0\n", result[0]), !strcmp(result[0], "0\n"));
return 0;
}
static char* test4() {
char *lines[101001] = { };
size_t k = 0;
lines[k] = (char*) calloc(-~strlen("1000 1000 100000"), sizeof(char));
if (!lines[k]) {
return strerror(errno);
}
strcpy(lines[k++], "1000 1000 100000");
for (size_t i = 0; i < 1000; i++) {
lines[k] = (char*) calloc(strlen("-100 ") * 1000, sizeof(char));
if (!lines[k]) {
int e = errno;
while (k--) {
free(lines[k]);
}
return strerror(e);
}
strcpy(lines[k], "-100");
for (size_t j = 1; j < 1000; j++) {
strcat(lines[k], " -100");
}
k++;
}
for (size_t i = 0; i < 100000; i++) {
lines[k] = (char*) calloc(-~strlen("1000 1000 1000 1000"), sizeof(char));
if (!lines[k]) {
int e = errno;
while (k--) {
free(lines[k]);
}
return strerror(e);
}
size_t y = i / 100;
size_t x = (i * 10 + y % 10) % 1000;
sprintf(lines[k], "%zu %zu 1000 1000", -~y, -~x);
k++;
}
size_t n = test(sizeof(lines) / sizeof(lines[0]), (const char**) lines);
mu_assert(message_u(100000, n), n == 100000);
for (size_t i = 0; i < 100000; i++) {
char expected[] = "-100000000\n";
size_t y = i / 100;
size_t x = (i * 10 + y % 10) % 1000;
snprintf(expected, sizeof(expected), "-%zu00\n", (1000 - y) * (1000 - x));
mu_assert(message_s(expected, result[i]), !strcmp(result[i], expected));
}
while (k--) {
free(lines[k]);
}
return 0;
}
本体
#include <stdio.h>
#include <stdlib.h>
#include "IntArray2.h"
#include <errno.h>
#ifdef NDEBUG
#include <string.h>
#include "input.h"
#else
#include <time.h>
#define PRINT(fmt, ...) \
do { \
fprintf(stderr, fmt, ##__VA_ARGS__); \
fflush(stderr); \
} while (0)
#endif
int main() {
#ifndef NDEBUG
clock_t clockt = clock();
#endif
size_t H, W, N;
#ifdef NDEBUG
do {
char s[] = "1000 1000 100000\n";
if (!strchr(fgets(s, sizeof(s), stdin), '\n')) {
for (int c = getchar(); (c == '\0' || c == '\n' || c == EOF); c = getchar());
} else {
if (read_size(s, &H, &W, &N)) {
break;
}
}
perror("The number of rows must be an integer between 1 and 1000.\n"
"The number of columns must be an integer between 1 and 1000.\n"
"The number of queries must be an integer between 1 and 100000.");
} while (true);
#else
H = 1000;
W = 1000;
N = 100000;
#endif
// int A[H][W];
int *A[H];
for (size_t i = 0; i < H; i++) {
A[i] = calloc(W, sizeof(A[i][0]));
#ifdef NDEBUG
do {
char s[strlen("-100 ") * W + 1];
if (!strchr(fgets(s, sizeof(s), stdin), '\n')) {
for (int c = getchar(); (c == '\0' || c == '\n' || c == EOF); c = getchar());
} else {
if (read_elements(s, W, A[i])) {
break;
}
}
fprintf(stderr, "The number of elements must be %zu.\n", W);
perror("The element must be an integer between -100 and 100.");
} while (true);
#else
for (size_t j = 0; j < W; j++) {
A[i][j] = (int) (i * W + j) % 201 - 100;
}
#endif
}
IntArray2 *intArray2 = new_IntArray2(H, W, (const int**)A);
if (!intArray2) {
return errno;
}
while (N--) {
- size_t y, x;
+ size_t a, b, c, d;
#ifdef NDEBUG
do {
- char s[] = "1000 1000\n";
+ char s[] = "1000 1000 1000 1000\n";
if (!strchr(fgets(s, sizeof(s), stdin), '\n')) {
for (int c = getchar(); (c == '\0' || c == '\n' || c == EOF); c = getchar());
} else {
- if (read_query(s, H, W, &y, &x)) {
+ if (read_query(s, H, W, &a, &b, &c, &d)) {
break;
}
}
fprintf(stderr, "The row number must be an integer between 1 and %zu.\n", N);
fprintf(stderr, "The column number must be an integer between 1 and %zu.\n", W);
perror(NULL);
} while (true);
#else
- y = N * H / 100000 + 1;
- x = N % W + 1;
+ a = N * H / 100000 + 1;
+ b = N % W + 1;
+ c = 1000;
+ d = 1000;
#endif
int sum;
- if (!intArray2->cumulative_sum(intArray2, y, x, &sum)) {
+ if (!intArray2->interval_sum(intArray2, --a, --b, c, d, &sum)) {
int e = errno;
perror(NULL);
return e;
}
printf("%d\n", sum);
}
free_IntArray2(&intArray2);
for (size_t i = H; i; i--) {
free(A[~-i]);
}
#ifndef NDEBUG
PRINT("%f sec.\n", (float) (clock() - clockt) / CLOCKS_PER_SEC);
#endif
return 0;
}