C
アルゴリズム
ソート

ヒープソートで最適なのは少なくとも2分木では無い事を示そう

目的

ヒープソートの改良

計算量の考察

サイズmの配列をn分木のヒープに構築するのに掛かる時間はO(m)なので無視する

ヒープになっている部分の最後と最初を交換し、サイズが1つ小さくなったヒープを修復する平均的なコストについて考える。

ヒープの高さHはヒープのサイズをSとすると

H = \log_n S = \log S / \log n

子の中の最大もしくは最小を探す比較回数は n - 1
親と子を比較する回数は 1
親と子を交換する回数は 1
それをヒープの高さに比例した回数繰り返す場合の比較回数C、交換回数Rは

C = (n - 1 + 1) * H =  (\log S) * n / \log n \\
R = 1 * H =  (\log S) * 1 / \log n \\

尚、計算量Tは比較と交換が同じコストで実行できると仮定すると、

\begin{align}
T &\varpropto \sum_{S=1}^{m}\log (C + R) \\
&\approx ((n + 1) / \log n) \sum_{S=1}^{m}\log S \\
&\approx ((n + 1) / \log n) \int_{1}^{m}\log S dS \\
&\approx ((n + 1) / \log n) m\log m
\end{align}

であるから計算量のオーダはO(m log m)でnによって変わる事は無い。

nが3の場合、2の場合と比べて、
比較回数は(3/log 3)/(2/log 2) = 0.946倍で少なくなり、
交換回数は(1/log 3)/(1/log 2) = 0.631倍でこれも少なくなる。
この時点で2分木は良くないという事が分かる。

nが4の場合、3の場合と比べて、
比較回数は(4/log 4)/(3/log 3) = 1.057倍で少し増え、
交換回数は(1/log 4)/(1/log 3) = 0.792倍で少なくなる。

比較のコストによっては3分木より4分木の方が早いかもしれない。
今時のCPUにはキャッシュが有るため、ある親ノードの子ノードが全て同一のキャッシュラインに乗るなら、比較するために読み込む追加コストがない分早いと思う。

実際

実験内容

head -c 134217728 /dev/urandom > /dev/shm/temp.bin
で生成したファイルを4バイト符号付き整数の配列として読み込みソートを行うプログラムの実行時間を測定。

実験環境

CPU: ファンが動かなくなって800MHzほぼ固定なSempron 3100+
Ram: 2GB
OS: Debian 32bit

結果

三回実行した中央値は

n分木 実行時間
2 39.710s
3 32.220s
4 29.210s

コード

二分木
#include <stdio.h>
#include <stdlib.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/stat.h>
#include <sys/types.h>

#define compair(a, b) (a < b)

#define repair_heap(end) \
do { \
    k = j * 2 + 1; \
    if (k + 1 < end) { \
        if (compair(data[k], data[k + 1])) { \
            if (compair(temp, data[k + 1])) { \
                data[j] = data[k + 1]; \
                j = k + 1; \
            } else \
                break; \
        } else { \
            if (compair(temp, data[k])) { \
                data[j] = data[k]; \
                j = k; \
            } else \
                break; \
        } \
    } else if (k < end) { \
        if (compair(temp, data[k])) { \
            data[j] = data[k]; \
            j = k; \
        } else \
            break; \
    } else \
        break; \
} while(1)

void sort(int *data,int size)
{
    int i, j, k;
    int temp;
    if (size < 2)
        return;

    for (i = (size - 2)/2; i >= 0; i--) {
        j = i;
        temp = data[j];
        repair_heap(size);
        data[j] = temp;
    }
    for (i = size - 1; i > 0; i--) {
        j = 0;
        temp = data[i];
        data[i] = data[0];
        repair_heap(i);
        data[j] = temp;
    }
}

int main ()
{
    int i,ret_code = -1,fd,*data = NULL;
    struct stat file_stat;
    off_t file_size;

    if ((fd = open("/dev/shm/temp.bin",O_RDONLY)) < 0)
        goto error01;
    if (fstat(fd,&file_stat))
        goto error02;
    if (file_stat.st_size <= 0 || (file_stat.st_size & (sizeof(int) - 1)))
        goto error02;
    if ((data = (int *)malloc(file_stat.st_size)) == NULL)
        goto error02;
    file_size = read(fd,data,file_stat.st_size);
    if (file_size != file_stat.st_size)
        goto error03;

    sort(data,(int)(file_size/sizeof(int)));

    for (i = 0; i < ((((int)(file_size/sizeof(int))) > 100)? 100 :  ((int)(file_size/sizeof(int)))); i++)
        printf("%d\n",data[i]);

    ret_code = 0;
error03:
    if (data != NULL)
        free(data);
error02:
    close(fd);
error01:
    return ret_code;
}

