ZMの中のRANSACコード
RANSAC(Random sample consensus)のオリジナル文献は次のものです。
M. A. Fischler and R. C. Bolles, ``Random Sample Consensus: A Paradigm for Model Fitting with Applications to Image Analysis and Automated Cartography,'' Communications of the ACM 24(6):381-395 (1981).
Wikipediaも参考になります。
あまり悩まずに実装できそうだと思ったので、試してみました。ただし、オリジナルのものと少しだけ違う次のような手順にしました。
- 全ての標本から比較的少数の標本をランダムに選択
- 選択した標本群を用いてシステムパラメータ(モデル)同定
- 上記モデルからの誤差がある閾値内に収まる標本(consensus set)の数をカウント
- 上記カウント数が最大となるモデルを抽出
- 上記モデルのconsensus setを用いてモデルをリファイン
オリジナルの方法ではconsensus setを用いたモデルリファインを上記3と4の間で毎回行いますが、それを最後一回に省いてしまった点が違いです。
ZMでの実装を載せます。
/* RANSAC: random sample consensus. */
zVec zRANSAC(zVec q, zVecList *sample, zVec (* fit_fp)(zVec,zVecList*,void*), double (* error_fp)(zVec,zVec,void*), void *util, int ns, int nt, double th)
{
zVec qt;
zVecList va;
int count, count_prev = 0;
if( !( qt = zVecAlloc( zVecSizeNC(q) ) ) ) return NULL;
if( ns < zVecSizeNC(q) )
ZRUNWARN( ZM_WARN_INSUFFICIENT_SAMPLES, ns, zVecSizeNC(q) );
/* find a candidate of the most likely model */
while( --nt >= 0 ){
if( !_zRANSACSelectRandom( sample, &va, ns ) ) break;
fit_fp( qt, &va, util );
zListAppend( sample, &va );
if( ( count = _zRANSACCountConsensus( qt, sample, error_fp, util, th ) ) > count_prev ){
count_prev = count;
zVecCopyNC( qt, q );
}
}
zVecFree( qt );
/* refine the model with the doubled threshold */
_zRANSACSelectConsensus( &va, q, sample, error_fp, util, th*2 );
fit_fp( q, &va, util );
zListAppend( sample, &va );
return q;
}
モデルフィット関数fit_fp()
と誤差計算関数error_fp()
は引数で与えるようにしています。
最後のモデルリファインメントにおいて閾値を2倍しているのは、経験的にそうする方がロバスト性が少し上がると分かったためです。
標本の乱数的選択_zRANSACSelectRandom()
、consensus setの要素数カウント_zRANSACCountConsensus()
、consensus set抽出_zRANSACSelectConsensus()
の実装は次の通りです。
/* randomly select test data from original data set. */
static bool _zRANSACSelectRandom(zVecList *sample, zVecList *va, int n)
{
zVecListCell *sp;
int i;
if( n > zListSize(sample) ){
ZRUNERROR( ZM_ERR_INVALID_NUMSAMP, n, zListSize(sample) );
return false;
}
zListInit( va );
while( zListSize(va) < n ){
i = zRandI( 0, zListSize(sample)-1 );
zListItem( sample, i, &sp );
zListPurge( sample, sp );
zListInsertHead( va, sp );
}
return true;
}
/* count number of consensus samples with respect to a guess. */
static int _zRANSACCountConsensus(zVec q, zVecList *sample, double (* error_fp)(zVec,zVec,void*), void *util, double th)
{
zVecListCell *sp;
int count = 0;
zListForEach( sample, sp ){
if( fabs( error_fp( q, sp->data, util ) ) < th ) count++;
}
return count;
}
/* select consensus samples for model refinement from original data set. */
static int _zRANSACSelectConsensus(zVecList *va, zVec q, zVecList *sample, double (* error_fp)(zVec,zVec,void*), void *util, double th)
{
zVecListCell *sp, *sp_prev;
zListInit( va );
zListForEach( sample, sp ){
if( fabs( error_fp( q, sp->data, util ) ) < th ){
sp_prev = zListCellPrev( sp );
zListPurge( sample, sp );
zListInsertHead( va, sp );
sp = sp_prev;
}
}
return zListSize(va);
}
コーディングで悩むところはあまりありませんでした。
テスト
上記を使って、直線フィッティングと放物線フィッティングを試してみました。全標本数は100、アウトライア混入確率は0.5とし、全標本にも±5以内のノイズを加えました。結果の例を示します。
結構うまくいってます。
パラメータについて
上記の実装では、乱択標本数ns
、試行回数nt
、consensus判定閾値th
を与える必要があります。結果はこれらにかなり依存します。経験則に過ぎませんが、次のように設定するのが良さそうだと分かりました。
- ns = min( 5m, (1-r)n/2 )
- nt = 3ns
- th = l/10
ただしm
はシステムパラメータ数、n
は全標本数、r
はアウトライア混入確率、l
はノイズレベルです。閾値をノイズレベルの10分の1まで落とさないとアウトライアの影響を十分低減できないというのは少し意外でした。
最後に、上記テストに使ったコードを載せます。
# include <zm/zm_data.h>
# include <zm/zm_le.h>
/* TEST 1: plane, 2: parabola */
# define TEST 2
# if TEST == 1
/* test: plane */
int q_size = 2;
double y_test(double x)
{
/* plane: y = 5 ( x - 2 ) */
return 5 * ( x - 2 );
}
double error_case(zVec q, zVec sample, void *util)
{
return zVecInnerProd( q, sample ) - 1;
}
zVec fit_case(zVec q, zVecList *list, void *util)
{
zMat m;
zVec p;
zVecListCell *sp;
m = zMatAllocSqr( zVecSizeNC(q) );
p = zVecAlloc( zVecSizeNC(q) );
if( !m || !p ){
q = NULL;
goto TERMINATE;
}
zListForEach( list, sp ){
zMatAddDyadNC( m, sp->data, sp->data );
zVecAddDRC( p, sp->data );
}
zLESolveGauss( m, p, q );
TERMINATE:
zMatFree( m );
zVecFree( p );
return q;
}
void print_case(zVec q)
{
printf( "plot [-10:10] 's' u 1:2 w p lt 2, (%.10g)*x+(%.10g) w l lt 7\n", -zVecElemNC(q,0)/zVecElemNC(q,1), 1.0/zVecElemNC(q,1) );
}
# else
/* test: parabola */
int q_size = 3;
double y_test(double x)
{
/* parabola: y = ( x - 0.5 )^2 - 1 = x^2 - x - 0.75 */
return zSqr( x - 0.5 ) - 1;
}
double error_case(zVec q, zVec sample, void *util)
{
return zVecElemNC(q,0)*zSqr(zVecElemNC(sample,0)) +
zVecElemNC(q,1)*zVecElemNC(sample,0) +
zVecElemNC(q,2) - zVecElemNC(sample,1);
}
zVec fit_case(zVec q, zVecList *list, void *util)
{
zMat m;
zVec p, pi;
zVecListCell *sp;
m = zMatAllocSqr( zVecSizeNC(q) );
p = zVecAlloc( zVecSizeNC(q) );
pi = zVecAlloc( zVecSizeNC(q) );
if( !m || !p || !pi ){
q = NULL;
goto TERMINATE;
}
zListForEach( list, sp ){
zVecSetElemList( pi, zSqr(zVecElemNC(sp->data,0)), zVecElemNC(sp->data,0), 1.0 );
zMatAddDyadNC( m, pi, pi );
zVecCatDRC( p, zVecElemNC(sp->data,1), pi );
}
zLESolveGauss( m, p, q );
TERMINATE:
zMatFree( m );
zVecFree( p );
zVecFree( pi );
return q;
}
void print_case(zVec q)
{
printf( "plot [-10:10] 's' u 1:2 w p lt 2, (%.10g)*x*x+(%.10g)*x+(%.10g) w l lt 7\n", zVecElemNC(q,0), zVecElemNC(q,1), zVecElemNC(q,2) );
}
# endif
void sample_list(zVecList *sample, int n, double r, double nl)
{
zVec v;
double x, y;
register int i;
FILE *fp;
zListInit( sample );
fp = fopen( "s", "w" );
for( i=0; i<n; i++ ){
v = zVecAlloc( 2 );
y = y_test( ( x = zRandF( -10, 10 ) ) ) + zRandF(-nl,nl);
/* an outlier model with a probability of r */
if( zRandF(0,1) < r ) y += zRandF(-50,50);
zVecSetElemList( v, x, y );
zVecDataFPrint( fp, v );
zVecListInsertHead( sample, v );
}
fclose( fp );
}
# define N 100
# define R 0.5
# define NS 20
# define NT 50
# define TH 0.5
# define NL 5.0
int main(int argc, char *argv[])
{
zVecList sample;
zVec q;
zRandInit();
sample_list( &sample, N, R, NL );
q = zVecAlloc( q_size );
# if 0
zRANSAC( q, &sample, fit_case, error_case, NULL, NS, NT, TH );
# else
zRANSACAuto( q, &sample, fit_case, error_case, NULL, R, NL );
# endif
print_case( q );
zVecFree( q );
zVecListDestroy( &sample );
return 0;
}
fit_case()
は最小二乗法を素直に実装しています。またzRANSACAuto()
は、前述のパラメータ自動設定を採用したバージョンです。