LoginSignup
4
1

More than 1 year has passed since last update.

azure batch 使った機能を自社アプリケーションに組み込んでみた

Last updated at Posted at 2021-12-07

基礎知識

azureバッチの概要はこちら↓
azure Batch ってなんなん?

はじめに

私が管理しているアプリケーションで同じ処理を大量に並行で実行したい要件があり、Azure Batchを使って、計算する機構をつくりました。

アプリ概要

あるWebアプリで指定された複数のパラメータを元に各パラメータごとに計算用アプリを実行します。

仕込み

  • 元々Spring で作った自社Webアプリケーションがあります。
  • 今回計算用にSpringでCliアプリケーションを用意しました。(引数を2つ受け取って計算結果をDBに保存する)

仕組み

1.計算用のCliアプリケーションをazure storageに配置(batchCli.jar)
2.自社アプリ本体から、Azure Batchのリソースを作成するように指示
3.指示を受けて、プール、ノード、ジョブ、タスクが作られる
4.作られたノード上でタスクが実行され、DBに結果書き出し、実行ログがazure storageに保存される

仕組み.png

コード説明

こちらを参考に
上記の仕組みを実現するコードを作りました。

executeメソッドにて、
1.計算必要なパラメータを取得
2.プールを作成
3.タスクを実行
4.結果を保存
というシンプルなものですが、

ネット上にサンプルが少なく単純に作るのに時間かかったのと
本記事末尾にあるような認証まわりのハマりポイントにハマりかなり苦戦しました。

/**
 * Azure Batchを利用して複数パラメータの計算を実行する
 *
 */
@Slf4j
@Component
@RequiredArgsConstructor
public class BatchExecTask {

    //実行に必要な設定を持ってるクラス
    private final @NonNull BatchConfig batchConfig;

    //計算対象を撮ってくるサービス
    private final @NonNull BatchService batchService;

    //結果を書き込むサービス
    private final @NonNull BatchResultService batchResultService;

    //1ノードで実行できる最大タスク
    private static Integer MAX_TASKS_PER_NODE = 4;

    //パラメータ埋めるためのクラス
    @Value(staticConstructor = "of")
    public static class BatchExecParameter {
        private Long batchId;
        private Long taskId;
        private int order;
        private String taskKey;
        private BigDecimal parameter1;
        private BigDecimal parameter2;
    }

    //実行用のメインメソッド
    @Transactional(rollbackFor = Exception.class)
    public void execute(Long taskId, Long batchId) {
        log.info(taskId, "計算リソースの準備を開始します"));
        // Batch関連リソースの後処理設定
        // 基本は削除
        Boolean shouldDeleteJob = true;
        Boolean shouldDeletePool = true;

        var TASK_COMPLETE_TIMEOUT = Duration.ofMinutes(30);

        var cred = getCredentials();
        BatchClient client = BatchClient.open(cred);

        //作った時間でpoolとjobの名前を決める
        var timeKey = DateTime.now().toString("yyyy-MM-dd-HH-mm-ss");
        var poolId = "pool-" + timeKey;
        var jobId = String.format("job-%s", timeKey);

        try {
            log.info(taskId, "計算条件を取得します"));
            var params = getParameters(batchId, taskId);

