C
algorithm
paiza

Cで新人女子に褒めてもらう #paizahack_01

More than 5 years have passed since last update.


成績

http://paiza.jp/poh/ec-campaign/result/3c7e1ae7a3b70790db8da0ef6c287c11

背景や課題の詳細は

https://paiza.jp/poh/ec-campaign

を参照してください。#paizahack_01


当初のアルゴリズム

普通に全商品の組合せで最適解を求めようとすると、$O(DN^2)$になってしまいます(paiza作成の模範解答はこれなのでcase1しか通らない)。それを避け、$O(DN\log N)$にするため、次のように二分探索を使いました。


  1. ごく普通にscanf()で入力

  2. ごく普通に商品価格リストをqsort()でソート

  3. 各キャンペーン価格につき以下を実行


    1. 各商品につき、対になる価格をもつ商品を二分探索

    2. それらのうち最適値を表示



しかし、このコードではcase3がタイムアウトしてしまうという恥ずかしい結果に。急遽、性能を手元で定量化できるように、case3相当のベンチマーク用入力を作成し、チューニングを開始しました。


最終形態への道


探索の改善

軽くプロファイリングすると、探索に時間がかかりすぎています。$D=300$と大きいので、1探索あたり0.01秒以下にしないとどうしようもありません。そこで探索アルゴリズムを抜本的に変更しました。

最適化したいのは二つの商品価格の合計ですから、探索は両側から一つずつペアを作っていけばよいだけです。合計がキャンペーン価格を超えていれば、高い方の商品をより安い商品に変えてみる。キャンペーン価格を下回っていれば、それまでの暫定解を改善する場合は更新した上で、安い方の商品をより高い商品に変えてみる。キャンペーン価格に一致すればそれが最終解、最後まで一致する解がなければそのときの暫定解が最終解です。

探索する初期値として、安い方は最低額商品から始める必要がありますが、高い方はキャンペーン価格から最低額商品価格を引いたところから始められます。この位置は当初のアルゴリズムにもある二分探索で求めることができ、多少の改善が可能です。

この方針なら、探索部分に関しては$O(DN)$で済みます。実際、ベンチマークに対しても探索時間は激減し、300個全部を0.01秒以下で実行できるようになりました。しかし、全体としては0.30秒弱くらいかかってしまいます。


ソートの改善

次に手をつけたのがソートです。qsort()は内部で比較関数を何度も呼ぶのでオーバーヘッドになります。これを解消するため、非再帰版マージソートを自作しました。しかし、0.20秒強くらいです。

ふと、商品価格の種類はせいぜい1000000通りしかないことに気づき、ならば$O(N)$のバケツソートにしようと思い至りました。これを実装すると0.10秒強くらいになりました。

しかし、そもそもソートする必要があるのか、入力値をバケツソートの各バケツに入れたまま探索してしまっても、商品数とバケツ数がコンパラなので探索時間にはあまり影響ないはずだと考え、実装しました。副次的効果として、高い方の開始価格を二分探索する必要もなくなりました。この改良により、オーダーが変わるわけではありませんが、0.10秒を切るくらいになりました。


入力の改善

ここまでの改善で、ソート時間ゼロ(バケツに入れただけ)、探索そのものも0.01秒以内、ということは残るは入力処理です。

まずは、scanf()をやめ、fgets()strtol()にしました。これで0.04秒くらい。"%d"だけでもscanf()は重かった。

次に、関数呼出しのオーバーヘッドを完全になくすため、各行を処理するのではなく、ファイル全体が入る大きなバッファを用意してfread()で全体を読み込んでしまい、自前の入力関数内で数値をすべてデコードしてしまうように変更したところ、ようやく0.01秒と0.02秒をふらつくくらいまで高速化できました。

入力内のホワイトスペース有無や改行コードなどを決め打ちすれば、もうほんの少しは改善できるかもしれません。


まとめ

大域変数使ってるのはどうなのとか、想定外の入力が与えられたらコアダンプするよとか、テストコードがないよとか、そういう細かいことは(いや細かくないけど)無視しています。

やはりオーダー重要。オーダーがそれ以上改善できないとなってから、細かいチューニングの出番ですね。


コード

下に最終コードを示しますが、かなり長くなりました。

#include <stdio.h>

#include <stdlib.h>
#include <ctype.h>

#define MIN_PRICE (10)
#define MAX_PRICE (1000000)
#define MIN_LIMIT (10)
#define MAX_LIMIT (1000000)
#define MAX_PRODUCTS (500000)
#define MAX_DAYS (300)
#define MAX_INPUT_BYTES (6+1+3+1+(7+1)*MAX_PRICE+(7+1)*MAX_DAYS)

int pricebucket[MAX_PRICE+1];
int m_j[MAX_DAYS];
int N;
int D;

// read prices from buffer. it is faster than calling strtol many times.
void
readprices(char **p) {
char c;
int price;
int i;

for (i = 0; i < N; i++) {
price = 0;
for (;;) {
c = *(*p)++;
if (isdigit(c)) {
price = price*10 + c - '0';
} else {
if (price >= MIN_PRICE) {
pricebucket[price]++;
break;
}
}
}
}
}

// read limits from buffer. it is faster than calling strtol many times.
void
readlimits(char **p) {
char c;
int limit;
int j;

for (j = 0; j < D; j++) {
limit = 0;
for (;;) {
c = *(*p)++;
if (isdigit(c)) {
limit = limit*10 + c - '0';
} else {
if (limit >= MIN_LIMIT) {
m_j[j] = limit;
break;
}
}
}
}
}

// search prices for the pair of products so that
// their total price is nearest to and less than or equal to the limit.
// return the total price.
int
searchpair(int limit) {
int l = MIN_PRICE;
int r = (limit - l <= MAX_PRICE) ? limit - l : MAX_PRICE;
int currentmax = 0;
int total;

// check if there is a pair of product with prices l and r.
while (l <= r) {
total = l+r;
if (total <= limit && pricebucket[r] > 0) {
if (pricebucket[l] > 0) {
if (total > currentmax && (l < r || pricebucket[l] > 1)) {
// better pair of different products
currentmax = total;
if (currentmax == limit) {
break;
}
}
}
l++;
} else {
r--;
}
}
return currentmax;
}

int
main(void) {
int i, j, p;
char *buf;
int bufsize;
char *dp;

// use 1 fread + strtol's for read numbers.
// they are faster than scanf.
buf = malloc(MAX_INPUT_BYTES+1);
if (!buf) exit(1);
bufsize = fread(buf, 1, MAX_INPUT_BYTES, stdin);
buf[bufsize] = '\0'; // any non-digit character

N = strtol(buf, &dp, 10);
D = strtol(dp, &dp, 10);

// read prices.
// prices are not stored in an array. just use buckets of prices.
readprices(&dp);

// read limits.
readlimits(&dp);

// compute the answer.
for (j = 0; j < D; j++) {
printf("%d\n", searchpair(m_j[j]));
}

return 0;
}