LoginSignup
6
0

Elasticsearch Learning to Rankプラグインのコードリーディング

Last updated at Posted at 2023-12-18

はじめに

Elasticsearch Learning to Rankプラグイン (elasticsearch-learning-to-rank) は
機械学習モデル(XGBoost, Ranklib, Simple Linear)を使用しての検索結果のリスコアを行うことが出来るElasticsearchプラグインです。

今年の春頃にelasticsearch-learning-to-rankプラグインのデモを動かすところまで行いました(以下Qiita記事参照)。
もっと詳細に理解したいと思いつつもあまり調査が進まずにいたため、このタイミングでもっと詳細に理解すべくソースコードリーディングを行ったのでその内容をまとめていきます。

elasticsearch-learning-to-rankプラグインに関しては、以下など記事やTechBlogが既に公開されているため、内容を知りたい方はまずこちらを参照。

本記事について

プラグインの実装レベルの詳細な話は上記とは別のTechBlogに既にまとまっており、これを読めば大まかに理解できるものの、自分ではまだコードを読み込めていなかったため今回のテーマにしました。

内容としては以下となる。

  • パッケージ構成、関係図
  • プラグインで追加されるエンドポイント
  • 主要処理詳細
    • Elasticsearch起動時
    • ltr用index作成
    • リスコア処理

調査対象のElasticsearchとプラグインのバージョンは、2023/12/19時点のプラグインの最新バージョンである8.9.2となります。

パッケージ構成、関係図

ソースコードを見る前にelasticsearch-learning-to-rankのパッケージ構成がどうなっているかをパッケージの関連図を通して見ていきます。

packages.png

restやqueryパッケージなど具体的な処理が書かれているパッケージは上の方に、utilsやrankerなど多くのパッケージから参照されているパッケージは下の方に記述されています。
これを見るとパッケージ間にどのようなつながりがありそうかを見て取れます。
各パッケージの内容は以下。

パッケージ 内容
action FeatureStoreActionやCreateModelFromSetActionのようなrestパッケージから呼び出され具体的な処理が記述されたクラス
feature FeatureやFeatureSet, StoredLtrModelなどのデータ用クラス
logging LoggingSearchExtBuilderやLoggingFetchSubPhaseのような特徴量ログ出力関連のクラス
query StoredLtrQueryBuilder, StoredLtrQueryBuilder, RankerQueryなどクエリ関連のクラス
ranker LtrRankerや各種機械学習ライブラリ用のParser, Rankerクラス
rest RestStoreManagerやRestFeatureManagerのようなAPIリクエストを受け付けるハンドラ関連のクラス
stats LTRプラグイン用の統計情報をするための処理が記述されたクラス
utils ユーティリティクラス

プラグインで追加されるエンドポイント

elasticsearch-learning-to-rankのプラグインをインストールすると、ElasticsearchへのLTR用エンドポイントが追加されてリクエスト出来るようになります。
追加されるエンドポイントは以下の通り。これらのエンドポイントへリクエストしつつ処理を追っていくと理解が深まりやすいと思います。

<特徴量セット、モデル関連>

エンドポイント 説明
PUT _ltr ストアを作成する
* 特徴量セットやモデルなどの情報を保存するための .ltrstore インデックスを作成する
DELETE _ltr ストアを削除する
* .ltrstore インデックスを削除してプラグインをリセットする
POST _ltr/_featureset/<特徴量セット名> 特徴量セットを作成(定義)する
DELETE _ltr/_featureset/<特徴量セット名> 特徴量セットを削除する
GET _ltr/_featureset/<特徴量セット名> 特徴量セットを取得する
POST _ltr/_featureset/<特徴量セット名>/_createmodel モデルを作成(アップロード)する
DELETE _ltr/_model/<モデル名> モデルを削除する
GET _ltr/_model/<モデル名> モデルを取得する
POST /_ltr/_feature/<特徴量> 特徴量を作成(定義)する。用途は再利用で利用するため
DELETE /_ltr/_feature/<特徴量> 特徴量を削除する
GET /_ltr/_feature/<特徴量> 特徴量を取得する
POST /_ltr/_featureset/<特徴量セット名>/_addfeatures/<特徴量> 特徴量セットに指定した特徴量を追加・更新する
  • ※検索結果のリスコア処理が行われる検索用エンドポイント (/{index}/_search) はプラグイン無しでも存在しているためこの一覧には記載していない。

<統計情報、キャッシュ関連>

エンドポイント 説明
GET _ltr/_stats,
GET _ltr/_stats/<stat>,
GET _ltr/_stats/nodes/<nodeID>,
GET _ltr/_stats/<stat>/nodes/<nodeID>
ノードやクラスター別のプラグイン用統計情報を取得する
GET _ltr/_cachestats プラグイン用キャッシュとメモリの使用状況についての統計を返却する
POST _ltr/_clearcache プラグイン用キャッシュをクリアする

