paizaラーニングレベルアップ問題集の区間和をやってみました。
問題
参考
前回は「累積和」の問題を解きましたが、今回は「区間和」の問題を解いていきます(“前回”なんて知らねーよ)。
参考記事のIntArrayクラスに、新たに区間和メソッドinterval_sumを追加します。
#ifndef INTARRAY_H_
#define INTARRAY_H_
#include <stdbool.h>
typedef struct __IntArray _IntArray;
typedef struct IntArray {
_IntArray *_intArray;
bool (*cumulative_sum)(const struct IntArray*, size_t, int*);
+ bool (*interval_sum)(const struct IntArray*, size_t, size_t, int*);
} IntArray;
IntArray* new_IntArray(size_t, const int*);
void free_IntArray(IntArray**);
#endif /* INTARRAY_H_ */
テストコード
早速、テストコードを書いていきます。このテストをパスするように本体のコードを書いていきます。
今回はinterval_sumメソッドのテストのみ記述します。
interval_sumの要件:
- 入力添字は$0\le l\le r \le N$
問題文と異なりますが、この範囲の入力を受け付けるように実装しておきます(後々便利かも。Pythonのsum(A[l:r])と同じ答えになります)。- 上記条件を満たす$l, r$に対して、$a_{l+1}+a_{l+2}+\dots+a_{r-1}+a_r$を返す(実装では添字がズレるので
A[l]+A[l+1]+...+A[r-2]+A[r-1]) - 上記条件を満たさない入力は
falseをreturnする
例:$l=-1$(size_t型を引数としているので実現不可)、$r=l-1$($l=r+1$)、$r=N+1$、等
- 上記条件を満たす$l, r$に対して、$a_{l+1}+a_{l+2}+\dots+a_{r-1}+a_r$を返す(実装では添字がズレるので
- $l=r$の時は$0$を返す(今回の問題ではこのような入力は与えられない)
- $l=0, r=N$の時は、配列要素全ての和を返す
尚、単体テストツールとして、minunit.hを使います。
#include <stdio.h>
#include <string.h>
#include <stdarg.h>
#include <errno.h>
#include "minunit.h"
#include "IntArray.h"
int tests_run = 0;
static char* message(int expected, int actual) {
static char msg[72];
snprintf(msg, sizeof(msg), "Error: expected: <%d> but was: <%d>", expected, actual);
return msg;
}
static char* test_interval_sum(int expected, size_t l, size_t r, size_t n, ...) {
va_list args;
va_start(args, n);
int A[n];
for (size_t i = 0; i < n; i++) {
A[i] = va_arg(args, int);
}
va_end(args);
IntArray *intArray = new_IntArray(n, A);
int actual;
bool result = intArray->interval_sum(intArray, l, r, &actual);
free_IntArray(&intArray);
mu_assert("Error: expected: <true> but was: <false>", result);
mu_assert(message(expected, actual), actual == expected);
mu_assert(message(0, errno), errno == 0);
return 0;
}
static char* test_interval_sum_fail(size_t l, size_t r, size_t n, ...) {
va_list args;
va_start(args, n);
int A[n];
for (size_t i = 0; i < n; i++) {
A[i] = va_arg(args, int);
}
va_end(args);
IntArray *intArray = new_IntArray(n, A);
int sum;
bool result = intArray->interval_sum(intArray, l, r, &sum);
free_IntArray(&intArray);
mu_assert("Error: expected: <false> but was: <true>", !result);
mu_assert(message(ERANGE, errno), errno == ERANGE);
return 0;
}
static char* test_interval_sum_0() {return test_interval_sum(0, 0, 0, 0);}
static char* test_interval_sum_1() {return test_interval_sum_fail(0, 1, 0);}
static char* test_interval_sum_2() {return test_interval_sum_fail(1, 0, 1, 0);}
static char* test_interval_sum_3() {return test_interval_sum(0, 0, 0, 1, 100);}
static char* test_interval_sum_4() {return test_interval_sum(100, 0, 1, 1, 100);}
static char* test_interval_sum_5() {return test_interval_sum(0, 1, 1, 1, 100);}
static char* all_tests() {
mu_run_test(test_interval_sum_0);
mu_run_test(test_interval_sum_1);
mu_run_test(test_interval_sum_2);
mu_run_test(test_interval_sum_3);
mu_run_test(test_interval_sum_4);
mu_run_test(test_interval_sum_5);
return 0;
}
int main() {
char *result = all_tests();
if (result != 0) {
fprintf(stderr, "%s\n", result);
fprintf(stderr, "Tests run: %d\n", tests_run);
} else {
fprintf(stdout, "ALL TESTS PASSED\n");
fprintf(stdout, "Tests run: %d\n", tests_run);
}
return result != 0;
}
読者諸賢においては、様々なテストデータで試していただきたい。
本体実装
上記のテストをパスするように、interval_sumメソッドを実装していきます(上記のテストデータだけだと、他にもパスしてしまうコードも沢山ありそうだが)。
区間和$\sum_{i=l}^r a_i$は、$S_0=0$、$S_j=\sum_{i=1}^j a_i$を用いて
\begin{align}
\sum_{i=l}^r a_i&=a_l+a_{l+1}+\dots+a_{r-1}+a_r\\
&=\left(a_1+a_2+\dots+a_r\right)-\left(a_1+a_2+\dots+a_{l-1}\right)\\
&=S_r-S_{l-1}
\end{align}
と書けるので、interval_sumメソッドの実装は
static bool IntArray_interval_sum(const IntArray *self, size_t l, size_t r, int *sum) {
errno = 0;
if (r < l || self->_intArray->n < r) {
errno = ERANGE;
return false;
}
*sum = self->_intArray->sum[r] - self->_intArray->sum[l];
return true;
}
となります(添字のズレの詳細は省略)。
これをコンストラクタnew_IntArrayにて
intArray->interval_sum = IntArray_interval_sum;
を追加すれば完成。
入力チェック
入力後半のクエリ部分に$1\le l\le r\le N$を満たす整数$(l,r)$が入ることになるので、前回のコードからその部分を修正します。
#ifndef INPUT_H_
#define INPUT_H_
#include <stdbool.h>
bool read_size(const char*, size_t*, size_t*);
bool read_element(const char*, int*);
- bool read_query(const char*, size_t, size_t*);
+ bool read_query(const char*, size_t, size_t*, size_t*);
#endif /* INPUT_H_ */
(以下、read_query関数のみ)
bool read_query(const char *str, size_t n, size_t *l, size_t *r) {
if (!str) {
errno = EINVAL;
return false;
}
char s[] = "100000 100000\n";
if (strlen(str) > strlen(s)) {
errno = EINVAL;
return false;
}
if (!parse_ulong(strtok(chomp(strncpy(s, str, sizeof(s))), " "), l)) {
return false;
} else if (*l < 1 || n < *l) {
errno = ERANGE;
return false;
}
if (!parse_ulong(strtok(NULL, " "), r)) {
return false;
} else if (*r < *l || n < *r) {
errno = ERANGE;
return false;
}
if (strtok(NULL, " ")) {
errno = EINVAL;
return false;
}
return true;
}
以下、read_query関数にかかわるテストのみ記述
#include <stdio.h>
#include <string.h>
#include <errno.h>
#include "minunit.h"
#include "input.h"
int tests_run = 0;
static char* message_d(int expected, int actual) {
static char msg[72];
snprintf(msg, sizeof(msg), "Error: expected: <%d> but was: <%d>", expected, actual);
return msg;
}
static char* message_z(size_t expected, size_t actual) {
static char msg[72];
snprintf(msg, sizeof(msg), "Error: expected: <%zu> but was: <%zu>", expected, actual);
return msg;
}
static char* test_read_query(const char *str, size_t n, size_t el, size_t er) {
size_t l, r;
mu_assert("Error: expected: <true> but was: <false>", read_query(str, n, &l, &r));
mu_assert(message_z(el, l), l == el);
mu_assert(message_z(er, r), r == er);
return 0;
}
static char* test_read_query_invalid(const char *str, size_t n) {
size_t l, r;
mu_assert("Error: expected: <false> but was: <true>", !read_query(str, n, &l, &r));
mu_assert(message_d(EINVAL, errno), errno = EINVAL);
return 0;
}
static char* test_read_query_out_of_range(const char *str, size_t n) {
size_t l, r;
mu_assert("Error: expected: <false> but was: <true>", !read_query(str, n, &l, &r));
mu_assert(message_d(ERANGE, errno), errno = ERANGE);
return 0;
}
static char* test_read_query_0() {return test_read_query_invalid(NULL, 0);}
static char* test_read_query_1() {return test_read_query_invalid("", 0);}
static char* test_read_query_2() {return test_read_query_invalid("1", 1);}
static char* test_read_query_3() {return test_read_query_invalid("1 1 1", 1);}
static char* test_read_query_4() {return test_read_query_out_of_range("0 1", 1);}
static char* test_read_query_5() {return test_read_query_out_of_range("1 2", 1);}
static char* test_read_query_6() {return test_read_query_out_of_range("2 1", 2);}
static char* test_read_query_7() {return test_read_query("1 1", 1, 1, 1);}
static char* test_read_query_8() {return test_read_query("100000 100000", 100000, 100000, 100000);}
static char* all_tests() {
mu_run_test(test_read_query_0);
mu_run_test(test_read_query_1);
mu_run_test(test_read_query_2);
mu_run_test(test_read_query_3);
mu_run_test(test_read_query_4);
mu_run_test(test_read_query_5);
mu_run_test(test_read_query_6);
mu_run_test(test_read_query_7);
mu_run_test(test_read_query_8);
return 0;
}
int main() {
char *result = all_tests();
if (result != 0) {
fprintf(stderr, "%s\n", result);
fprintf(stderr, "Tests run: %d\n", tests_run);
} else {
fprintf(stdout, "ALL TESTS PASSED\n");
fprintf(stdout, "Tests run: %d\n", tests_run);
}
return result != 0;
}
エントリポイント
main関数の後半部分を、2整数のクエリを受けて区間和を表示するように改修します。
#include <stdio.h>
#include <errno.h>
#include "IntArray.h"
#ifdef NDEBUG
#include <string.h>
#include "input.h"
#else
#include <time.h>
#endif
int main() {
#ifndef NDEBUG
clock_t clockt = clock();
#endif
size_t N, K;
#ifdef NDEBUG
do {
char s[] = "100000 100000\n";
if (!strchr(fgets(s, sizeof(s), stdin), '\n')) {
for (int c = getchar(); (c == '\0' || c == '\n' || c == EOF); c = getchar());
} else {
if (read_size(s, &N, &K)) {
break;
}
}
perror("The number of elements must be an integer between 1 and 100000.\n"
"The number of queries must be an integer between 1 and 100000.");
} while (true);
#else
N = 100000;
K = 100000;
#endif
int A[N];
for (size_t i = 0; i < N; i++) {
#ifdef NDEBUG
do {
char s[] = "-100\n";
if (!strchr(fgets(s, sizeof(s), stdin), '\n')) {
for (int c = getchar(); (c == '\0' || c == '\n' || c == EOF); c = getchar());
} else {
if (read_element(s, &A[i])) {
break;
}
}
perror("The element must be an integer between -100 and 100.");
} while (true);
#else
A[i] = ((int) i) % 201 - 100;
#endif
}
IntArray *intArray = new_IntArray(N, A);
while (K--) {
- size_t Q;
+ size_t l, r;
#ifdef NDEBUG
do {
- char s[] = "100000\n";
+ char s[] = "100000 100000\n";
if (!strchr(fgets(s, sizeof(s), stdin), '\n')) {
for (int c = getchar(); (c == '\0' || c == '\n' || c == EOF); c = getchar());
} else {
- if (read_query(s, N, &Q)) {
+ if (read_query(s, N, &l, &r)) {
break;
}
}
fprintf(stderr, "The index must be an integer between 1 and %zu.\n", N);
} while (true);
#else
- Q = N - K;
+ l = K + 1;
+ r = N;
#endif
int sum;
- if (!intArray->cumulative_sum(intArray, Q, &sum)) {
+ if (!intArray->interval_sum(intArray, --l, r, &sum)) {
int e = errno;
perror(NULL);
return e;
}
printf("%d\n", sum);
}
free_Intarray(&intArray);
#ifndef NDEBUG
fprintf(stderr, "%f sec.\n", (float) (clock() - clockt) / CLOCKS_PER_SEC);
#endif
return 0;
}
テストコード
テストの内容だけ記載します(前後は省略)。あとは前回の記事のtest.cに埋め込んでください。
- 最低限。1要素のみの配列に対して区間和(も何もないが)を求める
static char* test0() {
const char *lines[] = { "1 1", "0", "1 1" };
size_t n = test(sizeof(lines) / sizeof(lines[0]), lines);
mu_assert(message_u(1, n), n == 1);
mu_assert(message_s("0\n", result[0]), !strcmp(result[0], "0\n"));
return 0;
}
- 入出力例1
static char* test1() {
const char *lines[] = { "4 2", "16", "88", "10", "-65", "2 4", "1 2" };
size_t n = test(sizeof(lines) / sizeof(lines[0]), lines);
mu_assert(message_u(2, n), n == 2);
mu_assert(message_s("33\n", result[0]), !strcmp(result[0], "33\n"));
mu_assert(message_s("104\n", result[1]), !strcmp(result[1], "104\n"));
return 0;
}
- 入出力例2
static char* test2() {
const char *lines[] = { "10 5",
"82", "-37", "40", "-72", "-24", "-54", "57", "-6", "42", "-24",
"8 9", "6 9", "2 3", "4 4", "1 5" };
size_t n = test(sizeof(lines) / sizeof(lines[0]), lines);
mu_assert(message_u(5, n), n == 5);
mu_assert(message_s("36\n", result[0]), !strcmp(result[0], "36\n"));
mu_assert(message_s("39\n", result[1]), !strcmp(result[1], "39\n"));
mu_assert(message_s("3\n", result[2]), !strcmp(result[2], "3\n"));
mu_assert(message_s("-72\n", result[3]), !strcmp(result[3], "-72\n"));
mu_assert(message_s("-11\n", result[4]), !strcmp(result[4], "-11\n"));
return 0;
}
- 入力エラーを起こしながら実行
static char* test3() {
const char *lines[] = {
"100000 100000", "100001 100001", "0 0", "1 1 1", "2 3",
"101", "-101", "-100", "100",
"0 0", "1 1 1", "2", "2 1", "1 3", "1 1", "1 2", "2 2"
};
size_t n = test(sizeof(lines) / sizeof(lines[0]), lines);
mu_assert(message_u(3, n), n == 3);
mu_assert(message_s("-100\n", result[0]), !strcmp(result[0], "-100\n"));
mu_assert(message_s("0\n", result[1]), !strcmp(result[1], "0\n"));
mu_assert(message_s("100\n", result[2]), !strcmp(result[2], "100\n"));
return 0;
}
- $N=100000$、$K=100000$の場合(時間測定用)
static char* test4() {
static char *lines[200001] = { "100000 100000" };
for (size_t i = 1; i <= 100000; i++) {
lines[i] = "-100";
}
static char chars[100000][16] = {};
for (size_t i = 0; i < 100000; i++) {
snprintf(chars[i], sizeof(chars[i]), "%zu 100000", -~i);
lines[100001 + i] = chars[i];
}
size_t n = test(sizeof(lines) / sizeof(lines[0]), (const char**) lines);
mu_assert(message_u(100000, n), n == 100000);
for (size_t i = 0; i < 100000; i++) {
char expected[11];
snprintf(expected, sizeof(expected), "-%zu00\n", 100000 - i);
mu_assert(message_s(expected, result[i]), !strcmp(result[i], expected));
}
return 0;
}