LoginSignup
2
0

AHCで入力パラメータに応じて解法やハイパーパラメータを半自動調整して使い分ける方法

Last updated at Posted at 2023-10-23

0. はじめに

ヒューリスティック最適化において、特に今回(AHC025)などのような入力の値によって戦略が変わってくる問題の場合は、入力パラメータの値によって解法やハイパーパラメータを選択したくなります。
また、素朴にやるならルールベース的にif文で使い分けを実装すると思いますが、できれば入力に対して細やかに一番いい結果が期待できる解法を選択したいです。
下記の図のようなイメージです。

図1.png

今回は、あくまで一例ですが、そのような場合のやり方を書きます。

また実装もあくまで一例です。(言語はC++です。)

結論としては、ローカルでたくさんシミュレートしてハイパーパラメータを調整し、その値をハードコードで埋め込みます。

1. 入力パラメータを分割する

まず、入力パラメータを適当な大きさに分割します。

今回の問題の入力パラメータおよび入力の制約、生成方法は以下でした。

制約

N : 30 <= N <= 100
Q : 2N <= Q <= 32N
D : 2  <= D <= N/4

生成方法

N : 30~100までの整数の一様ランダム
Q : [1.0,5.0]の一様ランダムな実数を取り(仮にdとする)、N×2^dを四捨五入で整数に丸めたもの
D : 2~floor(D/4)までの整数の一様ランダム

これを下記のように、ざっくり 14×40×DivideSize_D (Dの分割サイズは可変)で約8000分割しました。

分割方法

N : 入力の取りうる値は30~100の71通りで、5刻みに(96~100は6刻み)14分割。
Q : 指数部分(1.0~5.0)を0.1刻みに40分割。
D : 1刻みに分割。(N=100のときDは2~25なので最大24分割)

ここで、AtCoderの提出コードの文字数制限は512KiBなので、埋め込みたいパラメータが64ビット整数で19桁だとしても数万個くらいは埋め込みが可能です。

ここまでをグローバルに埋め込みます。実装イメージは以下です。(実装を簡単にするためこの記事ではN,Q,Dやその他色々なパラメータをグローバルに定義します。)

#include <bits/stdc++.h>
using namespace std;
typedef long long int ll;

int N, Q, D;

// ハイパラのハードコード
// Dの最大値はNの値によって可変なため今回はvectorの2次元配列で定義する
// 中身は適当に1000000000000000000にしているが実際は初期値として意味のある値を記述する
vector<ll> hyperParams[14][40] = 
{
{
{1000000000000000000, 1000000000000000000, ..., 1000000000000000000},
{1000000000000000000, 1000000000000000000, ..., 1000000000000000000},
/* 省略 */
{1000000000000000000, 1000000000000000000, ..., 1000000000000000000}
}
};

int main(){
  return 0;
}

2. 配列hyperParamsの使い方

次にhyperParams配列の使い方を説明します。

まず、例えば以下の2種類の解法を自分が作ったとします。

// 解法1 ※parameterAはハイパラで0~99のいずれかの値を入れるとする
void Solver1(int parameterA){
  /* 省略 */
}

// 解法2 ※parameterB、parameterCはハイパラで0~99のいずれかの値を入れるとする
void Solver2(int parameterB, int parameterC){
  /* 省略 */
}

このようなとき、私は64ビット整数を使ってハイパーパラメータを以下のように埋め込んでいます。

ハイパーパラメータの埋め込み方式
下10桁はパラメータ用、11桁から上で解法の種類を指定する。
例として、あるN,Q,Dの値の場合、解法2を使ってparameterBに34、parameterCに78を渡したいとするならば、
hyperParams[NIndex][QIndex][DIndex]に20000003478を入れておく。

実装イメージは以下のようになります。(いくつか追加で関数を作成しています。)