主要処理詳細

それでは処理を順に見ていく。今回は以下3項目をピックアップして整理しました。

  • Elasticsearch起動時
  • ltr用index作成
  • リスコア処理

特徴量セットやモデル周りの処理の実装レベルの記載も出来ればよかったのですが、ボリュームが出てしまうため省いています。

処理を追うにあたり、Elasticsearchとelasticsearch-learning-to-rankのソースコードの両方を見る必要があり、行き来しますがご了承ください。

Elasticsearch起動時

まず、Elasticsearch起動時にプラグイン関連でどのような処理が行われているか、主な処理としては以下3点があります。
LtrQueryParserPluginインスタンスを生成して、LtrQueryParserPluginクラスで定義されているgetメソッドが呼び出され登録される形となっています。

ここで、前提としてElasticsearchにelasticsearch-learning-to-rankプラグインが既にインストールされているものとして進めます。

▼ LtrQueryParserPluginインスタンスの生成

  • PluginServiceの (Plugin) constructor.newInstance(settings) (code) の呼び出しで生成
  • このタイミングでLtrRankerParser(モデルパース用)の登録も行う。Ranklib, Linear, XGboostがサポートされている
    public LtrQueryParserPlugin(Settings settings) {
        caches = new Caches(settings);
        // Use memoize to Lazy load the RankerFactory as it's a heavy object to construct
        Supplier<RankerFactory> ranklib = Suppliers.memoize(RankerFactory::new);
        parserFactory = new LtrRankerParserFactory.Builder()
                .register(RanklibModelParser.TYPE, () -> new RanklibModelParser(ranklib.get()))
                .register(LinearRankerParser.TYPE, LinearRankerParser::new)
                .register(XGBoostJsonParser.TYPE, XGBoostJsonParser::new)
                .build();
    }

▼ QueryParserの登録

  • プラグインロード時の SearchModule searchModule = new SearchModule(settings, pluginsService.filterPlugins(SearchPlugin.class)); (code)で以下処理が呼び出される
public List<QuerySpec<?>> getQueries() {
        return asList(
                new QuerySpec<>(ExplorerQueryBuilder.NAME, ExplorerQueryBuilder::new, ExplorerQueryBuilder::fromXContent),
                new QuerySpec<>(LtrQueryBuilder.NAME, LtrQueryBuilder::new, LtrQueryBuilder::fromXContent),
                new QuerySpec<>(StoredLtrQueryBuilder.NAME,
                        (input) -> new StoredLtrQueryBuilder(getFeatureStoreLoader(), input),
                        (ctx) -> StoredLtrQueryBuilder.fromXContent(getFeatureStoreLoader(), ctx)),
                new QuerySpec<>(TermStatQueryBuilder.NAME, TermStatQueryBuilder::new, TermStatQueryBuilder::fromXContent),
                new QuerySpec<>(ValidatingLtrQueryBuilder.NAME,
                        (input) -> new ValidatingLtrQueryBuilder(input, parserFactory),
                        (ctx) -> ValidatingLtrQueryBuilder.fromXContent(ctx, parserFactory)));
    }

▼ プラグインで追加されるエンドポイント用のHandler登録

  • plugin.getRestHandlersで取得し、 registerHandler.accept(handler) (code)で登録される
  • 以下コードを見るとrestパッケージ関連のクラスが登録されている
    @Override
    public List<RestHandler> getRestHandlers(Settings settings, RestController restController,
                                             ClusterSettings clusterSettings, IndexScopedSettings indexScopedSettings,
                                             SettingsFilter settingsFilter, IndexNameExpressionResolver indexNameExpressionResolver,
                                             Supplier<DiscoveryNodes> nodesInCluster) {
        List<RestHandler> list = new ArrayList<>();

        for (String type : ValidatingLtrQueryBuilder.SUPPORTED_TYPES) {
            list.add(new RestFeatureManager(type));
            list.add(new RestSearchStoreElements(type));
        }
        list.add(new RestStoreManager());

        list.add(new RestFeatureStoreCaches());
        list.add(new RestCreateModelFromSet());
        list.add(new RestAddFeatureToSet());
        list.add(new RestLTRStats());
        return unmodifiableList(list);
    }

LTR用indexの作成

Elasticsearch起動が無事完了すると PUT _ltr リクエストが行えるようになり、 .ltrstore という名のLTR用インデックスを作成することが出来ます。
このリクエスト用の処理は RestStoreManager で行われ、CreateIndexRequestを生成してElasticsearchにリクエストします。
PUT _ltr のリクエストした結果は以下となります。

PUT _ltr

{
  "acknowledged" : true,
  "shards_acknowledged" : true,
  "index" : ".ltrstore"
}

