コードをさらにコンパクトにする
最初の記事ではHelloWorld的なわかりやすさを優先して、単純に watsonx API を呼び出していました。
次の記事ではJavaのプログラムとして少しコンパクトにしてみました。
そしてこの記事ではさらにJavaプログラムとしてきちんと書いてみます。
// 利用したいLLM
String model_id = "meta-llama/llama-2-70b-chat";
WatsonX wx = new WatsonX(model_id);
System.out.println(wx.generate("日本で最も高い山は?"));
System.out.println(wx.generate("日本で最も長い川は?"));
やればできる!
WatsonX クラス
余計な処理を外から隠すために、以下のようなクラスを作成しました。(今回は練習用で、本番用にはさらに丁寧に作ることになります)
期限時間内であればアクセストークンを再利用するようにもしました。
package com.ibm.watsonx;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import org.apache.hc.client5.http.classic.methods.HttpPost;
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse;
import org.apache.hc.client5.http.impl.classic.HttpClients;
import org.apache.hc.core5.http.HttpEntity;
import org.apache.hc.core5.http.ParseException;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import com.google.gson.Gson;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.ibm.util.HttpPostBuilder;
public class WatsonX {
private String model_id;
private String end_point = "https://us-south.ml.cloud.ibm.com/ml/v1-beta/generation/text?version=2023-05-29";
private String project_id;
private String apikey;
private String access_token = null;
private long access_token_expiration = -1;
public WatsonX(String model_id) {
this.model_id = model_id;
// 「watsonx > プロンプト・ラボ > コードの表示」から取得する
project_id = System.getProperty("project_id", null);
if (project_id == null) {
throw new RuntimeException("Property not set: project_id");
}
// 「IBM Cloud > 管理 > アクセス(IAM) > サービスID」から取得する
apikey = System.getProperty("apikey", null);
if (apikey == null) {
throw new RuntimeException("Property not set: apikey");
}
}
public String generate(String input) throws IOException {
return generate(input, "ja", 20, 0);
}
public String generate(String input, String lang, int max_new_tokens,
int min_new_tokens) throws IOException {
String access_token;
if (this.access_token != null && this.access_token_expiration < System.currentTimeMillis() - 1000) {
access_token = this.access_token;
} else {
access_token = access_token();
}
if (lang.equals("ja")) {
input = String.format("入力:\\n%s\\n\\n出力:\\n", input);
}
try (CloseableHttpClient httpClient = HttpClients.createDefault();) {
JsonObject request_body = new JsonObject();
{
request_body.addProperty("model_id", model_id);
request_body.addProperty("input", input);
JsonObject parameters = new JsonObject();
{
parameters.addProperty("decoding_method", "greedy");
parameters.addProperty("max_new_tokens", max_new_tokens);
parameters.addProperty("min_new_tokens", min_new_tokens);
parameters.add("stop_sequences", new JsonArray());
parameters.addProperty("repetition_penalty", 1);
}
request_body.add("request_body", parameters);
request_body.addProperty("project_id", project_id);
}
HttpPost httpPost = (new HttpPostBuilder(end_point)) //
.H("Content-Type", "application/json") //
.H("Accept", "application/json") //
.H("Authorization", "Bearer " + access_token) //
.data(request_body.toString()) //
.build();
// Execute the request
CloseableHttpResponse response = httpClient.execute(httpPost);
try {
// Get the response entity
HttpEntity responseEntity = response.getEntity();
if (responseEntity != null) {
// Convert the response entity to a string
String responseString = EntityUtils.toString(responseEntity, StandardCharsets.UTF_8);
JsonObject jo = (new Gson()).fromJson(responseString, JsonObject.class);
String generated_text = jo.get("results").getAsJsonArray()
.get(0).getAsJsonObject().get("generated_text")
.getAsString();
return generated_text;
}
EntityUtils.consume(responseEntity);
} catch (ParseException e) {
throw new IOException(e);
} finally {
response.close();
}
} catch (URISyntaxException e) {
throw new IOException(e);
}
return null;
}
private String access_token() throws IOException {
{ // APIキーを利用してアクセストークンを取得する(1時間有効)
String end_point_token = "https://iam.cloud.ibm.com/identity/token";
try (CloseableHttpClient httpClient = HttpClients.createDefault();) {
HttpPost httpPost = (new HttpPostBuilder(end_point_token)) //
.H("Content-Type", "application/x-www-form-urlencoded") //
.H("Accept", "application/json") //
.D("grant_type", "urn:ibm:params:oauth:grant-type:apikey") //
.D("apikey", apikey) //
.build();
// Execute the request
CloseableHttpResponse response_token = httpClient.execute(httpPost);
try {
// Get the response entity
HttpEntity responseEntity = response_token.getEntity();
if (responseEntity != null) {
// Convert the response entity to a string
String responseString = EntityUtils.toString(responseEntity, StandardCharsets.UTF_8);
{
JsonObject jo = (new Gson()).fromJson(responseString, JsonObject.class);
// 取得したアクセストークンを変数にセットする
String access_token = jo.get("access_token").getAsString();
this.access_token = access_token;
this.access_token_expiration = jo.get("expiration").getAsLong();
return access_token;
}
}
EntityUtils.consume(responseEntity);
} catch (ParseException e) {
throw new IOException(e);
} finally {
response_token.close();
}
} catch (URISyntaxException e) {
throw new IOException(e);
}
}
return null;
}
}
結果
出力結果は以下のような感じになりました。
日本で最も高い山は富士山です。標高は
日本で最も長い川は、 Shinano River です。
...「標高は」??
... 「Shinano River」??
どうやらLlama 2 で日本語を使うときには工夫が必要になりそうです。
ちなみに英語だと以下の回答になります。
The highest mountain in Japan is Mount Fuji, with a height of 3,776
The longest river in Japan is the Shinano River, which is approximately 367 kilometers
いい感じです。
以上です。