// 10のx乗
const ll D1 = 10LL;
const ll D2 = 100LL;
const ll D3 = 1000LL;
const ll D4 = 10000LL;
const ll D5 = 100000LL;
const ll D6 = 1000000LL;
const ll D7 = 10000000LL;
const ll D8 = 100000000LL;
const ll D9 = 1000000000LL;
const ll D10 = 10000000000LL;
const ll D11 = 100000000000LL;
const ll D12 = 1000000000000LL;
const ll D13 = 10000000000000LL;
const ll D14 = 100000000000000LL;
const ll D15 = 1000000000000000LL;
const ll D16 = 10000000000000000LL;
const ll D17 = 100000000000000000LL;
const ll D18 = 1000000000000000000LL;

const int MAX_N = 100;
const int MAX_Q = 3200;
const int MAX_D = 25;
int N, Q, D;
int NIndex, QIndex, DIndex; // ハイパラを参照するためのインデックス
int ans[MAX_N];

vector<ll> hyperParams[14][40] = {/*省略*/}

// シミュレート用関数
// テストケースとその解答に対してスコア計算式に従ってans配列から実際のスコアを計算する
ll CalcScore(){
  /* 省略 */
}

// Nの値から0~14のいずれかの値を取るNIndexを計算してグローバル変数に記憶
// Qの値から0~39のいずれかの値を取るQIndexを計算してグローバル変数に記憶
// Dの値から0~23のいずれかの値を取るDIndexを計算してグローバル変数に記憶
void TranslateNQDIndexFromNQD(){
  /* 省略 */
}

// 解法1 ※parameterAはハイパラで0~99のいずれかの値を入れるとする
void Solver1(int parameterA){
  /* 省略 */
}

// 解法2 ※parameterB、parameterCはハイパラで0~99のいずれかの値を入れるとする
void Solver2(int parameterB, int parameterC){
  /* 省略 */
}

// 標準入力またはファイル入力で入力を受け取る
void Input(){
  /* 省略 */
}

// 1ケースの入力に対して解答処理を行う関数
// 戻り値はスコア(提出の場合は適当に0など)を返すようにする
// N,Q,D,NIndex,QIndex,DIndex,hyperParamsには値がセットされているものとする
ll Solve(){
  ll para = hyperParams[NIndex][QIndex][DIndex]; // ここで64ビット整数の値を取得

  ll selectedSolver = para % D10;
  
  /* 解法の数だけ記述する */
  if(selectedSolver == 1){
    ll paraA = para % D2;
    Solver(paraA);
  }else if(selectedSolver == 2){
    ll paraB = para % D4 / D2;
    ll paraC = para % D2;
    Solver(paraB, paraC);
  }

  return CalcScore(); 
}

int main(){
  Input();

  TranslateNQDIndexFromNQD(); // ここでN,Q,DからNIndex,QIndex,DIndexを計算しておく

  ll score = Solve(); // Solveはスコアを返すようにしておく

  return 0;
}

3. hyperParamsの改善(その1)

次に hyperParams を改善していきます。

やり方は、 新しいハイパラを作成してhyperParams配列に入っている値と複数回戦わせ、勝率が良ければ更新する。 を長時間回します。

具体的なやり方

具体的には、以下のやり方をします。

  1. NIndex,QIndex,DIndex を1ケースジェネレートする。また新しいハイパーパラメータを1つジェネレートする。
  2. 以下を決めた回数行い、勝ち/負け/引き分けの数をそれぞれカウントする。
    1. NIndex,QIndex,DIndexから、テストケースを1ケース作成する。(範囲に従うN,Q,Dおよび、隠しパラメータのN個のアイテムの重みジェネレートする。)
    2. 作成したテストケースに対して hyperParams[NIndex][QIndex][DIndex] と新しいハイパラのそれぞれで解法処理を行いそれぞれのスコアを計算する。
    3. 新しいパラメータが勝ち/負け/引き分けのいずれであるかを求める。(今回はスコアが小さければ勝ち。)
  3. あらかじめ決めた勝率を上回っていれば新しいハイパラで hyperParams[NIndex][QIndex][DIndex] を上書きする。

