1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

watsonxのAPIをJavaから呼び出してみる (Llama 2) _V3

Posted at

コードをさらにコンパクトにする

最初の記事ではHelloWorld的なわかりやすさを優先して、単純に watsonx API を呼び出していました。

次の記事ではJavaのプログラムとして少しコンパクトにしてみました。

そしてこの記事ではさらにJavaプログラムとしてきちんと書いてみます。

// 利用したいLLM
String model_id = "meta-llama/llama-2-70b-chat";
WatsonX wx = new WatsonX(model_id);
System.out.println(wx.generate("日本で最も高い山は?"));
System.out.println(wx.generate("日本で最も長い川は?"));

やればできる!

WatsonX クラス

余計な処理を外から隠すために、以下のようなクラスを作成しました。(今回は練習用で、本番用にはさらに丁寧に作ることになります)

期限時間内であればアクセストークンを再利用するようにもしました。

package com.ibm.watsonx;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import org.apache.hc.client5.http.classic.methods.HttpPost;
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse;
import org.apache.hc.client5.http.impl.classic.HttpClients;
import org.apache.hc.core5.http.HttpEntity;
import org.apache.hc.core5.http.ParseException;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import com.google.gson.Gson;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.ibm.util.HttpPostBuilder;

public class WatsonX {
	private String model_id;
	private String end_point = "https://us-south.ml.cloud.ibm.com/ml/v1-beta/generation/text?version=2023-05-29";
	private String project_id;
	private String apikey;
	private String access_token = null;
	private long access_token_expiration = -1;

	public WatsonX(String model_id) {
		this.model_id = model_id;
		// 「watsonx > プロンプト・ラボ > コードの表示」から取得する
		project_id = System.getProperty("project_id", null);
		if (project_id == null) {
			throw new RuntimeException("Property not set: project_id");
		}
		// 「IBM Cloud > 管理 > アクセス(IAM) > サービスID」から取得する
		apikey = System.getProperty("apikey", null);
		if (apikey == null) {
			throw new RuntimeException("Property not set: apikey");
		}
	}

	public String generate(String input) throws IOException {
		return generate(input, "ja", 20, 0);
	}

	public String generate(String input, String lang, int max_new_tokens,
			int min_new_tokens) throws IOException {
		String access_token;
		if (this.access_token != null && this.access_token_expiration < System.currentTimeMillis() - 1000) {
			access_token = this.access_token;
		} else {
			access_token = access_token();
		}

		if (lang.equals("ja")) {
			input = String.format("入力:\\n%s\\n\\n出力:\\n", input);
		}
		try (CloseableHttpClient httpClient = HttpClients.createDefault();) {
			JsonObject request_body = new JsonObject();
			{
				request_body.addProperty("model_id", model_id);
				request_body.addProperty("input", input);
				JsonObject parameters = new JsonObject();
				{
					parameters.addProperty("decoding_method", "greedy");
					parameters.addProperty("max_new_tokens", max_new_tokens);
					parameters.addProperty("min_new_tokens", min_new_tokens);
					parameters.add("stop_sequences", new JsonArray());
					parameters.addProperty("repetition_penalty", 1);
				}
				request_body.add("request_body", parameters);
				request_body.addProperty("project_id", project_id);
			}

			HttpPost httpPost = (new HttpPostBuilder(end_point)) //
					.H("Content-Type", "application/json") //
					.H("Accept", "application/json") //
					.H("Authorization", "Bearer " + access_token) //
					.data(request_body.toString()) //
					.build();
			// Execute the request
			CloseableHttpResponse response = httpClient.execute(httpPost);
			try {
				// Get the response entity
				HttpEntity responseEntity = response.getEntity();
				if (responseEntity != null) {
					// Convert the response entity to a string
					String responseString = EntityUtils.toString(responseEntity, StandardCharsets.UTF_8);
					JsonObject jo = (new Gson()).fromJson(responseString, JsonObject.class);
					String generated_text = jo.get("results").getAsJsonArray()
							.get(0).getAsJsonObject().get("generated_text")
							.getAsString();
					return generated_text;
				}
				EntityUtils.consume(responseEntity);
			} catch (ParseException e) {
				throw new IOException(e);
			} finally {
				response.close();
			}
		} catch (URISyntaxException e) {
			throw new IOException(e);
		}
		return null;
	}
	private String access_token() throws IOException {
		{ // APIキーを利用してアクセストークンを取得する(1時間有効)
			String end_point_token = "https://iam.cloud.ibm.com/identity/token";
			try (CloseableHttpClient httpClient = HttpClients.createDefault();) {
				HttpPost httpPost = (new HttpPostBuilder(end_point_token)) //
						.H("Content-Type", "application/x-www-form-urlencoded") //
						.H("Accept", "application/json") //
						.D("grant_type", "urn:ibm:params:oauth:grant-type:apikey") //
						.D("apikey", apikey) //
						.build();
				// Execute the request
				CloseableHttpResponse response_token = httpClient.execute(httpPost);
				try {
					// Get the response entity
					HttpEntity responseEntity = response_token.getEntity();
					if (responseEntity != null) {
						// Convert the response entity to a string
						String responseString = EntityUtils.toString(responseEntity, StandardCharsets.UTF_8);
						{
							JsonObject jo = (new Gson()).fromJson(responseString, JsonObject.class);
							// 取得したアクセストークンを変数にセットする
							String access_token = jo.get("access_token").getAsString();
							this.access_token = access_token;
							this.access_token_expiration = jo.get("expiration").getAsLong();
							return access_token;
						}
					}
					EntityUtils.consume(responseEntity);
				} catch (ParseException e) {
					throw new IOException(e);
				} finally {
					response_token.close();
				}
			} catch (URISyntaxException e) {
				throw new IOException(e);
			}
		}
		return null;
	}
}

結果

出力結果は以下のような感じになりました。

日本で最も高い山は富士山です。標高は
日本で最も長い川は、 Shinano River です。

...「標高は」??
... 「Shinano River」??

どうやらLlama 2 で日本語を使うときには工夫が必要になりそうです。

ちなみに英語だと以下の回答になります。

The highest mountain in Japan is Mount Fuji, with a height of 3,776
The longest river in Japan is the Shinano River, which is approximately 367 kilometers

いい感じです。

以上です。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?