3分木
#include <stdio.h>
#include <stdlib.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/stat.h>
#include <sys/types.h>

#define compair(a, b) (a < b)

#define repair_heap(end) \
do { \
    k = j * 3 + 1; \
    if (k + 2 < end) { \
        if (compair(data[k + 1], data[k + 2])) { \
            if (compair(data[k], data[k + 2])) { \
                if (compair(temp, data[k + 2])) { \
                    data[j] = data[k + 2]; \
                    j = k + 2; \
                } else \
                    break; \
            } else { \
                if (compair(temp, data[k])) { \
                    data[j] = data[k]; \
                    j = k; \
                } else \
                    break; \
            } \
        } else { \
            if (compair(data[k], data[k + 1])) { \
                if (compair(temp, data[k + 1])) { \
                    data[j] = data[k + 1]; \
                    j = k + 1; \
                } else \
                    break; \
            } else { \
                if (compair(temp, data[k])) { \
                    data[j] = data[k]; \
                    j = k; \
                } else \
                    break; \
            } \
        } \
    } else if (k + 1 < end) { \
        if (compair(data[k], data[k + 1])) { \
            if (compair(temp, data[k + 1])) { \
                data[j] = data[k + 1]; \
                j = k + 1; \
            } else \
                break; \
        } else { \
            if (compair(temp, data[k])) { \
                data[j] = data[k]; \
                j = k; \
            } else \
                break; \
        } \
    } else if (k < end) { \
        if (compair(temp, data[k])) { \
            data[j] = data[k]; \
            j = k; \
        } else \
            break; \
    } else \
        break; \
} while(1)

void sort(int *data,int size)
{
    int i, j, k;
    int temp;
    if (size < 2)
        return;

    for (i = (size - 2)/3; i >= 0; i--) {
        j = i;
        temp = data[j];
        repair_heap(size);
        data[j] = temp;
    }
    for (i = size - 1; i > 0; i--) {
        j = 0;
        temp = data[i];
        data[i] = data[0];
        repair_heap(i);
        data[j] = temp;
    }
}

int main ()
{
    int i,ret_code = -1,fd,*data = NULL;
    struct stat file_stat;
    off_t file_size;

    if ((fd = open("/dev/shm/temp.bin",O_RDONLY)) < 0)
        goto error01;
    if (fstat(fd,&file_stat))
        goto error02;
    if (file_stat.st_size <= 0 || (file_stat.st_size & (sizeof(int) - 1)))
        goto error02;
    if ((data = (int *)malloc(file_stat.st_size)) == NULL)
        goto error02;
    file_size = read(fd,data,file_stat.st_size);
    if (file_size != file_stat.st_size)
        goto error03;

    sort(data,(int)(file_size/sizeof(int)));

    for (i = 0; i < ((((int)(file_size/sizeof(int))) > 100)? 100 :  ((int)(file_size/sizeof(int)))); i++)
        printf("%d\n",data[i]);

    ret_code = 0;
error03:
    if (data != NULL)
        free(data);
error02:
    close(fd);
error01:
    return ret_code;
}

4分木
#include <stdio.h>
#include <stdlib.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/stat.h>
#include <sys/types.h>

#define AlignmentSize 64
#define RoundUp_Size(a) (((a) + (AlignmentSize - 1)) & ~(AlignmentSize - 1))

#define compair(a, b) (a < b)