戦わせる回数と更新する閾値の勝率

適当ですが、自分は以下の3種類を主に使っています。

  • 20戦して18回以上(90%以上)勝ったら更新
  • 200戦して120回以上(60%以上)勝ったら更新
  • 1000戦して550回以上(55%以上)勝ったら更新

弱いパラメータが誤って採択されないようにするためにこのように(少し厳しめに?)しています。

上のものほどたくさん回せますが既存のパラメータに対してかなり強くないと更新できません。

ちなみに、50%の確率で表が出るコインを200回投げたときに120回以上表が出る確率等は以下のサイトで調べることができます。回数や勝率はここで値を確認して決めています。

ハイパラ同士の勝敗コード.cpp
// NIndex、QIndex、DIndexをジェネレートする
void GenerateNQDIndex(){
  NIndex = rand() % 14;
  QIndex = rand() % 40;
  DIndex = rand() % hyperParams[NIndex][QIndex].size(); // Dの最大値は可変なため
}

// NIndex、QIndex、DIndexを用いてN、Q、D、各アイテムの重さをジェネレートする
void GenerateCaseFromNQDIndex(){
  /* 省略 */
}

int main(){
  while(true){
    // NIndex,QIndex,DIndexを1ケースジェネレート
    GenerateNQDIndex();

    // 新しいハイパラ(挑戦者)を作成
    ll newPara = 0;
    int solverNum = rand() % 2 + 1; // 解法はとりあえず1or2
    /* 解法の数だけ記述する */
    if(solverNum == 1){
      ll paraA = rand() % 100; // parameterAは0~99
      newPara = solverNum * D10 + paraA;
    }else if(solverNum == 2){
      ll paraB = rand() % 100; // parameterBは0~99
      ll paraC = rand() % 100; // parameterCは0~99
      newPara = solverNum * D10 + paraB * D2 + paraC;
    }

    int win = 0;
    int lose = 0;
    int draw = 0;
    
    int judgeMode = 1;
    bool isWin = false;
    if(jedgeMode == 0){
      // 20戦18勝
      for(int _ = 0; _ < 20; _++){
        GenerateCaseFromNQDIndex(); // ここで毎回テストケースを1ケース作成
        // 既存パラメータでのスコア計算
        ll oldScore = Solve();
        // 新パラメータでのスコア計算
        ll oldPara = hyperParams[NIndex][QIndex][DIndex];
        hyperParams[NIndex][QIndex][DIndex] = newPara;
        ll newScore = Solve();
        hyperParams[NIndex][QIndex][DIndex] = oldPara;

        if(newScore < oldScore){
          // AHC025はスコアが小さいほど良い
          win++;
        }else if(newScore == oldScore){
          draw++;
        }else{
          lose++;
        }

        // 時間短縮(結果が明らかなときはbreak)
        if (lose >= 3) break;
      }

      if(win >= 3 && (double)win / (win + lose) >= 0.9) { // 1勝0敗19分などは避ける
        isWin = true;
      }
    }else if(judgeMode == 1){
      // 200戦120勝
      for(int _ = 0; _ < 200; _++){
        GenerateCaseFromNQDIndex(); // ここで毎回テストケースを1ケース作成
        // 既存パラメータでのスコア計算
        ll oldScore = Solve();
        // 新パラメータでのスコア計算
        ll oldPara = hyperParams[NIndex][QIndex][DIndex];
        hyperParams[NIndex][QIndex][DIndex] = newPara;
        ll newScore = Solve();
        hyperParams[NIndex][QIndex][DIndex] = oldPara;

        if(newScore < oldScore){
          // AHC025はスコアが小さいほど良い
          win++;
        }else if(newScore == oldScore){
          draw++;
        }else{
          lose++;
        }

        // 時間短縮(結果が明らかなときはbreak)
        if (win >= 120) break;
        if (lose >= 80) break;
        if (win <= lose - 10) break;
        if (lose >= 20 && win <= lose) break;
        if (win - lose >= 40) break;
      }

      if(win >= 10 && (double)win / (win + lose) >= 0.6) { // 3勝1敗116分などは避ける
        isWin = true;
      }
    }else if(judgeMode == 2){
      // 1000戦550勝
      for(int _ = 0; _ < 1000; _++){
        GenerateCaseFromNQDIndex(); // ここで毎回テストケースを1ケース作成
        // 既存パラメータでのスコア計算
        ll oldScore = Solve();
        // 新パラメータでのスコア計算
        ll oldPara = hyperParams[NIndex][QIndex][DIndex];
        hyperParams[NIndex][QIndex][DIndex] = newPara;
        ll newScore = Solve();
        hyperParams[NIndex][QIndex][DIndex] = oldPara;

        if(newScore < oldScore){
          // AHC025はスコアが小さいほど良い
          win++;
        }else if(newScore == oldScore){
          draw++;
        }else{
          lose++;
        }

        // 時間短縮(結果が明らかなときはbreak)
        if (win >= 550) break;
        if (lose >= 450) break;
        if (win - lose >= 100) break;
        if (win <= lose - 10) break;
        if (lose >= 20 && win <= lose) break;
      }

      if(win >= 30 && (double)win / (win + lose) >= 0.55) { 
        isWin = true;
      }
    }

    if(isWin){
      // hyperParamsを新パラメータで上書き
      hyperParams[NIndex][QIndex][DIndex] = newPara;
    }
  }

  return 0;
}