            log.info(taskId, "計算リソースを作成します。数分お待ち下さい"));
            var sharedPool = createPoolIfNotExists(client, poolId, params.size());
            log.info(taskId, "計算処理を開始します。数分お待ち下さい"));
            submitJobAndAddTask(client, sharedPool.id(), jobId, params, taskId);
            //計算実行したパラメータとタスクの名前のマップを記録しておく
            var paramMap = Seq.seq(params).toMap(p -> p.getTaskKey());
            if (!waitForTasksToComplete(client, jobId, TASK_COMPLETE_TIMEOUT)) {
                throw new TimeoutException("計算処理がタイムアウトしました。");
            }
            log.info(taskId, "計算が完了しました。結果を保存します。"));
            //結果保存
            batchResultService.store(batchId, taskId, paramMap);
            log.info(taskId, "保存処理が完了しました。"));
        } catch (BatchErrorException err) {
            log.info(taskId, "エラーが発生しました。"));
            printBatchException(err);
        } catch (ErrorException err) {
            log.info(taskId, "エラーが発生しました。"));
            throw err;
        } catch (Exception ex) {
            log.info(taskId, "エラーが発生しました。"));
            ex.printStackTrace();
            throw new ErrorException("計算処理中に予期せぬエラーが発生しました。");
        } finally {
            // 必要に応じてリソースを削除する
            if (shouldDeleteJob) {
                try {
                    client.jobOperations().deleteJob(jobId);
                } catch (BatchErrorException err) {
                    printBatchException(err);
                } catch (IOException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }

            if (shouldDeletePool) {
                try {
                    client.poolOperations().deletePool(poolId);
                } catch (BatchErrorException err) {
                    printBatchException(err);
                } catch (IOException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
        }
    }

    /**
     * 実行用のパラメータを取得する
     */
    private List<BatchExecParameter> getParameters(
            long batchId, long taskId) {
        //実行パラメータのリストを取得する
        //実行するparameter1,parameter2のマトリックスで決まる。
        //5×4なら20個のパラメータのリストを返す  
        var detail = batchService.getDeatil(batchId);
        List<BatchExecParameter> params = new ArrayList<>();
        var keyId = 1;
        for (var i = 0; i < detail.getParameter1List().size(); i++) {
            for (var j = 0; j < detail.getParameter2List().size(); j++) {
                //タスクの並び順が文字列純だとばらつくので0埋めにする
                var taskKey = String.format("task%03d", keyId);
                var parameter1 = detail.getParameter1List().get(i);
                var parameter2 = detail.getParameter2List().get(j);
                var param = BatchExecParameter.of(batchId,
                        taskId, keyId, taskKey, parameter1, parameter2);
                params.add(param);
                keyId++;
            }
        }

        return params;
    }

    /**
     * Azure Batchの資格情報を取得する
     */
    private BatchCredentials getCredentials() {
        //Azure ADで認証する(そうしないと仮想ネットワークが使えない)
        var batchEndpoint = batchConfig.getBatchUri();
        //ActiveDirectoryのアプリケーションのクライアントID
        var clientId = batchConfig.getClientId();
        //アプリケーションの認証シークレット
        var applicationSecret = batchConfig.getSecret();
        //このアプリケーションを含むドメインまたはテナントID。
        var tenantId = batchConfig.getTenantId();
        //エンドポイントはnullにしとくとコンストラクタでセットされる
        var cred = new BatchApplicationTokenCredentials(batchEndpoint, clientId, applicationSecret,
                tenantId, null, null);
        return cred;
    }

    /**
     * azure batchのプールにジョブとタスクを紐づける
     */
    private void submitJobAndAddTask(BatchClient client,
            String poolId,
            String jobId,
            List<BatchExecParameter> params, Long taskId)
            throws BatchErrorException, IOException, StorageException, InvalidKeyException,
            URISyntaxException {

        // プールにジョブを追加する
        var poolInfo = new PoolInformation();
        poolInfo.withPoolId(poolId);
        client.jobOperations().createJob(jobId, poolInfo);

        // タスクにファイルを紐付ける
        String sas = retrieveJarUriWithSas();
        var file = new ResourceFile();
        file.withFilePath(batchConfig.getCliJarName()).withHttpUrl(sas);
        List<ResourceFile> files = new ArrayList<>();
        files.add(file);
        for (var param : params) {
            // タスクを作る
            var taskToAdd = new TaskAddParameter();
            taskToAdd.withId(param.getTaskKey()).withCommandLine(createCommandLine(param));
            taskToAdd.withResourceFiles(files);
            //タスクログとストレージの紐付け
            var outputFiles = getOutputFiles(taskId, jobId, param.getTaskKey());
            taskToAdd.withOutputFiles(outputFiles);
            // ジョブにタスクを追加する
            client.taskOperations().createTask(jobId, taskToAdd);
        }

    }

    /**
     * 標準出力と標準エラー出力のファイルを作る
     */
    private List<OutputFile> getOutputFiles(Long taskId, String jobId, String taskKey)
            throws IOException {
        var stdout = createOutputFile(taskId, jobId, taskKey, "stdout");
        var stderr = createOutputFile(taskId, jobId, taskKey, "stderr");
        return List.of(stdout, stderr);
    }

    /**
     * 標準出力と標準エラー出力のファイルを作る
     */
    private OutputFile createOutputFile(Long taskId, String jobId, String taskKey,
            String fileName) throws IOException {
        var output = new OutputFile();
        var destination = new OutputFileDestination();
        var storage = new OutputFileBlobContainerDestination();
        var containerUrl = retrieveLogContainerUriWithSas();
        var path = String.format("/%s-%s/%s/%s.txt", jobId, taskId.toString(), taskKey, fileName);
        storage.withContainerUrl(containerUrl).withPath(path);
        destination.withContainer(storage);
        var uploadOptions = new OutputFileUploadOptions()
                .withUploadCondition(OutputFileUploadCondition.TASK_COMPLETION);
        output.withDestination(destination)
                //ログファイルができる階層は実行ディレクトリの一つ上にできる
                .withFilePattern("../" + fileName + ".txt")
                .withUploadOptions(uploadOptions);
        return output;
    }

    /**
     * プールが存在しなかったら作ります
     * 
     * @param Batchクライアントのインスタンス
     * @param プールID
     * @return プールのインスタンス
     * @throws Exception 
     */
    private CloudPool createPoolIfNotExists(BatchClient client, String poolId, Integer paramSize)
            throws Exception {
        //プールが安定するまで待つ時間:5分
        Duration POOL_STEADY_TIMEOUT = Duration.ofMinutes(5);
        //VMが準備されるまでに待つ時間:20分
        Duration VM_READY_TIMEOUT = Duration.ofMinutes(20);

        // Batch poolが存在するかどうか確認。しなければ作る。
        if (!client.poolOperations().existsPool(poolId)) {

            var configuration = createVmConfiguration(client);
            // ノード数(並列数を超える分は別ノードで実行)
            int poolVMCount = (int) Math
                    .ceil(paramSize.doubleValue() / MAX_TASKS_PER_NODE.doubleValue());
            var startTask = new StartTask();
            PoolAddParameter poolAddParameter = new PoolAddParameter();
            poolAddParameter.withId(poolId);
            poolAddParameter.withVmSize(batchConfig.getVmSize());
            // 標準のイメージではjavaがインストールされていないため、開始タスクでインストール
            startTask.withCommandLine(
                    "/bin/bash -c \"apt-get update && apt-get install -y openjdk-11-jdk\"");
            //プール autouser 管理者相当
            var userIdentity = (new UserIdentity()).withAutoUser(new AutoUserSpecification()
                    .withElevationLevel(ElevationLevel.ADMIN).withScope(AutoUserScope.POOL));
            startTask.withUserIdentity(userIdentity);
            poolAddParameter.withStartTask(startTask);
            poolAddParameter.withVirtualMachineConfiguration(configuration);
            poolAddParameter.withTargetDedicatedNodes(poolVMCount);
            if (batchConfig.getSubnetId() != null) {
                //仮想ネットワークの設定
                //※設定する仮想ネットワークは同一サブスクリプションである必要がある
                //https://docs.microsoft.com/ja-jp/java/api/com.microsoft.azure.batch.protocol.models.networkconfiguration.withsubnetid?view=azure-java-stable#com_microsoft_azure_batch_protocol_models_NetworkConfiguration_withSubnetId_java_lang_String_
                var networkConfig = new NetworkConfiguration();
                networkConfig.withSubnetId(batchConfig.getSubnetId());
                poolAddParameter.withNetworkConfiguration(networkConfig);
            }

            //タスクの並列実行数(コア数の4倍が最大)
            poolAddParameter.withMaxTasksPerNode(MAX_TASKS_PER_NODE);
            client.poolOperations().createPool(poolAddParameter);
        }

        boolean steady = false;

        // プールが安定するまで待機する
        var steadyStopWatch = Stopwatch.createStarted();
        while (steadyStopWatch.elapsed().toMillis() < POOL_STEADY_TIMEOUT.toMillis()) {
            var pool = client.poolOperations().getPool(poolId);

            if (pool.allocationState() == AllocationState.STEADY) {
                steady = true;
                var creationElapsedSeconds = Seconds.secondsBetween(pool.creationTime(),
                        pool.allocationStateTransitionTime());
                log.info("プールが安定状態になりました。所要時間:" + creationElapsedSeconds.getSeconds() + "秒");
                break;
            }
            log.info("プールが安定するまで10秒待機します...");
            Thread.sleep(10 * 1000);
        }

        if (!steady) {
            throw new TimeoutException("プールの割当がタイムアウトしました");
        }

        // VMがアイドル状態になるまで待機する
        boolean hasIdleVM = false;
        var readyStopWatch = Stopwatch.createStarted();
        while (readyStopWatch.elapsed().toMillis() < VM_READY_TIMEOUT.toMillis()) {
            List<ComputeNode> nodeCollection = client.computeNodeOperations().listComputeNodes(
                    poolId,
                    new DetailLevel.Builder().withSelectClause("id, state")
                            .withFilterClause("state eq 'idle'")
                            .build());
            if (!nodeCollection.isEmpty()) {
                hasIdleVM = true;
                log.info("仮想マシンがアイドル状態になりました。所要時間:" + readyStopWatch.elapsed(TimeUnit.SECONDS)
                        + "秒");
                break;
            }

            log.info("仮想マシンが開始するまで10秒待機します...");
            Thread.sleep(10 * 1000);
        }

        if (!hasIdleVM) {
            throw new TimeoutException("仮想マシンの開始がタイムアウトしました");
        }
        return client.poolOperations().getPool(poolId);
    }


    /**
     * VMイメージの定義を作る
     */
    private VirtualMachineConfiguration createVmConfiguration(BatchClient client) throws Exception {

       //ubuntu 18.04を指定。必要に応じて変える
        String osPublisher = "canonical";
        String osOffer = "ubuntuserver";
        String imageVersion = "18.04-lts";
        // sku image の参照を取得する
        List<ImageInformation> skus = client.accountOperations().listSupportedImages();

        //SKUの取得
        var skuOpt = Seq.seq(skus)
                .filter(sku -> sku.osType() == OSType.LINUX)
                .filter(sku -> sku.verificationType() == VerificationType.VERIFIED)
                .filter(sku -> sku.imageReference().publisher().equalsIgnoreCase(osPublisher)
                        && sku.imageReference().offer().equalsIgnoreCase(osOffer)
                        && sku.imageReference().sku().equals(imageVersion))
                .findFirst();
        if (!skuOpt.isPresent()) {
            throw new Exception("image not found");
        }

        var sku = skuOpt.get();
        var imageRef = sku.imageReference();
        var skuId = sku.nodeAgentSKUId();

        // イメージを指定してpoolを作る
        var configuration = new VirtualMachineConfiguration();
        configuration.withNodeAgentSKUId(skuId).withImageReference(imageRef);
        return configuration;
    }

    /**
     * ジョブ内のタスクが終わるまで待機します。
     * 
     * @param client azure batchクライアント
     * @param jobId  ジョブID
     * @param expiryTime タイムアウトまでの時間
     * @return 時間内に全てのタスクが完了したらtrue, そうでない場合はfalseを返します
     * @throws BatchErrorException
     * @throws IOException
     * @throws InterruptedException
     */
    private boolean waitForTasksToComplete(BatchClient client, String jobId,
            Duration expiryTime)
            throws BatchErrorException, IOException, InterruptedException {

        var stopWatch = Stopwatch.createStarted();
        while (stopWatch.elapsed().toMillis() < expiryTime.toMillis()) {
            List<CloudTask> taskCollection = client.taskOperations().listTasks(jobId,
                    new DetailLevel.Builder().withSelectClause("id, state, executionInfo").build());

            // 全てのタスクが完了したかどうか
            var allComplete = Seq.seq(taskCollection)
                    .allMatch(t -> t.state() == TaskState.COMPLETED);
            if (allComplete) {
                //全部完了したら抜ける。
                StringBuilder errorDetailBuilder = new StringBuilder();
                for (var task : taskCollection) {
                    var failureInfo = task.executionInfo().failureInfo();
                    if (failureInfo != null) {
                        errorDetailBuilder.append(System.lineSeparator());
                        errorDetailBuilder.append(
                                String.format("処理名:%s エラー内容:%s", task.id(), failureInfo.message()));
                    }
                }

                log.info("タスクが全て完了しました。実行時間:" + stopWatch.elapsed(TimeUnit.SECONDS) + "秒");
                if (errorDetailBuilder.length() > 0) {
                    var erroMessage = "計算処理でエラーが発生しました" + System.lineSeparator();
                    throw new ErrorException(erroMessage + errorDetailBuilder.toString());
                }
                //タスクからメッセージが帰ってくるまで10秒待つ
                Thread.sleep(10 * 1000);
                return true;
            }

            log.debug("タスクが完了するまで10秒待ちます..");

            // 10秒ごとのチェックする
            Thread.sleep(10 * 1000);
        }

        // タイムアウトしたら抜ける。
        return false;
    }

    /**
     * 実行用のコマンドライン文字列を生成します
     * @param parameter1
     * @param parameter2
     * @return
     */
    private String createCommandLine(BatchExecParameter param) {

        // 出来上がりの文字列イメージ
        //        java -jar 
        //        -Dspring.profiles.active=xxx,xxx-local 
        //        -Duser.language=ja 
        //        -Duser.country=JP 
        //        -Duser.timezone=Asia/Tokyo 
        //        -Dfile.encoding=UTF-8 
        //        batchCli.jar BatchExec -p1 0.35 -p2 123456
        return "java -jar " +
                "-Dspring.profiles.active=" + batchConfig.getExecuteProfile() + " " +
                "-Dspring.cloud.config.label=" + batchConfig.getConfigLabel() + " " +
                "-Duser.language=ja -Duser.country=JP -Duser.timezone=Asia/Tokyo " +
                "-Dfile.encoding=UTF-8 " +
                //同一ノードでアプリを実行する場合、tomcatのportが重複するためタスクごとに変更する
                "-Dserver.port=" + (8080 + param.order) + " " +
                //アクチュエーターのportが重複するためタスクごとに変更する
                "-Dmanagement.server.port=" + (9999 - param.order) + " " +
                batchConfig.getCliJarName() + " " +
                "BatchExec " +
                "-p1 " + param.getParameter1().toString() + " " +
                "-p2 " + param.getParameter2().toString() + " ";
    }

    /**
     * jarの格納先blobへのsas付きパスを返却する
     * @return
     * @throws IOException
     */
    private String retrieveJarUriWithSas() throws IOException {
        try {
            var container = getContainerReference(batchConfig.getStorageContainerName());
            var blob = container.getBlockBlobReference(batchConfig.getCliJarPath());
            //一日だけ読み取り権限
            var accessPolicy = StorageUtil.createSharedAccessPolicy(
                    EnumSet.of(SharedAccessBlobPermissions.READ),
                    60 * 24);
            return blob.getUri() + "?" + blob.generateSharedAccessSignature(accessPolicy, null);

        } catch (URISyntaxException | StorageException | InvalidKeyException e) {
            throw new IOException(e);
        }
    }

    /**
     * ログファイルの格納先blobへのsas付きパスを返却する
     * @return
     * @throws IOException
     */
    private String retrieveLogContainerUriWithSas() throws IOException {
        try {
            var container = getContainerReference("tasklogs");
            container.createIfNotExists();
            //一日だけ書き込み権限
            var accessPolicy = StorageUtil.createSharedAccessPolicy(
                    EnumSet.of(SharedAccessBlobPermissions.WRITE),
                    60 * 24);
            return container.getUri() + "?"
                    + container.generateSharedAccessSignature(accessPolicy, null);

        } catch (URISyntaxException | StorageException | InvalidKeyException e) {
            throw new IOException(e);
        }
    }

    /**
     * コンテナへの参照を取得する
     * 
     * @param storageAccountName ストレージアカウント名
     * @param storageAccountKey ストレージアカウントキー
     * @return コンテナへの参照
     * @throws URISyntaxException
     * @throws StorageException
     */
    private CloudBlobContainer getContainerReference(String containerName)
            throws URISyntaxException, StorageException {

        // ストレージの資格情報を生成する
        var credentials = new StorageCredentialsAccountAndKey(batchConfig.getStorageAccount(),
                batchConfig.getStorageAccountKey());

        // https接続でストレージアカウントを生成する
        var storageAccount = new CloudStorageAccount(credentials, true);

        // blobクライアントを生成する
        var blobClient = storageAccount.createCloudBlobClient();

        // コンテナへの参照を取得する
        return blobClient.getContainerReference(containerName);
    }

    /**
     * バッチエラーを出力します
     * 
     * @param err バッチエラーの内容
     */
    private static void printBatchException(BatchErrorException err) {
        var builder = new StringBuilder();
        builder.append(String.format("計算中にエラーが発生しました %s", err.toString()));
        if (err.body() != null) {
            builder.append(System.lineSeparator());
            builder.append(String.format("エラーコード = %s, message = %s", err.body().code(),
                    err.body().message().value()));
            if (err.body().values() != null) {
                for (var detail : err.body().values()) {
                    builder.append(System.lineSeparator());
                    builder.append(String.format("エラー詳細 %s=%s", detail.key(), detail.value()));
                }
            }
        }
        var errMsg = builder.toString();
        log.error(errMsg);
        throw new ErrorException(errMsg);
    }

    private TaskMessage createMessage(long taskId, String message) {
        var taskMesssage = new TaskMessage();
        taskMesssage.setTaskId(taskId);
        taskMesssage.setMessage(message);
        taskMesssage.setPublishedDateTime(LocalDateTime.now());
        return taskMesssage;
    }
}

ハマったポイント

・ノードからSQLサーバーのアクセスが弾かれてエラー
・batch poolに仮想ネットワーク設定しようとしたらAD認証じゃないとだめでエラー
・各リソースのIAM設定しないとazure batchからリソース操作できないエラー
・アプリのポート重複により起動できないエラー

4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1