これにより、 .ltrstoreインデックスが作成され、GET .ltrstore リクエストを行うと通常インデックスと同様にmappingやsettingsを取得することが出来ます。
以下はmappingを取得した結果となります。

  • feature, featureset, modelはobject型で定義されている
GET .ltrstore/_mapping

{
    ".ltrstore": {
        "mappings": {
            "dynamic": "strict",
            "properties": {
                "feature": {
                    "type": "object",
                    "enabled": false
                },
                "featureset": {
                    "type": "object",
                    "enabled": false
                },
                "model": {
                    "type": "object",
                    "enabled": false
                },
                "name": {
                    "type": "text",
                    "fields": {
                        "prefix": {
                            "type": "text",
                            "analyzer": "name_prefix",
                            "search_analyzer": "name_prefix_search"
                        }
                    },
                    "analyzer": "keyword"
                },
                "type": {
                    "type": "keyword"
                }
            }
        }
    }
}

.ltrstoreのmapping, settingsは以下で実装されており固定の定義が用意されています。

private static String readResourceFile(String indexName, String resource) {
    try (InputStream is = IndexFeatureStore.class.getResourceAsStream(resource)) {
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        Streams.copy(is.readAllBytes(), out);
        return out.toString(StandardCharsets.UTF_8.name());
    } catch (Exception e) {
・・・
  • settings
private static Settings storeIndexSettings(String indexName) {
    return Settings.builder()
            .put(IndexMetadata.INDEX_NUMBER_OF_SHARDS_SETTING.getKey(), 1)
            .put(IndexMetadata.INDEX_AUTO_EXPAND_REPLICAS_SETTING.getKey(), "0-2")
            .put(STORE_VERSION_PROP.getKey(), VERSION)
            .put(IndexMetadata.SETTING_PRIORITY, Integer.MAX_VALUE)
            .put(IndexMetadata.SETTING_INDEX_HIDDEN, true)
            .put(Settings.builder()
                    .loadFromSource(readResourceFile(indexName, ANALYSIS_FILE), XContentType.JSON)
                    .build())
            .build();
}

リスコア計算

コードリーディングする中で、一番気になったのがリスコア計算をどう行っているかについて気になったのでその部分を中心に見ていきたいと思います。
このあたりは、Elasticsearchのリスコア機能自体の実装でもあり、またElasticsearch LTRプラグインと特徴量キャッシュ機能の基本 - ZOZO TECH BLOGでもご説明のあった処理となります。

Elasticsearchに検索リクエスト (/_search) を行うと、まず最初に以下コードでモデルがロードされます。

  • RanklibModelParser でモデルをロード
  • デモに続きRanklibを用いているため、RanklibModel関連のクラスのコードを通過
    @Override
    public LtrRanker parse(FeatureSet set, String model) {
        Ranker ranklibRanker = factory.loadRankerFromString(model);
        int numFeatures = ranklibRanker.getFeatures().length;
        if (set != null) {
            numFeatures = set.size();
        }
        return new RanklibRanker(ranklibRanker, numFeatures);
    }

リスコア計算処理は以下のQueryRescorerのrescoreメソッドとなります。

  @Override
  public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN)
      throws IOException {
    ScoreDoc[] hits = firstPassTopDocs.scoreDocs.clone();

    Arrays.sort(
        hits,
        new Comparator<ScoreDoc>() {
          @Override
          public int compare(ScoreDoc a, ScoreDoc b) {
            return a.doc - b.doc;
          }
        });

    List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();

    Query rewritten = searcher.rewrite(query);
    Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);

    // Now merge sort docIDs from hits, with reader's leaves:
    int hitUpto = 0;
    int readerUpto = -1;
    int endDoc = 0;
    int docBase = 0;
    Scorer scorer = null;

    while (hitUpto < hits.length) {
      ScoreDoc hit = hits[hitUpto];
      int docID = hit.doc;
      LeafReaderContext readerContext = null;
      while (docID >= endDoc) {
        readerUpto++;
        readerContext = leaves.get(readerUpto);
        endDoc = readerContext.docBase + readerContext.reader().maxDoc();
      }

      if (readerContext != null) {
        // We advanced to another segment:
        docBase = readerContext.docBase;
        scorer = weight.scorer(readerContext);
      }

      if (scorer != null) {
        int targetDoc = docID - docBase;
        int actualDoc = scorer.docID();
        if (actualDoc < targetDoc) {
          actualDoc = scorer.iterator().advance(targetDoc);
        }

        if (actualDoc == targetDoc) {
          // Query did match this doc:
          hit.score = combine(hit.score, true, scorer.score());
        } else {
          // Query did not match this doc:
          assert actualDoc > targetDoc;
          hit.score = combine(hit.score, false, 0.0f);
        }
      } else {
        // Query did not match this doc:
        hit.score = combine(hit.score, false, 0.0f);
      }

      hitUpto++;
    }

▼ combineメソッド

この中で最終スコアを算出している箇所が以下のcombineメソッド部分となります。

        if (actualDoc == targetDoc) {
          // Query did match this doc:
          hit.score = combine(hit.score, true, scorer.score());
        } else {

combineメソッドの実装はQueryRescorerで定義されており以下のような実装になっています。

  • combineメソッドの引数は以下
    • 第一引数:クエリにマッチしたドキュメントのスコア
    • 第二引数:secondPassMatches(rescoreスコアを組み合わせるかどうかのフラグ)
    • 第三引数:RankerScorer
  • firstPassScoreとsecondPassScoreの値をscoreModeのcombineメソッドに渡している
  • scoreModeはTotalがセットされている
public final class QueryRescorer implements Rescorer {

    public static final Rescorer INSTANCE = new QueryRescorer();

    @Override
    public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext rescoreContext) throws IOException {

            @Override
            protected float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore) {
                if (secondPassMatches) {
                    return rescore.scoreMode.combine(
                        firstPassScore * rescore.queryWeight(),
                        secondPassScore * rescore.rescoreQueryWeight()
                    );
                }
                // TODO: shouldn't this be up to the ScoreMode? I.e., we should just invoke ScoreMode.combine, passing 0.0f for the
                // secondary score?
                return firstPassScore * rescore.queryWeight();
            }
・・・

ScoreMode TotalのcombineメソッドはQueryRescoreModeで定義されており足し合わせ。

public enum QueryRescoreMode implements Writeable {
	・・・
    Total {
        @Override
        public float combine(float primary, float secondary) {
            return primary + secondary;
        }
・・・

▼ elasticsearch-learning-to-rankによるリスコア計算

機械学習モデルを用いてのリスコア計算はRankerQueryのサブクラスであるRankerQueryで行われます。

public class RankerQuery extends Query {
	・・・
    public static class RankerWeight extends Weight {
		・・・
        class RankerScorer extends Scorer {
			・・・
            @Override
            public float score() throws IOException {
                fv = ranker.newFeatureVector(fv);
                if (featureScoreCache == null) {  // Cache disabled
                    int ordinal = -1;
                    // a DisiPriorityQueue could help to avoid
                    // looping on all scorers
                    for (Scorer scorer : scorers) {
                        ordinal++;
                        // FIXME: Probably inefficient, again we loop over all scorers..
                        if (scorer.docID() == docID()) {
                            // XXX: bold assumption that all models are dense
                            // do we need a some indirection to infer the featureId?
                            fv.setFeatureScore(ordinal, scorer.score());
                        }
                    }
					・・・
                return ranker.score(fv);
            }
・・・

ここで、setFeatureScoreメソッドはDenseProgramaticDataPointで以下のように定義されており、feature毎のスコアを格納します。

public class DenseProgramaticDataPoint extends DataPoint implements LtrRanker.FeatureVector {
	・・・
    @Override
    public void setFeatureScore(int featureIdx, float score) {
        // add 1 because RankLib features 1 based
        this.setFeatureValue(featureIdx+1, score);
    }
	・・・
}

一番最後の return ranker.score(fv)FVLtrRankerWrapperのscoreメソッドの引数にFeatureVectorをセットして計算します。以下のwrappedがRanklibRankerにあたります。

public class RankerQuery extends Query {
 	・・・
    static class FVLtrRankerWrapper implements LtrRanker {
		・・・
        @Override
        public float score(FeatureVector point) {
            return wrapped.score(point);
        }
・・・

RanklibRankerでのスコア計算は以下で行っているのですが、 ciir.umass.edu.learning.Ranker のevalまでは追いきれませんでした。

import ciir.umass.edu.learning.Ranker;
・・・

public class RanklibRanker implements LtrRanker {
	private final Ranker ranker;
	・・・
    @Override
    public float score(FeatureVector point) {
        assert point instanceof DenseProgramaticDataPoint;
        return (float) ranker.eval((DenseProgramaticDataPoint) point);
    }
}

おわりに

elasticsearch-learning-to-rankの詳細についてまだまだ細かいところまで見切れてはいないですが、実装レベルである程度知ることが出来ました。
リスコア周りはfirstPassScoreとsecondPassScoreのスコアを足し合わせるようなことをしていることも見ることが出来ました。
ただし、処理の複雑なところがあり、理解が追いつかず「こうなってました」で終わってしまったので、推測や工夫している箇所を読み取れるくらいまで理解を進めたいと思いました。
あと今回はelasticsearch-learning-to-rankプラグインをデバッグしながら読み進めることが出来ました。プラグインのデバッグをどうやったかについても近いうちに投稿出来ればと思ってます。

6
0
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
0