4. hyperParamsの改善(その2)

また、ある N Index、Q Index、D Index のハイパラが新しいパラメータに上書きされたときは、付近に対しても新しいパラメータは良い結果を出す可能性が高いことは容易に想像できます。

なので更新が発生したときは付近に伝播させ(付近とも戦わせ)ます。

具体的には、キューを導入して更新が発生したときは {次に戦わせたい場所, 試したいパラメータ} のペアをストックして、キューに中身があるときは優先してその中身のシミュレーションを行います。

下記のように実装すると、強いパラメータほど勝利による更新が発生しパラメータが伝播していきます。

ハイパラ更新実装例.cpp
int main(){
  queue<pair<int, ll>> winQueue; // {次に戦わせたい場所, 試したいパラメータ}のペアを入れる
  while(true){
    ll newPara = 0;
    if(!winQueue.empty()){
      // winQueueに中身がある場合はそちらからシミュレーション
      // 次に試したい場所
      int nextNQD = winQueue.front().first;
      NIndex = nextNQD / D4;
      QIndex = nextNQD % D4 / D2;
      DIndex = nextNQD % D2;
      // 試したいパラメータ
      newPara = winQueue.front().second;
      winQueue.pop();
    }else{
      // ランダムに1箇所選びシミュレーション
      GenerateNQDIndex();

      /* newParaを作成する実装は省略 */
    }

    // 同じ値同士のシミュレーションが頻繁に発生するので弾く
    if(hyperParams[NIndex][QIndex][DIndex] == newPara) continue;

    /* 戦わせる実装は省略 */

    if(isWin){
      // hyperParamsを新パラメータで上書き
      hyperParams[NIndex][QIndex][DIndex] = newPara;

      // 付近(各インデックス±1)をキューに入れる
      for(int i = 0; i < 6; i++) {
        int nextNIndex = NIndex;
        int nextQIndex = QIndex;
        int nextDIndex = DIndex;
        if      (i == 0) nextNIndex--;
        else if (i == 1) nextNIndex++;
        else if (i == 2) nextQIndex--;
        else if (i == 3) nextQIndex++;
        else if (i == 4) nextDIndex--;
        else if (i == 5) nextDIndex++;
      
        // 範囲チェック
        if (nextNIndex < 0 || 14 <= nextNIndex) continue;
        if (nextQIndex < 0 || 40 <= nextQIndex) continue;
        if (nextDIndex < 0 || hyperParams[nextNIndex][nextQIndex].size() <= nextDIndex) continue;

        // キューに入れる
        winQueue.push(make_pair(nextNIndex * D4 + nextQIndex * D2 + nextDIndex, newPara));
      }
    }
  }

  return 0;
}

