C++を思い出しながらハフマン木のプログラムを書いてみた (ソース: Wikipedia)
メモリ管理周りがあまり自信がない... (所要時間60分)
- ハフマン木構築 (make_huffman_tree) は文字頻度を格納した配列を受け取り,ハフマン木を構築する
- ハフマン木の構築は,初期状態ではすべての文字がノード数1の別々の木となっている状態からスタートして頻度が低い2つの木をマージしていくことで得られる.マージされた木の頻度は2つの木の頻度の合計.
- このマージの回数は常にアルファベット数-1回となる
頻度の低い木を求めるところで下記のコードでは std::sort を使っているけど,本来なら priority_queue を使うべきだった- 追記 (2/18) : priority_queue で書き直した
コード
[binary_tree.hpp]
#pragma once
#include <iostream>
struct BinaryTree {
int id;
BinaryTree *left;
BinaryTree *right;
BinaryTree *parent;
BinaryTree() : left(nullptr), right(nullptr), parent(nullptr) {}
//BinaryTree() : left(nullptr), right(nullptr) {}
BinaryTree(BinaryTree *_left, BinaryTree *_right) : left(_left), right(_right), parent(nullptr) {
left->parent = this;
right->parent = this;
}
//BinaryTree(BinaryTree *_left, BinaryTree *_right) : left(_left), right(_right) {}
~BinaryTree(){
//std::cerr << "destructor called" << std::endl;
delete left; delete right;
}
BinaryTree* add_left(){
if(left != nullptr){
left = new BinaryTree;
left->parent = this;
}
return left;
}
BinaryTree* add_right(){
if(right != nullptr){
right = new BinaryTree;
right->parent = this;
}
return right;
}
size_t size() const {
size_t n=1;
if(left) n += left->size();
if(right) n += right->size();
return n;
}
bool is_leaf() const {
return (left==nullptr) && (right==nullptr);
}
size_t leaf_size() const {
size_t n=0;
if(this->is_leaf()) return 1;
else {
if(left) n += left->leaf_size();
if(right) n += right->leaf_size();
}
return n;
}
};
[huffman.hpp]
#pragma once
#include <vector>
#include <iostream>
#include <algorithm>
#include <string>
#include <queue>
#include "binary_tree.hpp"
using namespace std;
struct HuffmanTree : public BinaryTree {
int freq;
int term_id;
bool label; // flag for storing parent's label (0/1)
HuffmanTree(){};
HuffmanTree(HuffmanTree *left, HuffmanTree *right)
: BinaryTree(left, right), freq(left->freq+right->freq)
{left->label = false; right->label = true;}
~HuffmanTree(){}
string huffcode_as_string() const {
if(this->parent==nullptr) return "";
else if(label) return static_cast<HuffmanTree*>(this->parent)->huffcode_as_string() + "1";
else return static_cast<HuffmanTree*>(this->parent)->huffcode_as_string() + "0";
}
vector<bool> huffcode() const {
string huffcode_str = huffcode_as_string();
vector<bool> code; code.resize(huffcode_str.size());
for(int i=0;i<code.size();++i){
if(huffcode_str[i] == '0')
code[i] = false;
else
code[i] = true;
}
return code;
}
};
struct HuffmanCoder {
HuffmanTree *root;
vector<HuffmanTree*> term2node;
vector<vector<bool>> term2code;
HuffmanCoder(HuffmanTree *root_) : root(root_) {
int nleaf = root->leaf_size();
term2node.resize(nleaf, nullptr);
term2code.resize(nleaf);
run_leaf(root);
for(int i=0; i<nleaf; ++i){
term2code[i] = term2node[i]->huffcode(); // store huffman code as vector<bool>
}
}
vector<bool> encode(const vector<int> &x) const {
vector<bool> bitseq;
for(int xi : x){
for(bool bi : term2code[xi])
bitseq.push_back(bi);
}
return bitseq;
}
vector<int> decode(const vector<bool> &code) const {
vector<int> str;
HuffmanTree *node = this->root;
for(int i=0; i<code.size(); ++i){
if(!code[i]){
if(node->left==nullptr) throw "Invalid code";
if(node->left->is_leaf()){
str.push_back(static_cast<HuffmanTree*>(node->left)->term_id);
node = this->root;
} else {
node = static_cast<HuffmanTree*>(node->left);
}
} else {
if(node->right==nullptr) throw "Invalid code";
if(node->right->is_leaf()){
str.push_back(static_cast<HuffmanTree*>(node->right)->term_id);
node = this->root;
} else {
node = static_cast<HuffmanTree*>(node->right);
}
}
}
return str;
}
private:
void run_leaf(HuffmanTree *node){
if(node->is_leaf()){
term2node[node->term_id] = node;
} else {
if(node->left) run_leaf(static_cast<HuffmanTree*>(node->left));
if(node->right) run_leaf(static_cast<HuffmanTree*>(node->right));
}
}
};
HuffmanTree* make_huffman_tree(const vector<int> &frequency){
size_t n = frequency.size();
auto comp = [](const HuffmanTree* a, const HuffmanTree* b){return a->freq < b->freq; };
priority_queue<HuffmanTree*, vector<HuffmanTree*>, decltype(comp)> nodes(comp);
for(int w=0; w<n; w++){
auto node = new HuffmanTree;
node->freq = frequency[w];
node->term_id = w;
nodes.push(node);
}
HuffmanTree *root;
for(int i=0; i<n-1; ++i){
auto n1 = nodes.top(); nodes.pop();
auto n2 = nodes.top(); nodes.pop();
root = new HuffmanTree(n1, n2); // merge the least and 2nd least trees
root->term_id = -1;
nodes.push(root);
}
return root;
}
[main.cpp]
#include "huffman.hpp"
#include <memory>
string bin2string(const vector<bool> &b){
string s;
for(int i=0; i<b.size(); ++i){
if(b[i]) s += '1';
else s += '0';
}
return s;
}
int main(){
//vector<int> f = {100, 20, 83, 2, 10};
vector<int> f = {5, 3, 2, 1, 1};
auto ht_root = shared_ptr<HuffmanTree>(make_huffman_tree(f));
auto coder = HuffmanCoder(ht_root.get());
for(int i=0; i<coder.term2node.size(); ++i){
auto huffcode = bin2string(coder.term2node[i]->huffcode());
cout << i << "\t"
<< coder.term2node[i]->freq << "\t"
<< huffcode
<< endl;
}
vector<int> s = {0, 1, 2, 0};
vector<bool> s_comp = coder.encode(s);
for(auto xi : s) cout << xi << " ";
cout << endl;
cout << bin2string(s_comp) << endl;
auto s2 = coder.decode(s_comp);
for(auto xi : s2) cout << xi << " ";
cout << endl;
}
実行結果
Wikipedia の例と同様の長さの符号が得られた.また,エンコードしたビット列をデコードしたらもとの入力に一致した.
0 5 0
1 3 10
2 2 111
3 1 1101
4 1 1100
input: 0 1 2 0
encoded: 0101110
decoded: 0 1 2 0