LoginSignup
0
0

ReplicateモデルをLangchain越しに使いたいとき

Last updated at Posted at 2024-05-02

はじめに

OpenAIやAnthropicのAPIだけを使って満足していませんか?

Replicateとは、クローズドモデル以外のオープンモデル1のAIを簡単にAPI経由で実行できるサービスです。この記事では、ReplicateをLangchainと組み合わせて使用した際のメモを紹介します。Langchainは、分散型マシン学習モデルと分散型データソースを統合するためのフレームワークであり、自然言語処理(NLP)の応用に特に適しています。

Langchainから使うとき

LangchainでReplicate越しにモデルを使う場合は、replicateライブラリと異なりversionが必須。

Screenshot from 2024-05-01 09-55-43.png

E import { Replicate  } from "npm:@langchain/community/llms/replicate";     ■ Cannot find module 'npm:@langchain/community/llms/replicate' or its
const model = new Replicate({
E   model: "meta/llama-2-7b-chat",     ■ Type '"meta/llama-2-7b-chat"' is not assignable to type '${string}/${string}:${string}'.

${string}/${string}:${string} 型が欲しいと仰る…

mistralai/mistral-7b-instruct-v0.2

を指定したとき

error: Uncaught (in promise) ApiError: Request to https://api.replicate.com/v1/models
/mistralai/mistral-7b-instruct-v0.2/versions/undefined failed with status 404 Not Fou
nd: {"detail":"Not found."}.

というエラーが返るから、

https://api.replicate.com/v1/models /mistralai/mistral-7b-instruct-v0.2/versions/undefinedというURLが生成されている。

undefinedというバージョンは存在しない

error: Uncaught (in promise) ApiError: Request to https://api.replicate.com/v1/models
/mistralai/mistral-7b-instruct-v0.2/versions/latest failed with status 404 Not Found:
 {"detail":"Not found."}.

latestというバージョンも存在しない。ハッシュ値のような文字列がバージョン情報に必要?

npm Replicate を使った場合

error: Uncaught (in promise) ApiError: Request to https://api.replicate.com/v1/model
s/mistralai/mistral-7b-instruct-v0.2/predictions failed with status 422 Unprocessabl
e Entity: {"detail":"- Additional property max_tokens is not allowed\n","status":422
,"title":"Input validation failed","invalid_fields":[{"type":"additional_property_no
t_allowed","field":"","description":"Additional property max_tokens is not allowed"}
]}

https://api.replicate.com/v1/model s/mistralai/mistral-7b-instruct-v0.2/predictions このようなURLが使用される。

versionsをpredictionsに置き換えてくれればいいのに

versionが書いてあるところを見つけた

Screenshot from 2024-05-01 06-32-05.png

画像の選択範囲。

要するに、JSONスキーマのversionプロパティを使う。これでlangchainからReplicate APIを使って返答を得られるようになった。

import { Replicate } from "npm:@langchain/community/llms/replicate"
const name = "meta/llama-2-13b-chat"
const version = "6b4da803a2382c08868c5af10a523892f38e2de1aafb2ee55b020d9efef2fdb8"
// モデルを指定
const model = new Replicate({
	model: `${name}:${version}`
})
// 回答を取得
const res = await model.invoke("hello world")
console.log(res)

meta/llama-2-13b-chat – Replicate

モデルのboot状態

Screenshot from 2024-05-01 06-36-36.png

❤️‍🔥Warmとなっていれば、直ちに回答を生成してくれるが、

Screenshot from 2024-05-01 06-39-21.png

中には画像のように❄️Coldと表示されているものもある。

Coldモデルはboot, すなわちモデルを立ち上げてから実行するので回答の生成に数十秒から数分待たされるので注意。

最近の人気のある公式モデルであれば❤️‍🔥Warm状態であることが多いが、個人が立ち上げているようなものや、公式であっても古いモデルは❄️Coldと表示されていることが多い2

詳細

We have a huge catalogue of models. To make good use of resources, we only run the models that are actually being used. When a model hasn’t been used for a little while, we turn it off.

When you make a request to run a prediction on a model, you’ll get a fast response if the model is “warm” (already running), and a slower response if the model is “cold” (starting up). Machine learning models are often very large and resource intensive, and we have to fetch and load several gigabytes of code for some models. In some cases this process can take several minutes.

Cold boots can also happen when there’s a big spike in demand. We autoscale by running multiple copies of a model on different machines, but the model can take a while to become ready.

For popular public models, cold boots are uncommon because the model is kept “warm” from all the activity. For less-frequently used models, cold boots are more frequent.

If you’re using the API to create predictions in the background, then cold boots probably aren’t a big deal: we only charge for the time that your prediction is actually running, so it doesn’t affect your costs.

我々は膨大なモデルのカタログを持っている。リソースを有効に活用するために、実際に使用されているモデルだけを動かしています。しばらく使われていないモデルは、電源を切っています。

モデルに対して予測を実行するリクエストを出すと、モデルが「ウォーム」(すでに実行中)であれば速いレスポンスが得られ、モデルが「コールド」(起動中)であれば遅いレスポンスが得られる。機械学習モデルは非常に大きく、リソースを大量に消費することが多いため、モデルによっては数ギガバイトのコードをフェッチしてロードしなければなりません。場合によっては、この処理に数分かかることもある。

コールドブートは、需要が急増したときにも発生します。モデルの複数のコピーを異なるマシンで実行することでオートスケールを行いますが、モデルの準備が整うまでに時間がかかることがあります。

人気のある公開モデルの場合、コールドブーツは珍しい。使用頻度の低いモデルでは、コールドブーツが頻繁に発生します。

APIを使用してバックグラウンドで予測を作成している場合、コールドブーツはおそらく大きな問題ではありません。

APIからモデルのバージョン情報を取得する

バージョンをリストアップするAPIが用意されていた3

Replicate公式のAPIリファレンスに取得の仕方が載っていたので引用します。

GET https://api.replicate.com/v1/models/{model_owner}/{model_name}/versions

Example cURL request:

curl -s \
 -H "Authorization: Bearer <paste-your-token-here>" \
https://api.replicate.com/v1/models/replicate/hello-world/versions

meta/llama2-7b-chatモデルのバージョンを取得してみます。

なお、事前にAPIキーを発行4し、REPLICATE_API_TOKENにセットしておく必要があります。

$ curl -s -H "Authorization: Bearer ${REPLICATE_API_TOKEN}" \
	[https://api.replicate.com/v1/models/meta/llama-2-7b-chat/versions](https://api.replicate.com/v1/models/meta/llama-2-7b-chat/versions)

エラーが出ました。

{
  "status": 404,
  "title": "Not found",
  "detail": "This model is an official model and does not expose a list of versions."
}

エラーによると、公式モデルはlist versionを公開していないようです。

versionsを取って再びGETしてみますと866行のJSONが返ってきました。

$ curl -s -H "Authorization: Bearer ${REPLICATE_API_TOKEN}" \
	https://api.replicate.com/v1/models/meta/llama-2-7b-chat | jq | head

jqでパースして最初の行だけ見るとこんな感じ

{
  "cover_image_url": "https://tjzk.replicate.delivery/models_models_cover_image/33ad1ead-8954-4b3b-bd46-125c7e18143f/llama-logo
.png",
  "created_at": "2023-07-14T06:11:39.905511Z",
  "default_example": {
    "completed_at": "2023-07-22T00:58:11.785972Z",
    "created_at": "2023-07-22T00:57:53.459888Z",
    "error": null,
    "id": "7xhvtyjbutohmf2ia5k2men57y",
    "input": {
      "top_p": 1,

jtrでパースしてみるとこんな構造のJSON

github.com/u1and0/jtr: 自作のJSONツリービューワーコマンド

.
├── cover_image_url <string>
├── created_at <string>
├── default_example
│   ├── completed_at <string>
│   ├── created_at <string>
│   ├── error <null>
│   ├── id <string>
│   ├── input
│   │   ├── top_p <int>
│   │   ├── prompt <string>
│   │   ├── temperature <float>
│   │   ├── system_prompt <string>
│   │   ├── max_new_tokens <int>
│   │   └── repetition_penalty <int>
│   ├── logs <string>
│   ├── metrics
│   │   ├── predict_time <float>
│   │   └── total_time <float>
│   ├── output []string
│   ├── started_at <string>
│   ├── status <string>
│   ├── urls
│   │   ├── get <string>
│   │   └── cancel <string>
│   ├── model <string>
│   ├── version <string>
│   └── webhook_completed <null>
├── description <string>
├── github_url <string>
├── latest_version
│   ├── id <string>
│   ├── created_at <string>
│   ├── cog_version <string>
│   └── openapi_schema
│       ├── info
│       │   ├── title <string>
│       │   └── version <string>
│       ├── paths
│       │   ├── /
│       │   │   └── get
│       │   │       ├── summary <string>
│       │   │       ├── responses
│       │   │       │   └── 200
│       │   │       │       ├── content
│       │   │       │       │   └── application/json
│       │   │       │       │       └── schema
│       │   │       │       │           └── title <string>
│       │   │       │       └── description <string>
│       │   │       └── operationId <string>
│       │   ├── /shutdown
│       │   │   └── post
│       │   │       ├── summary <string>
│       │   │       ├── responses
│       │   │       │   └── 200
│       │   │       │       ├── content
│       │   │       │       │   └── application/json
│       │   │       │       │       └── schema
│       │   │       │       │           └── title <string>
│       │   │       │       └── description <string>
│       │   │       └── operationId <string>
│       │   ├── /predictions
│       │   │   └── post
│       │   │       ├── summary <string>
│       │   │       ├── responses
│       │   │       │   ├── 200
│       │   │       │   │   ├── content
│       │   │       │   │   │   └── application/json
│       │   │       │   │   │       └── schema
│       │   │       │   │   │           └── $ref <string>
│       │   │       │   │   └── description <string>
│       │   │       │   └── 422
│       │   │       │       ├── content
│       │   │       │       │   └── application/json
│       │   │       │       │       └── schema
│       │   │       │       │           └── $ref <string>
│       │   │       │       └── description <string>
│       │   │       ├── parameters [].
│       │   │       │     ├── in <string>
│       │   │       │     ├── name <string>
│       │   │       │     ├── schema
│       │   │       │     │   ├── type <string>
│       │   │       │     │   └── title <string>
│       │   │       │     └── required <bool>
│       │   │       ├── description <string>
│       │   │       ├── operationId <string>
│       │   │       └── requestBody
│       │   │           └── content
│       │   │               └── application/json
│       │   │                   └── schema
│       │   │                       └── $ref <string>
│       │   ├── /health-check
│       │   │   └── get
│       │   │       ├── summary <string>
│       │   │       ├── responses
│       │   │       │   └── 200
│       │   │       │       ├── content
│       │   │       │       │   └── application/json
│       │   │       │       │       └── schema
│       │   │       │       │           └── title <string>
│       │   │       │       └── description <string>
│       │   │       └── operationId <string>
│       │   ├── /predictions/{prediction_id}
│       │   │   └── put
│       │   │       ├── summary <string>
│       │   │       ├── responses
│       │   │       │   ├── 200
│       │   │       │   │   ├── content
│       │   │       │   │   │   └── application/json
│       │   │       │   │   │       └── schema
│       │   │       │   │   │           └── $ref <string>
│       │   │       │   │   └── description <string>
│       │   │       │   └── 422
│       │   │       │       ├── content
│       │   │       │       │   └── application/json
│       │   │       │       │       └── schema
│       │   │       │       │           └── $ref <string>
│       │   │       │       └── description <string>
│       │   │       ├── parameters [].
│       │   │       │     ├── in <string>
│       │   │       │     ├── name <string>
│       │   │       │     ├── schema
│       │   │       │     │   ├── type <string>
│       │   │       │     │   └── title <string>
│       │   │       │     └── required <bool>
│       │   │       ├── description <string>
│       │   │       ├── operationId <string>
│       │   │       └── requestBody
│       │   │           ├── content
│       │   │           │   └── application/json
│       │   │           │       └── schema
│       │   │           │           ├── allOf [].
│       │   │           │           │     └── $ref <string>
│       │   │           │           └── title <string>
│       │   │           └── required <bool>
│       │   └── /predictions/{prediction_id}/cancel
│       │       └── post
│       │           ├── summary <string>
│       │           ├── responses
│       │           │   ├── 200
│       │           │   │   ├── content
│       │           │   │   │   └── application/json
│       │           │   │   │       └── schema
│       │           │   │   │           └── title <string>
│       │           │   │   └── description <string>
│       │           │   └── 422
│       │           │       ├── content
│       │           │       │   └── application/json
│       │           │       │       └── schema
│       │           │       │           └── $ref <string>
│       │           │       └── description <string>
│       │           ├── parameters [].
│       │           │     ├── in <string>
│       │           │     ├── name <string>
│       │           │     ├── schema
│       │           │     │   ├── type <string>
│       │           │     │   └── title <string>
│       │           │     └── required <bool>
│       │           ├── description <string>
│       │           └── operationId <string>
│       ├── openapi <string>
│       └── components
│           └── schemas
│               ├── Input
│               │   ├── type <string>
│               │   ├── title <string>
│               │   ├── required []string
│               │   └── properties
│               │       ├── seed
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   ├── x-order <int>
│               │       │   └── description <string>
│               │       ├── debug
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   ├── default <bool>
│               │       │   ├── x-order <int>
│               │       │   └── description <string>
│               │       ├── top_k
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   ├── default <int>
│               │       │   ├── minimum <float>
│               │       │   ├── x-order <int>
│               │       │   └── description <string>
│               │       ├── top_p
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   ├── default <float>
│               │       │   ├── maximum <float>
│               │       │   ├── minimum <float>
│               │       │   ├── x-order <int>
│               │       │   └── description <string>
│               │       ├── prompt
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   ├── x-order <int>
│               │       │   └── description <string>
│               │       ├── temperature
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   ├── default <float>
│               │       │   ├── maximum <float>
│               │       │   ├── minimum <float>
│               │       │   ├── x-order <int>
│               │       │   └── description <string>
│               │       ├── system_prompt
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   ├── default <string>
│               │       │   ├── x-order <int>
│               │       │   └── description <string>
│               │       ├── max_new_tokens
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   ├── default <int>
│               │       │   ├── minimum <float>
│               │       │   ├── x-order <int>
│               │       │   └── description <string>
│               │       ├── min_new_tokens
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   ├── default <int>
│               │       │   ├── minimum <float>
│               │       │   ├── x-order <int>
│               │       │   └── description <string>
│               │       ├── stop_sequences
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   ├── x-order <int>
│               │       │   └── description <string>
│               │       ├── replicate_weights
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   ├── x-order <int>
│               │       │   └── description <string>
│               │       └── repetition_penalty
│               │           ├── type <string>
│               │           ├── title <string>
│               │           ├── default <float>
│               │           ├── minimum <float>
│               │           ├── x-order <int>
│               │           └── description <string>
│               ├── Output
│               │   ├── type <string>
│               │   ├── items
│               │   │   └── type <string>
│               │   ├── title <string>
│               │   ├── x-cog-array-type <string>
│               │   └── x-cog-array-display <string>
│               ├── Status
│               │   ├── enum []string
│               │   ├── type <string>
│               │   ├── title <string>
│               │   └── description <string>
│               ├── WebhookEvent
│               │   ├── enum []string
│               │   ├── type <string>
│               │   ├── title <string>
│               │   └── description <string>
│               ├── ValidationError
│               │   ├── type <string>
│               │   ├── title <string>
│               │   ├── required []string
│               │   └── properties
│               │       ├── loc
│               │       │   ├── type <string>
│               │       │   ├── items
│               │       │   │   └── anyOf [].
│               │       │   │         └── type <string>
│               │       │   └── title <string>
│               │       ├── msg
│               │       │   ├── type <string>
│               │       │   └── title <string>
│               │       └── type
│               │           ├── type <string>
│               │           └── title <string>
│               ├── PredictionRequest
│               │   ├── type <string>
│               │   ├── title <string>
│               │   └── properties
│               │       ├── id
│               │       │   ├── type <string>
│               │       │   └── title <string>
│               │       ├── input
│               │       │   └── $ref <string>
│               │       ├── webhook
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   ├── format <string>
│               │       │   ├── maxLength <int>
│               │       │   └── minLength <int>
│               │       ├── created_at
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   └── format <string>
│               │       ├── output_file_prefix
│               │       │   ├── type <string>
│               │       │   └── title <string>
│               │       └── webhook_events_filter
│               │           ├── type <string>
│               │           ├── items
│               │           │   └── $ref <string>
│               │           └── default []string
│               ├── PredictionResponse
│               │   ├── type <string>
│               │   ├── title <string>
│               │   └── properties
│               │       ├── id
│               │       │   ├── type <string>
│               │       │   └── title <string>
│               │       ├── logs
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   └── default <string>
│               │       ├── error
│               │       │   ├── type <string>
│               │       │   └── title <string>
│               │       ├── input
│               │       │   └── $ref <string>
│               │       ├── output
│               │       │   └── $ref <string>
│               │       ├── status
│               │       │   └── $ref <string>
│               │       ├── metrics
│               │       │   ├── type <string>
│               │       │   └── title <string>
│               │       ├── version
│               │       │   ├── type <string>
│               │       │   └── title <string>
│               │       ├── created_at
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   └── format <string>
│               │       ├── started_at
│               │       │   ├── type <string>
│               │       │   ├── title <string>
│               │       │   └── format <string>
│               │       └── completed_at
│               │           ├── type <string>
│               │           ├── title <string>
│               │           └── format <string>
│               └── HTTPValidationError
│                   ├── type <string>
│                   ├── title <string>
│                   └── properties
│                       └── detail
│                           ├── type <string>
│                           ├── items
│                           │   └── $ref <string>
│                           └── title <string>
├── license_url <string>
├── name <string>
├── owner <string>
├── paper_url <string>
├── run_count <int>
├── url <string>
└── visibility <string>

jq で ‘.default_example.version’キーを取得するとバージョンハッシュを取得できました。

$ curl -s -H "Authorization: Bearer ${REPLICATE_API_TOKEN}" \
    https://api.replicate.com/v1/models/meta/llama-2-7b-chat
    | jq '.default_example.version'
"f1d50bb24186c52daae319ca8366e53debdaa9e0ae7ff976e918df752732ccc4"

モデル指定とともにfetchしてバージョンを獲得する

上記のcurlをfetch APIに置き換えてtypescriptでバージョン取得する関数を作りました。

type Model = `${string}/${string}`;
type ModelVersion = `${Model}:${string}`;

async function getReplicateModelVersion(model: Model): Promise<ModelVersion> {
  // APIトークンを取得
  const token = Deno.env.get("REPLICATE_API_TOKEN");
  if (!token) {
    throw new Error("REPLICATE_API_TOKEN is not set");
  }

  const url = `https://api.replicate.com/v1/models/${model}`;
  const headers = {
    "Authorization": `Bearer ${token}`,
  };
  try {
    // バージョン取得用 GET API
    const response = await fetch(url, { headers });
    if (!response.ok) {
      throw new Error(
        `Error fetch Replicate model: ${response.status} - ${response.statusText}`,
      );
    }
    const data = await response.json();
    // 取得したJSONのversionオブジェクトにハッシュが格納されている
    // .
    // ├── default_example
    // │   ├── version <string>
    const version = data.default_example.version;
    return `${model}:${version}`;
  } catch (error) {
    console.error(error);
    throw error;
  }
}

回答まで取得する全文

import { Replicate } from "npm:@langchain/community/llms/replicate";
// replicate モジュールのempty import
// またはnpm install -D reqlicate が必要
import { _ } from "npm:replicate";

type Model = `${string}/${string}`;
type ModelVersion = `${Model}:${string}`;

// モデルオーナーとモデル名から最新のバージョンを取得する
async function getReplicateModelVersion(model: Model): Promise<ModelVersion> {
  // APIトークンを取得
  const token = Deno.env.get("REPLICATE_API_TOKEN");
  if (!token) {
    throw new Error("REPLICATE_API_TOKEN is not set");
  }

  const url = `https://api.replicate.com/v1/models/${model}`;
  const headers = {
    "Authorization": `Bearer ${token}`,
  };
  try {
    // バージョン取得用 GET API
    const response = await fetch(url, { headers });
    if (!response.ok) {
      throw new Error(
        `Error fetch Replicate model: ${response.status} - ${response.statusText}`,
      );
    }
    const data = await response.json();
    // 取得したJSONのversionオブジェクトにハッシュが格納されている
    // .
    // ├── default_example
    // │   ├── version <string>
    const version = data.default_example.version;
    return `${model}:${version}`;
  } catch (error) {
    console.error(error);
    throw error;
  }
}

// モデルの指定とバージョンの獲得
const model = "meta/llama-2-7b-chat";
const modelWithVersion = await getReplicateModelVersion(model);

// モデルへ質問と回答の取得
const llm = new Replicate({ model: modelWithVersion });
const res = await llm.invoke("generate code hello world using python");
console.log(res);

まとめ

ReplicateとLangchainを組み合わせて使う際、モデルのバージョンを指定する必要がありました。バージョンリストはAPIから簡単に取得できました。
しかしながら、この辺の使い方が書かれたドキュメントがLangchain側にはなく、Replicate側は調べるのにとても時間がかかったので、記事を書いて皆様に共有いたします。

  1. 例えばLLaMA3とかMistralとか。replicate/exploreに一覧があります。

  2. How does Replicate work?

  3. List model versions

  4. APIキーの取得はReplicate にサインインする & API キーを発行する - Qiita参照。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0