5. hyperParamsをテキスト出力し、ソースコードに張り付ける

更新した hyperParams はテキストファイルで出力します。

出力した hyperParams は手動でソースコードにコピペします。

実装イメージ

// hyperParamsをテキストファイルに出力
void PrintHyperParams(int loop) {
  string str = "./hyperParams/hyperParams" + to_string(loop) + ".txt";
  std::ofstream file(str);

  // 出力したい文字列
  file << "vector<ll> hyperParams[14][40] = {" << endl;
  for(int _nIndex = 0; _nIndex < 14; _nIndex++) {
    file << "{" << endl;
    for(int _qIndex = 0; _qIndex < 40; _qIndex++) {
      file << "{";
      for(int _dIndex = 0; _dIndex < hyperParams[_nIndex][_qIndex].size(); _dIndex++) {
        file << setw(19) << hyperParams[_nIndex][_qIndex][_dIndex];
        if (_dIndex < hyperParams[_nIndex][_qIndex].size() - 1) {
          file << ",";
        }
      }
      if (_qIndex < 39) {
        file << "}," << endl;
      }
      else {
        file << "}" << endl;
      }
    }
    if (_nIndex < 13) {
      file << "}," << endl;
    }
    else {
      file << "}" << endl;
    }
  }
  file << "};" << endl;
  file << endl;

  file.close();
}

int main(){
  int loop = 0; // ループ回数のカウンタ
  queue<pair<int, ll>> winQueue;
  while(true){
    /* 前半は省略 */

    if(isWin){
      /* 中身は省略 */
    }

    loop++;
    if(loop % 100 == 0){ // 適当なタイミングで出力
      PrintHyperParams(loop);
    }
  }
  return 0;
}

6. 全体の実装

全体の実装を繋げると以下のような感じになります。

実装イメージ

#include <bits/stdc++.h>
using namespace std;
typedef long long int ll;

// 10のx乗
const ll D1 = 10LL;
const ll D2 = 100LL;
const ll D3 = 1000LL;
const ll D4 = 10000LL;
const ll D5 = 100000LL;
const ll D6 = 1000000LL;
const ll D7 = 10000000LL;
const ll D8 = 100000000LL;
const ll D9 = 1000000000LL;
const ll D10 = 10000000000LL;
const ll D11 = 100000000000LL;
const ll D12 = 1000000000000LL;
const ll D13 = 10000000000000LL;
const ll D14 = 100000000000000LL;
const ll D15 = 1000000000000000LL;
const ll D16 = 10000000000000000LL;
const ll D17 = 100000000000000000LL;
const ll D18 = 1000000000000000000LL;

const int MAX_N = 100;
const int MAX_Q = 3200;
const int MAX_D = 25;
int N, Q, D;
int NIndex, QIndex, DIndex; // ハイパラを参照するためのインデックス
int ans[MAX_N];

vector<ll> hyperParams[14][40] = 
{
{
{10000000077, 20000004365, ...,},
/* 以下省略 */
}
};

