LoginSignup
1
1

More than 5 years have passed since last update.

AdaBoost を書いてみた

Last updated at Posted at 2012-10-21

MLSS の復習。AdaBoost アルゴリズムを使って 円 x^2 + y^2 = 4 の内外を判定する分類器を構成してみました。

訓練データを使って「弱い分類器」として1次の線形分類器を適当に構成し (generate_linear_classifier)、construct_hardest_distribution で精度が最悪になるような確率分布を作って、その上で再び「弱い分類器」を構成する、というのを繰り返して最後に足し合わせるだけ。

個々の「弱い分類器」の精度は (途中で warn してる通り) 60%程度ですが最終的に得られる分類器のそれは90%程度になります。

adaboost.pl
#!/usr/bin/env perl

use 5.014;
use strict;
use warnings;
use List::MoreUtils qw/minmax/;
use List::Util qw/sum/;
use Math::Cartesian::Product;

# srand 42;

my @training_set;
until (@training_set == 100) {
  my ($x, $y) = (rand(4) - 2, rand(4) - 2);
  push @training_set, [$x, $y, $x ** 2 + $y ** 2 < 4 ? +1 : -1];
}
my $classifier = adaboost(\@training_set, 100);

my @test_set;
until (@test_set == 1000) {
  my ($x, $y) = (rand(4) - 2, rand(4) - 2);
  push @test_set, [$x, $y, $x ** 2 + $y ** 2 < 4 ? +1 : -1];
}
say 'Accuracy: ', evaluate_accuracy($classifier, \@test_set);

sub PI() { 4 * atan2(1, 1) }

sub adaboost {
  my ($training_set, $num_iterations) = @_;
  my @initial_distribution = (1 / @$training_set) x @$training_set;
  my ($initial_classifier, $initial_weight) =
    generate_linear_classifier(\@training_set, \@initial_distribution);
  my @weak_classifiers = (+{
    classifier => $initial_classifier,
    distribution => \@initial_distribution,
    weight => $initial_weight,
  });
  for (2 .. $num_iterations) {
    my $distribution = construct_hardest_distribution(
      \@training_set,
      $weak_classifiers[-1]{distribution},
      $weak_classifiers[-1]{classifier},
      $weak_classifiers[-1]{weight}
    );
    my ($classifier, $weight) =
      generate_linear_classifier(\@training_set, $distribution);
    push @weak_classifiers, +{
      classifier => $classifier,
      distribution => $distribution,
      weight => $weight,
    };
  }
  return sub {
    my ($x, $y) = @_;
    my $sum = sum map {
      $_->{weight} * $_->{classifier}($x, $y);
    } @weak_classifiers;
    return sign($sum);
  };
}

sub construct_hardest_distribution {
  my ($training_set, $prev_distribution, $prev_classifier, $prev_weight) = @_;
  my @distribution;
  for my $i (0 .. $#$prev_distribution) {
    my ($x, $y, $label) = @{ $training_set->[$i] };
    push @distribution, $prev_distribution->[$i]
      * exp(-$prev_weight * $label * $prev_classifier->($x, $y));
  }
  my $normalization_factor = sum @distribution;
  return [ map { $_ / $normalization_factor } @distribution ];
}

sub evaluate_accuracy {
  my ($classifier, $test_set, $distribution) = @_;
  $distribution //= [ (1 / @$test_set) x @$test_set ];
  my $accuracy = 0;
  for my $i (0 .. $#$test_set) {
    my ($x, $y, $label) = @{ $test_set->[$i] };
    $accuracy += $distribution->[$i] if $classifier->($x, $y) == $label;
  }
  return $accuracy;
}

sub _generate_linear_classifier {
  my ($radian, $label, @origin) = @_;
  my $inclination = sin($radian) / cos($radian);
  sub {
    my ($x, $y) = @_;
    ($x - $origin[0]) * $inclination + $origin[1] < $y ? $label : $label * -1;
  };
}

sub generate_linear_classifier {
  my ($training_set, $distribution) = @_;
  my $radian = rand(PI);
  my @origin = (0, 0);
  my @minmax_x = minmax(map { $_->[0] } @$training_set);
  my @minmax_y = minmax(map { $_->[1] } @$training_set);
  my $step_x = ($minmax_x[1] - $minmax_x[0]) / 2;
  my $step_y = ($minmax_y[1] - $minmax_y[0]) / 2;

  my $best_classifier;
  my $best_accuracy = 0;
  for (1 .. 32) {
    cartesian {
      my $classifier = _generate_linear_classifier($radian, @_);
      my $accuracy =
        evaluate_accuracy($classifier, $training_set, $distribution);
      if ($best_accuracy <= $accuracy) {
        $best_accuracy = $accuracy;
        $best_classifier = $classifier;
        @origin = @_[1, 2];
      }
    } (
      [+1, -1],
      [$origin[0] - $step_x, $origin[0], $origin[0] + $step_x],
      [$origin[1] - $step_y, $origin[1], $origin[1] + $step_y],
    );
  } continue {
    $step_x /= 2;
    $step_y /= 2;
  }
  warn $best_accuracy;
  my $weight = log($best_accuracy / (1 - $best_accuracy)) / 2;
  return ($best_classifier, $weight);
}

sub sign { $_[0] / abs $_[0] }

追記 (2012-10-22): 汎用化して Algorithm::AdaBoost として CPAN に登録しました。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1