#define repair_heap(end) \
do { \
    k = j * 4 + 1; \
    if (k + 3 < end) { \
        if (compair(data[k + 2], data[k + 3])) { \
            if (compair(data[k + 1], data[k + 3])) { \
                if (compair(data[k], data[k + 3])) { \
                    if (compair(temp, data[k + 3])) { \
                        data[j] = data[k + 3]; \
                        j = k + 3; \
                    } else \
                        break; \
                } else { \
                    if (compair(temp, data[k])) { \
                        data[j] = data[k]; \
                        j = k; \
                    } else \
                        break; \
                } \
            } else { \
                if (compair(data[k], data[k + 1])) { \
                    if (compair(temp, data[k + 1])) { \
                        data[j] = data[k + 1]; \
                        j = k + 1; \
                    } else \
                        break; \
                } else { \
                    if (compair(temp, data[k])) { \
                        data[j] = data[k]; \
                        j = k; \
                    } else \
                        break; \
                } \
            } \
        } else { \
            if (compair(data[k + 1], data[k + 2])) { \
                if (compair(data[k], data[k + 2])) { \
                    if (compair(temp, data[k + 2])) { \
                        data[j] = data[k + 2]; \
                        j = k + 2; \
                    } else \
                        break; \
                } else { \
                    if (compair(temp, data[k])) { \
                        data[j] = data[k]; \
                        j = k; \
                    } else \
                        break; \
                } \
            } else { \
                if (compair(data[k], data[k + 1])) { \
                    if (compair(temp, data[k + 1])) { \
                        data[j] = data[k + 1]; \
                        j = k + 1; \
                    } else \
                        break; \
                } else { \
                    if (compair(temp, data[k])) { \
                        data[j] = data[k]; \
                        j = k; \
                    } else \
                        break; \
                } \
            } \
        } \
    } else if (k + 2 < end) { \
        if (compair(data[k + 1], data[k + 2])) { \
            if (compair(data[k], data[k + 2])) { \
                if (compair(temp, data[k + 2])) { \
                    data[j] = data[k + 2]; \
                    j = k + 2; \
                } else \
                    break; \
            } else { \
                if (compair(temp, data[k])) { \
                    data[j] = data[k]; \
                    j = k; \
                } else \
                    break; \
            } \
        } else { \
            if (compair(data[k], data[k + 1])) { \
                if (compair(temp, data[k + 1])) { \
                    data[j] = data[k + 1]; \
                    j = k + 1; \
                } else \
                    break; \
            } else { \
                if (compair(temp, data[k])) { \
                    data[j] = data[k]; \
                    j = k; \
                } else \
                    break; \
            } \
        } \
    } else if (k + 1 < end) { \
        if (compair(data[k], data[k + 1])) { \
            if (compair(temp, data[k + 1])) { \
                data[j] = data[k + 1]; \
                j = k + 1; \
            } else \
                break; \
        } else { \
            if (compair(temp, data[k])) { \
                data[j] = data[k]; \
                j = k; \
            } else \
                break; \
        } \
    } else if (k < end) { \
        if (compair(temp, data[k])) { \
            data[j] = data[k]; \
            j = k; \
        } else \
            break; \
    } else \
        break; \
} while(1)

void sort(int *data,int size)
{
    int i, j, k;
    int temp;
    if (size < 2)
        return;

    for (i = (size - 2)/4; i >= 0; i--) {
        j = i;
        temp = data[j];
        repair_heap(size);
        data[j] = temp;
    }
    for (i = size - 1; i > 0; i--) {
        j = 0;
        temp = data[i];
        data[i] = data[0];
        repair_heap(i);
        data[j] = temp;
    }
}

int main ()
{
    int i,ret_code = -1,fd,*data = NULL;
    struct stat file_stat;
    off_t file_size;

    if ((fd = open("/dev/shm/temp.bin",O_RDONLY)) < 0)
        goto error01;
    if (fstat(fd,&file_stat))
        goto error02;
    if (file_stat.st_size <= 0 || (file_stat.st_size & (sizeof(int) - 1)))
        goto error02;
    if (posix_memalign((void **)&data, AlignmentSize, RoundUp_Size(file_stat.st_size + 3 * sizeof(int))))
        goto error02;
    data += 3;
    file_size = read(fd,data,file_stat.st_size);
    if (file_size != file_stat.st_size)
        goto error03;

    sort(data,(int)(file_size/sizeof(int)));

    for (i = 0; i < ((((int)(file_size/sizeof(int))) > 100)? 100 :  ((int)(file_size/sizeof(int)))); i++)
        printf("%d\n",data[i]);

    ret_code = 0;
error03:
    data -= 3;
    if (data != NULL)
        free(data);
error02:
    close(fd);
error01:
    return ret_code;
}

終わりに

こうして人類は少しだけ速いヒープソートを手に入れた。
実際にはイントロソートの中で使う為に作っただけです。