// hyperParamsをテキストファイルに出力
void PrintHyperParams(int loop) {
  string str = "./hyperParams/hyperParams" + to_string(loop) + ".txt";
  std::ofstream file(str);

  // 出力したい文字列
  file << "vector<ll> hyperParams[14][40] = {" << endl;
  for(int _nIndex = 0; _nIndex < 14; _nIndex++) {
    file << "{" << endl;
    for(int _qIndex = 0; _qIndex < 40; _qIndex++) {
      file << "{";
      for(int _dIndex = 0; _dIndex < hyperParams[_nIndex][_qIndex].size(); _dIndex++) {
        file << setw(19) << hyperParams[_nIndex][_qIndex][_dIndex];
        if (_dIndex < hyperParams[_nIndex][_qIndex].size() - 1) {
          file << ",";
        }
      }
      if (_qIndex < 39) {
        file << "}," << endl;
      }
      else {
        file << "}" << endl;
      }
    }
    if (_nIndex < 13) {
      file << "}," << endl;
    }
    else {
      file << "}" << endl;
    }
  }
  file << "};" << endl;
  file << endl;

  file.close();
}

// スコア計算(シミュレート用)
ll CalcScore(){/* 省略 */}

// N,Q,Dの値からNIndex,QIndex,DIndexの値を計算
void TranslateNQDIndexFromNQD(){/* 省略 */}

// NIndex、QIndex、DIndexをジェネレートする
void GenerateNQDIndex(){/* 省略 */}

// NIndex、QIndex、DIndexを用いてN、Q、D、各アイテムの重さをジェネレートする
void GenerateCaseFromNQDIndex(){/* 省略 */}

// 解法1、解法2、解法3、解法4、...
void Solver1(int parameterA){/* 省略 */}
void Solver2(int parameterB, int parameterC){/* 省略 */}

// 1ケースの入力に対して解答処理を行う関数
// N,Q,D,NIndex,QIndex,DIndex,hyperParamsには値がセットされているものとする
ll Solve(){
  ll para = haipar[NIndex][QIndex][DIndex]; // ここで64ビット整数の値を取得

  ll selectedSolver = para % D10;
  
  /* 解法の数だけ記述する */
  if(selectedSolver == 1){
    ll paraA = para % D2;
    Solver(paraA);
  }else if(selectedSolver == 2){
    ll paraB = para % D4 / D2;
    ll paraC = para % D2;
    Solver(paraB, paraC);
  }

  return CalcScore(); 
}

int main(){
  int loop = 0; // ループ回数のカウンタ
  queue<pair<int, ll>> winQueue;
  while(true){
    ll newPara = 0;
    if(!winQueue.empty()){
      // winQueueに中身がある場合はそちらからシミュレーション
      // 次に試したい場所
      int nextNQD = winQueue.front().first;
      NIndex = nextNQD / D4;
      QIndex = nextNQD % D4 / D2;
      DIndex = nextNQD % D2;
      // 試したいパラメータ
      newPara = winQueue.front().second;
      winQueue.pop();
    }else{
      // ランダムに1箇所選びシミュレーション
      // NIndex,QIndex,DIndexを1ケースジェネレート
      GenerateNQDIndex();

      // 新しいハイパラ(挑戦者)を作成
      int solverNum = rand() % 2 + 1; // 解法はとりあえず1or2
      /* 解法の数だけ記述する */
      if(solverNum == 1){
        ll paraA = rand() % 100; // parameterAは0~99
        newPara = solverNum * D10 + paraA;
      }else if(solverNum == 2){
        ll paraB = rand() % 100; // parameterBは0~99
        ll paraC = rand() % 100; // parameterCは0~99
        newPara = solverNum * D10 + paraB * D2 + paraC;
      }
    }

    // 同じ値の場合はcontinue
    if(hyperParams[NIndex][QIndex][DIndex] == newPara) continue;

    int win = 0;
    int lose = 0;
    int draw = 0;
    
    int judgeMode = 1;
    bool isWin = false;
    if(jedgeMode == 0){
      // 20戦18勝
      for(int _ = 0; _ < 20; _++){
        GenerateCaseFromNQDIndex(); // ここで毎回テストケースを1ケース作成
        ll oldScore = Solve();
        ll oldPara = hyperParams[NIndex][QIndex][DIndex];
        hyperParams[NIndex][QIndex][DIndex] = newPara;
        ll newScore = Solve();
        hyperParams[NIndex][QIndex][DIndex] = oldPara;

        if(newScore < oldScore){
          // AHC025はスコアが小さいほど良い
          win++;
        }else if(newScore == oldScore){
          draw++;
        }else{
          lose++;
        }

        // 時間短縮(結果が明らかなときはbreak)
        if (lose >= 3) break;
      }

      if(win >= 3 && (double)win / (win + lose) >= 0.9) { // 1勝0敗19分などは避ける
        isWin = true;
      }
    }else if(judgeMode == 1){
      // 200戦120勝
      for(int _ = 0; _ < 200; _++){
        GenerateCaseFromNQDIndex(); // ここで毎回テストケースを1ケース作成
         ll oldScore = Solve();
        ll oldPara = hyperParams[NIndex][QIndex][DIndex];
        hyperParams[NIndex][QIndex][DIndex] = newPara;
        ll newScore = Solve();
        hyperParams[NIndex][QIndex][DIndex] = oldPara;

        if(newScore < oldScore){
          // AHC025はスコアが小さいほど良い
          win++;
        }else if(newScore == oldScore){
          draw++;
        }else{
          lose++;
        }

        // 時間短縮(結果が明らかなときはbreak)
        if (win >= 120) break;
        if (lose >= 80) break;
        if (win <= lose - 10) break;
        if (lose >= 20 && win <= lose) break;
        if (win - lose >= 40) break;
      }

      if(win >= 10 && (double)win / (win + lose) >= 0.6) { // 3勝1敗116分などは避ける
        isWin = true;
      }
    }else if(judgeMode == 2){
      // 1000戦550勝
      for(int _ = 0; _ < 1000; _++){
        GenerateCaseFromNQDIndex(); // ここで毎回テストケースを1ケース作成
        ll oldScore = Solve();
        ll oldPara = hyperParams[NIndex][QIndex][DIndex];
        hyperParams[NIndex][QIndex][DIndex] = newPara;
        ll newScore = Solve();
        hyperParams[NIndex][QIndex][DIndex] = oldPara;

        if(newScore < oldScore){
          // AHC025はスコアが小さいほど良い
          win++;
        }else if(newScore == oldScore){
          draw++;
        }else{
          lose++;
        }

        // 時間短縮(結果が明らかなときはbreak)
        if (win >= 550) break;
        if (lose >= 450) break;
        if (win - lose >= 100) break;
        if (win <= lose - 10) break;
        if (lose >= 20 && win <= lose) break;
      }

      if(win >= 30 && (double)win / (win + lose) >= 0.55) { 
        isWin = true;
      }
    }

    if(isWin){
      // hyperParamsを新パラメータで上書き
      hyperParams[NIndex][QIndex][DIndex] = newPara;

      // 付近(各インデックス±1)をキューに入れる
      for(int i = 0; i < 6; i++) {
        int nextNIndex = NIndex;
        int nextQIndex = QIndex;
        int nextDIndex = DIndex;
        if      (i == 0) nextNIndex--;
        else if (i == 1) nextNIndex++;
        else if (i == 2) nextQIndex--;
        else if (i == 3) nextQIndex++;
        else if (i == 4) nextDIndex--;
        else if (i == 5) nextDIndex++;
      
        // 範囲チェック
        if (nextNIndex < 0 || 14 <= nextNIndex) continue;
        if (nextQIndex < 0 || 40 <= nextQIndex) continue;
        if (nextDIndex < 0 || hyperParams[nextNIndex][nextQIndex].size() <= nextDIndex) continue;

        // キューに入れる
        winQueue.push(make_pair(nextNIndex * D4 + nextQIndex * D2 + nextDIndex, newPara));
      }
    }

    loop++;
    if(loop % 100 == 0){ // 適当なタイミングで出力
      PrintHyperParams(loop);
    }
  }

  return 0;
}

7. おわりに

ここまでの実装を行うと以下のような解答ができます。

図1.png

読んでいただきありがとうございました。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0