95
50

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

生成AIで古(いにしえ)のFORTRANコードに挑戦する

95
Last updated at Posted at 2026-03-16

古(いにしえ)のFORTRANコード

ChatGPT5.4 を使います。MacのCodexアプリを使います(5.4 Codexではなく5.4でした。)。
このWebサイトでは、物理学で使う計算手法のコードが公開されています。FORTRANです。DMFTと呼ばれる手法の超有名なレビューで使われている関数でして、今でもその一部が使われているかもしれません。
さて、このコードのうち、lisalanc.fというファイル内には、非線形関数を最適化する関数minimizeがあります。ここの文頭に、

C========+=========+=========+=========+=========+=========+=========+=$
C PROGRAM: minimize
C TYPE   : subroutine
C PURPOSE: conjugent gradient search
C I/O    : 
C VERSION: 30-Sep-95
C COMMENT: This is a most reliable conjugent gradient routine! It has
C          served us well for many years, and is capable to cope with
C          a very large number of variables. Unfortunately, we don't
C          know who wrote this routine (original name: 'va10a'), and 
C          we find it very obscure.
C          Don't worry, it works just fine.
Cnoprint=+=========+=========+=========+=========+=========+=========+=$

とあります。日本語に翻訳すると、

C========+=========+=========+=========+=========+=========+=========+=$
C プログラム名 : minimize
C 種類         : サブルーチン
C 目的         : 共役勾配法による探索
C 入出力       :
C バージョン   : 1995年9月30日
C コメント     : これは非常に信頼性の高い共役勾配法ルーチンである。
C                長年にわたり問題なく使われてきており、
C                非常に多数の変数を扱うこともできる。
C                ただし残念ながら、このルーチンを書いた人は
C                分かっていない(元の名前は 'va10a')。
C                また、コードはかなり分かりにくい。
C                心配しなくてもよい。きちんと動作する。
Cnoprint=+=========+=========+=========+=========+=========+=========+=$

となっています。誰が書いたかわからないけどうまく動くコード、ですね。
そして、このコードについて検索すると、R.Fletcherの"FORTRAN SUBROUTINES FOR MINIMIZATION BY QUASI-NEWTON METHODS", Research group report(United Kingdom Atomic Energy Authority)を見つけました。このコードva10aは1972年の4月に書かれたそうです。1977年よりも前ですから、FORTRAN77ですらないですね。まさに、古(いにしえ)に相応しいコードでしょう。

これを、最近の生成AIでモダンなコードに生まれ変わらせよう、という作戦です。使うのは、ChatGPT 5.4 です。MacのCodexアプリを使います。

コードの全体像

どのくらいFORTRANなのか見るために、minimize関数をここに貼り付けます。

C========+=========+=========+=========+=========+=========+=========+=$
C PROGRAM: minimize
C TYPE   : subroutine
C PURPOSE: conjugent gradient search
C I/O    : 
C VERSION: 30-Sep-95
C COMMENT: This is a most reliable conjugent gradient routine! It has
C          served us well for many years, and is capable to cope with
C          a very large number of variables. Unfortunately, we don't
C          know who wrote this routine (original name: 'va10a'), and 
C          we find it very obscure.
C          Don't worry, it works just fine.
Cnoprint=+=========+=========+=========+=========+=========+=========+=$
      subroutine minimize (funct, n, x, f, g, h, w, dfn, xm,
     $  hh, eps, mode, maxfn, iprint, iexit)
      implicit double precision (a-h,o-z)
      dimension x(*), g(*), h(*), w(*), xm(*)
      external funct
      data zero, half, one, two /0.0d0, 0.5d0, 1.0d0, 2.0d0/
      if (iprint .ne. 0) write (6,1000)
 1000 format (' entry into minimize')
      np = n + 1
      n1 = n - 1
      nn=(n*np)/2
      is = n
      iu = n
      iv = n + n
      ib = iv + n
      idiff = 1
      iexit = 0
      if (mode .eq. 3) go to 15
      if (mode .eq. 2) go to 10
      ij = nn + 1
      do 5 i = 1, n
      do 6 j = 1, i
      ij = ij - 1
   6  h(ij) = zero
   5  h(ij) = one
      go to 15
  10  continue
      ij = 1
      do 11 i = 2, n
      z = h(ij)
      if (z .le. zero) return
      ij = ij + 1
      i1 = ij
      do 11 j = i, n
      zz = h(ij)
      h(ij) = h(ij) / z
      jk = ij
      ik = i1
      do 12 k = i, j
      jk = jk + np - k
      h(jk) = h(jk) - h(ik) * zz
      ik = ik + 1
  12  continue
      ij = ij + 1
  11  continue
      if (h(ij) .le. zero) return
  15  continue
      ij = np
      dmin = h(1)
      do 16 i = 2, n
      if (h(ij) .ge. dmin) go to 16
      dmin = h(ij)
  16  ij = ij + np - i
      if (dmin .le. zero) return
      z = f
      itn = 0
      call funct (n, x, f)
      ifn = 1
      df = dfn
      if (dfn .eq. zero) df = f - z
      if (dfn .lt. zero) df = abs (df * f)
      if (df .le. zero) df = one
  17  continue
      do 19 i = 1, n
      w(i) = x(i)
  19  continue
      link = 1
      if (idiff - 1) 100, 100, 110
  18  continue
      if (ifn .ge. maxfn) go to 90
  20  continue
      if (iprint .eq. 0) go to 21
      if (mod (itn, iprint) .ne. 0) go to 21
       write (6,1001) itn, ifn
1001  format (1x,'itn = ',i5,' ifn = ',i5)
      write (6,1002) f
1002  format (1x,'f = ',e15.7)
      if (iprint .lt. 0) go to 21
      write (6,1003) (x(i), i = 1, n)
***
***
1003  format (1x,'x = ',4e15.7 / (5x, 4e15.7))
      write (6,1004) (g(i), i = 1, n)
1004  format (1x,'g = ',4e15.7 / (5x, 4e15.7))
**
***
  21  continue
      itn = itn + 1
      w(1) = -g(1)
      do 22 i = 2, n
      ij = i
      i1 = i - 1
      z = -g(i)
      do 23 j = 1, i1
      z = z - h(ij) * w(j)
      ij = ij + n - j
  23  continue
  22  w(i) = z
      w(is+n) = w(n) / h(nn)
      ij = nn
      do 25 i = 1, n1
      ij = ij - 1
      z = zero
      do 26 j = 1, i
      z = z + h(ij) * w(is+np-j)
      ij = ij - 1
  26  continue
  25  w(is+n-i) = w(n-i) / h(ij) - z
      z = zero
      gs0 = zero
      do 29 i = 1, n
      if (z * xm(i) .ge. abs (w(is+i))) go to 28
      z = abs (w(is+i)) / xm(i)
  28  gs0 = gs0 + g(i) * w(is+i)
  29  continue
      aeps = eps / z
      iexit = 2
      if (gs0 .ge. zero) go to 92
      alpha = -two * df / gs0
      if (alpha .gt. one) alpha = one
      ff = f
      tot = zero
      int = 0
      iexit = 1
  30  continue
      if (ifn .ge. maxfn) go to 90
      do 31 i = 1, n
      w(i) = x(i) + alpha * w(is+i)
  31  continue
      call funct (n, w, f1)
      ifn = ifn + 1
      if (f1 .ge. f) go to 40
      f2 = f
      tot = tot + alpha
  32  continue
      do 33 i = 1, n
      x(i) = w(i)
  33  continue
      f = f1
      if (int - 1) 35, 49, 50
  35  continue
      if (ifn .ge. maxfn) go to 90
      do 34 i = 1, n
      w(i) = x(i) + alpha * w(is+i)
  34  continue
      call funct (n, w, f1)
      ifn = ifn + 1
      if (f1 .ge. f) go to 50
      if ((f1 + f2 .ge. f + f) .and.
     $  (7.0d0 * f1 + 5.0d0 * f2 .gt. 12.0d0 * f)) int = 2
      tot = tot + alpha
      alpha = two * alpha
      go to 32
  40  continue
      if (alpha .lt. aeps) go to 92
      if (ifn .ge. maxfn) go to 90
      alpha = half * alpha
      do 41 i = 1, n
      w(i) = x(i) + alpha * w(is+i)
  41  continue
      call funct (n, w, f2)
      ifn = ifn + 1
      if (f2 .ge. f) go to 45
      tot = tot + alpha
      f = f2
      do 42 i = 1, n
      x(i) = w(i)
  42  continue
      go to 49
  45  continue
      z = 0.1d0
      if (f1 + f .gt. f2 + f2)
     $  z = one + half * (f - f1) / (f + f1 - f2 - f2)
      if (z .lt. 0.1d0) z = 0.1d0
      alpha = z * alpha
      int = 1
      go to 30
  49  continue
      if (tot .lt. aeps) go to 92
  50  continue
      alpha = tot
      do 56 i = 1, n
      w(i) = x(i)
      w(ib+i) = g(i)
  56  continue
      link = 2
      if (idiff - 1) 100, 100, 110
  54  continue
      if (ifn .ge. maxfn) go to 90
      gys = zero
      do 55 i = 1, n
      w(i) = w(ib+i)
      gys = gys + g(i) * w(is+i)
  55  continue
      df = ff - f
      dgs = gys - gs0
      if (dgs .le. zero) go to 20
      link = 1
      if (dgs + alpha * gs0 .gt. zero) go to 52
      do 51 i = 1, n
      w(iu + i) = g(i) - w(i)
  51  continue
      sig = one / (alpha * dgs)
      go to 70
  52  continue
      zz = alpha / (dgs - alpha * gs0)
      z = dgs * zz - one
      do 53 i = 1, n
      w(iu+i) = z * w(i) + g(i)
  53  continue
      sig = one / (zz * dgs * dgs)
      go to 70
  60  continue
      link = 2
      do 61 i = 1, n
      w(iu+i) = w(i)
  61  continue
      if (dgs + alpha * gs0 .gt. zero) go to 62
      sig = one / gs0
      go to 70
  62  continue
      sig = -zz
  70  continue
      w(iv+1) = w(iu+1)
      do 71 i = 2, n
      ij = i
      i1 = i - 1
      z = w(iu+i)
      do 72 j = 1, i1
      z = z - h(ij) * w(iv+j)
      ij = ij + n - j
  72  continue
      w(iv+i) = z
  71  continue
      ij = 1
      do 75 i = 1, n
      z = h(ij) + sig * w(iv+i) * w(iv+i)
      if (z .le. zero) z = dmin
      if (z .lt. dmin) dmin = z
      h(ij) = z
      w(ib+i) = w(iv+i) * sig / z
      sig = sig - w(ib+i) * w(ib+i) * z
      ij = ij + np - i
  75  continue
      ij = 1
      do 80 i = 1, n1
      ij = ij + 1
      i1 = i + 1
      do 80 j = i1, n
      w(iu+j) = w(iu+j) - h(ij) * w(iv+i)
      h(ij) = h(ij) + w(ib+i) * w(iu+j)
      ij = ij + 1
  80  continue
      go to (60, 20), link
  90  continue
      iexit = 3
      go to 94
  92  continue
      if (idiff .eq. 2) go to 94
      idiff = 2
      go to 17
  94  continue
      if (iprint .eq. 0) return
      write (6,1005) itn, ifn, iexit
1005  format (1x,'itn = ',i5, ' ifn = ',i5,' iexit = ',i5)
      write (6,1002) f
      write (6,1003) (x(i), i = 1, n)
      write (6,1004) (g(i), i = 1, n)
      return
 100  continue
      do 101 i = 1, n
      z = hh * xm(i)
      w(i) = w(i) + z
      call funct (n, w, f1)
      g(i) = (f1 - f) / z
      w(i) = w(i) - z
 101  continue
      ifn = ifn + n
      go to (18, 54), link
 110  continue
      do 111 i = 1, n
      z = hh * xm(i)
      w(i) = w(i) + z
      call funct (n, w, f1)
      w(i) = w(i) - z - z
      call funct (n, w, f2)
      g(i) = (f1 - f2) / (two * z)
      w(i) = w(i) + z
 111  continue
      ifn = ifn + n + n
      go to (18, 54), link
      end 

さて、これを手作業でgoto文を除去できる猛者は少ないのではないでしょうか。そこで、生成AIの力を使いましょう。

Codexによる分析

実際にCodexに何を入れて解析しているか、興味のある方もいると思いますので、それをそのままお見せします。まず、lisalancコードのファイル一式をダウンロードし、一つのフォルダに入れておきます。

README_lisalanc
lisalanc.input
lisalanc.green
lisalanc.f
lisalanc.diff
lisalanc.dat
lisalanc.andpar
lisadiag.green

そして、分析させます。プロンプトは、


このlisalanc.fの中のminimizeを現代風のFortranに書き換えることを目的としています。まず、このコードが何をしているかを調べて下さい。また、その中で特にminimizeが何をしているか調べて下さい。多分、非線形関数の最小化をしています。引数が何を示しているか調べて下さい


としましょう。
出力結果はスクリーンショットとして貼ります。

スクリーンショット 2026-03-16 16.44.53.png

スクリーンショット 2026-03-16 16.45.10.png
スクリーンショット 2026-03-16 16.45.21.png
スクリーンショット 2026-03-16 16.45.58.png
スクリーンショット 2026-03-16 16.46.10.png
スクリーンショット 2026-03-16 16.46.16.png
のように分析してくれました。

テストの作成

次は、コードを書き換えるので、オリジナルの出力を正解としたいです。ですので、


次に、minimize単体を取り出してテストするテストコードを実装したいです。これは適当な非線形関数で問題ありません。解析的に答えがわかっているものを使って、その答えになるかどうかをテストしましょう。まず、minimizeを取り出して別のファイルとして、非線形関数をテストするようなコードを作り、cmakeファイルを作成し、makeして走るか確認しましょう。なお、テストする関数は二つにします。まず実装を計画して下さい


としました。
その結果、

スクリーンショット 2026-03-16 16.55.23.png
スクリーンショット 2026-03-16 16.55.30.png
スクリーンショット 2026-03-16 16.55.35.png
となりました。いい感じに作れそうですね。

テストの方針が立ったので、テストを作ります。


では、それで実装して下さい


とします。

その結果、
スクリーンショット 2026-03-16 16.57.03.png
スクリーンショット 2026-03-16 16.57.08.png

となりました。テストもされたようですね。

goto文の削除

悪名高いgoto文を削除したいわけです。ということでまず分析をさせます。


次に、現代Fortranにしますが、まず、goto文の解析を行なって下さい。どのようにすればdoループでかけるか分析して下さい


とします。

スクリーンショット 2026-03-16 16.59.31.png
スクリーンショット 2026-03-16 16.59.43.png
スクリーンショット 2026-03-16 16.59.53.png
スクリーンショット 2026-03-16 17.00.00.png
スクリーンショット 2026-03-16 17.00.11.png
スクリーンショット 2026-03-16 17.00.19.png
スクリーンショット 2026-03-16 17.00.26.png
スクリーンショット 2026-03-16 17.00.34.png
スクリーンショット 2026-03-16 17.00.42.png
スクリーンショット 2026-03-16 17.00.49.png
スクリーンショット 2026-03-16 17.00.54.png

では、解析が終わりましたので、


了解です。これでいきましょう。リファクタリングをするたびに、テスト関数が正しくパスすることを確認して下さい。本質的に同じコードのはずなので、機械精度で同じ結果になるべきです。


として、リファクタリングをしましょう。テストをしながら正しいことを確認しつつ動きます。

以下は出力は省略します。次のプロンプトとしては、


次に、コードの意味ごとに関数やサブルーチンに分けたいです。メインはそれらの関数やサブルーチンを呼ぶようにすれば、読んでわかるサブルーチンになるはずです。どう分ければいいか実装の計画を立てて下さい。


です。
その次は、


ではそれでいきましょう


です。
そのあとは色々言ってくるので確認しつつOKを出します。


それでいきましょう


どんどんやります、

結果

色々やった結果、できたコードが以下の通りです。

module minimize_module
  implicit none
  private

  public :: minimize

  abstract interface
    subroutine objective_function(n, x, f)
      integer, intent(in) :: n
      real(8), intent(in) :: x(n)
      real(8), intent(out) :: f
    end subroutine objective_function
  end interface

  type :: workspace_layout
    integer :: n
    integer :: np
    integer :: n1
    integer :: nn
    integer :: search_offset
    integer :: update_offset
    integer :: transformed_offset
    integer :: buffer_offset
  end type workspace_layout

  type :: minimize_state
    integer :: idiff
    integer :: itn
    integer :: ifn
    integer :: iexit
    real(8) :: dmin
    real(8) :: df
    real(8) :: gs0
    real(8) :: aeps
    real(8) :: alpha
    real(8) :: ff
    real(8) :: tot
  end type minimize_state

contains

  subroutine minimize(funct, n, x, f, g, h, w, dfn, xm, hh, eps, mode, maxfn, iprint, iexit)
    implicit none

    procedure(objective_function) :: funct
    integer, intent(in) :: n, mode, maxfn, iprint
    real(8), intent(inout) :: x(n)
    real(8), intent(out) :: f, g(n)
    real(8), intent(inout) :: h(n * (n + 1) / 2)
    real(8), intent(inout) :: w(4 * n)
    real(8), intent(in) :: dfn, xm(n), hh, eps
    integer, intent(out) :: iexit

    real(8), parameter :: zero = 0.0d0, one = 1.0d0, two = 2.0d0
    type(workspace_layout) :: layout
    type(minimize_state) :: state
    integer :: i
    real(8) :: z
    logical :: converged_in_line_search, maxfn_reached

    if (iprint /= 0) write (6, '(a)') ' entry into minimize'

    call initialize_layout(n, layout)
    call initialize_state(state)

    call initialize_metric(mode, layout, h, state%dmin)
    if (state%dmin <= zero) then
      iexit = state%iexit
      return
    end if

    z = f
    call funct(n, x, f)
    call initialize_objective_state(dfn, z, f, one, state)

    main_loop: do
      call set_work_to_current_point(n, x, w)
      call compute_gradient(funct, n, w, f, g, xm, hh, state%idiff, state%ifn)
      if (state%ifn >= maxfn) then
        state%iexit = 3
        exit main_loop
      end if

      call print_iteration_if_needed(iprint, state%itn, state%ifn, n, f, x, g)

      state%itn = state%itn + 1
      call compute_search_direction(layout, h, g, w, xm, eps, state%gs0, state%aeps)
      state%iexit = 2
      if (state%gs0 >= zero) then
        if (state%idiff == 2) exit main_loop
        state%idiff = 2
        cycle main_loop
      end if

      state%alpha = -two * state%df / state%gs0
      if (state%alpha > one) state%alpha = one
      state%iexit = 1
      call perform_line_search(funct, layout, x, f, w, state%alpha, state%aeps, state%ifn, maxfn, state%tot, state%ff, &
        converged_in_line_search, maxfn_reached)

      if (maxfn_reached) then
        state%iexit = 3
        exit main_loop
      end if

      if (converged_in_line_search .or. state%tot < state%aeps) then
        state%iexit = 2
        if (state%idiff == 2) exit main_loop
        state%idiff = 2
        cycle main_loop
      end if

      state%alpha = state%tot
      call update_quasi_newton_metric(funct, layout, x, f, g, h, w, xm, hh, state%idiff, state%ifn, maxfn, &
        state%gs0, state%alpha, state%ff, state%dmin, state%df, maxfn_reached)
      if (maxfn_reached) then
        state%iexit = 3
        exit main_loop
      end if
    end do main_loop

    iexit = state%iexit
    if (iprint == 0) return
    write (6, '(1x,a,i5,a,i5,a,i5)') 'itn = ', state%itn, ' ifn = ', state%ifn, ' iexit = ', iexit
    write (6, '(1x,a,e15.7)') 'f = ', f
    write (6, '(1x,a,4e15.7,/(5x,4e15.7))') 'x = ', (x(i), i = 1, n)
    write (6, '(1x,a,4e15.7,/(5x,4e15.7))') 'g = ', (g(i), i = 1, n)
    return

  end subroutine minimize

  subroutine initialize_layout(n, layout)
    implicit none

    integer, intent(in) :: n
    type(workspace_layout), intent(out) :: layout

    layout%n = n
    layout%np = n + 1
    layout%n1 = n - 1
    layout%nn = (n * layout%np) / 2
    layout%search_offset = n
    layout%update_offset = n
    layout%transformed_offset = 2 * n
    layout%buffer_offset = 3 * n
  end subroutine initialize_layout

  subroutine initialize_state(state)
    implicit none

    type(minimize_state), intent(out) :: state

    state%idiff = 1
    state%itn = 0
    state%ifn = 0
    state%iexit = 0
    state%dmin = 0.0d0
    state%df = 0.0d0
    state%gs0 = 0.0d0
    state%aeps = 0.0d0
    state%alpha = 0.0d0
    state%ff = 0.0d0
    state%tot = 0.0d0
  end subroutine initialize_state

  subroutine initialize_objective_state(dfn, previous_f, current_f, one, state)
    implicit none

    real(8), intent(in) :: dfn, previous_f, current_f, one
    type(minimize_state), intent(inout) :: state

    state%ifn = 1
    state%df = dfn
    if (dfn == 0.0d0) state%df = current_f - previous_f
    if (dfn < 0.0d0) state%df = abs(state%df * current_f)
    if (state%df <= 0.0d0) state%df = one
  end subroutine initialize_objective_state

  subroutine set_work_to_current_point(n, x, w)
    implicit none

    integer, intent(in) :: n
    real(8), intent(in) :: x(n)
    real(8), intent(inout) :: w(4 * n)

    w(1:n) = x
  end subroutine set_work_to_current_point

  subroutine initialize_metric(mode, layout, h, dmin)
    implicit none

    integer, intent(in) :: mode
    type(workspace_layout), intent(in) :: layout
    real(8), intent(inout) :: h(layout%nn)
    real(8), intent(out) :: dmin

    integer :: i, j, i1, ij, jk, ik, k
    real(8) :: z, zz

    select case (mode)
    case (1)
      ij = size(h) + 1
      do i = 1, layout%n
        do j = 1, i
          ij = ij - 1
          h(ij) = 0.0d0
        end do
        h(ij) = 1.0d0
      end do
    case (2)
      ij = 1
      do i = 2, layout%n
        z = h(ij)
        if (z <= 0.0d0) then
          dmin = z
          return
        end if
        ij = ij + 1
        i1 = ij
        do j = i, layout%n
          zz = h(ij)
          h(ij) = h(ij) / z
          jk = ij
          ik = i1
          do k = i, j
            jk = jk + layout%np - k
            h(jk) = h(jk) - h(ik) * zz
            ik = ik + 1
          end do
          ij = ij + 1
        end do
      end do
      if (h(ij) <= 0.0d0) then
        dmin = h(ij)
        return
      end if
    case (3)
      continue
    case default
      error stop 'unsupported mode in minimize'
    end select

    ij = layout%np
    dmin = h(1)
    do i = 2, layout%n
      if (h(ij) < dmin) dmin = h(ij)
      ij = ij + layout%np - i
    end do
  end subroutine initialize_metric

  subroutine compute_gradient(funct, n, w, f, g, xm, hh, idiff, ifn)
    implicit none

    procedure(objective_function) :: funct
    real(8), parameter :: two = 2.0d0
    integer, intent(in) :: n, idiff
    integer, intent(inout) :: ifn
    real(8), intent(inout) :: w(4 * n)
    real(8), intent(in) :: f, xm(n), hh
    real(8), intent(out) :: g(n)

    integer :: i
    real(8) :: z, f1, f2

    if (idiff <= 1) then
      do i = 1, n
        z = hh * xm(i)
        w(i) = w(i) + z
        call funct(n, w(1:n), f1)
        g(i) = (f1 - f) / z
        w(i) = w(i) - z
      end do
      ifn = ifn + n
      return
    end if

    do i = 1, n
      z = hh * xm(i)
      w(i) = w(i) + z
      call funct(n, w(1:n), f1)
      w(i) = w(i) - z - z
      call funct(n, w(1:n), f2)
      g(i) = (f1 - f2) / (two * z)
      w(i) = w(i) + z
    end do
    ifn = ifn + n + n
  end subroutine compute_gradient

  subroutine print_iteration_if_needed(iprint, itn, ifn, n, f, x, g)
    implicit none

    integer, intent(in) :: iprint, itn, ifn, n
    integer :: i
    real(8), intent(in) :: f, x(n), g(n)

    if (iprint == 0) return
    if (mod(itn, iprint) /= 0) return

    write (6, '(1x,a,i5,a,i5)') 'itn = ', itn, ' ifn = ', ifn
    write (6, '(1x,a,e15.7)') 'f = ', f
    if (iprint >= 0) then
      write (6, '(1x,a,4e15.7,/(5x,4e15.7))') 'x = ', (x(i), i = 1, n)
      write (6, '(1x,a,4e15.7,/(5x,4e15.7))') 'g = ', (g(i), i = 1, n)
    end if
  end subroutine print_iteration_if_needed

  subroutine compute_search_direction(layout, h, g, w, xm, eps, gs0, aeps)
    implicit none

    type(workspace_layout), intent(in) :: layout
    real(8), intent(in) :: h(layout%nn), g(layout%n), xm(layout%n), eps
    real(8), intent(inout) :: w(4 * layout%n)
    real(8), intent(out) :: gs0, aeps

    integer :: i, j, i1, ij
    real(8) :: z

    w(1) = -g(1)
    do i = 2, layout%n
      ij = i
      i1 = i - 1
      z = -g(i)
      do j = 1, i1
        z = z - h(ij) * w(j)
        ij = ij + layout%n - j
      end do
      w(i) = z
    end do

    w(layout%search_offset + layout%n) = w(layout%n) / h(layout%nn)
    ij = layout%nn
    do i = 1, layout%n1
      ij = ij - 1
      z = 0.0d0
      do j = 1, i
        z = z + h(ij) * w(layout%search_offset + layout%np - j)
        ij = ij - 1
      end do
      w(layout%search_offset + layout%n - i) = w(layout%n - i) / h(ij) - z
    end do

    z = 0.0d0
    gs0 = 0.0d0
    do i = 1, layout%n
      if (z * xm(i) < abs(w(layout%search_offset + i))) z = abs(w(layout%search_offset + i)) / xm(i)
      gs0 = gs0 + g(i) * w(layout%search_offset + i)
    end do
    aeps = eps / z
  end subroutine compute_search_direction

  subroutine perform_line_search(funct, layout, x, f, w, alpha, aeps, ifn, maxfn, tot, ff, &
      converged, maxfn_reached)
    implicit none

    procedure(objective_function) :: funct
    real(8), parameter :: half = 0.5d0, one = 1.0d0, two = 2.0d0
    type(workspace_layout), intent(in) :: layout
    integer, intent(in) :: maxfn
    integer, intent(inout) :: ifn
    real(8), intent(inout) :: x(layout%n), f, w(4 * layout%n), alpha
    real(8), intent(in) :: aeps
    real(8), intent(out) :: tot, ff
    logical, intent(out) :: converged, maxfn_reached

    integer :: i, int_mode
    real(8) :: z, f1, f2
    logical :: line_search_done

    ff = f
    tot = 0.0d0
    int_mode = 0
    converged = .false.
    maxfn_reached = .false.
    line_search_done = .false.

    do while (.not. line_search_done)
      if (ifn >= maxfn) then
        maxfn_reached = .true.
        exit
      end if
      do i = 1, layout%n
        w(i) = x(i) + alpha * w(layout%search_offset + i)
      end do
      call funct(layout%n, w(1:layout%n), f1)
      ifn = ifn + 1

      if (f1 < f) then
        f2 = f
        tot = tot + alpha
        x(1:layout%n) = w(1:layout%n)
        f = f1

        if (int_mode < 1) then
          if (ifn >= maxfn) then
            maxfn_reached = .true.
            exit
          end if
          do i = 1, layout%n
            w(i) = x(i) + alpha * w(layout%search_offset + i)
          end do
          call funct(layout%n, w(1:layout%n), f1)
          ifn = ifn + 1
          if (f1 >= f) then
            line_search_done = .true.
          else
            if ((f1 + f2 >= f + f) .and. (7.0d0 * f1 + 5.0d0 * f2 > 12.0d0 * f)) int_mode = 2
            tot = tot + alpha
            alpha = two * alpha
          end if
        else
          line_search_done = .true.
        end if
      else
        if (alpha < aeps) then
          converged = .true.
          exit
        end if
        if (ifn >= maxfn) then
          maxfn_reached = .true.
          exit
        end if
        alpha = half * alpha
        do i = 1, layout%n
          w(i) = x(i) + alpha * w(layout%search_offset + i)
        end do
        call funct(layout%n, w(1:layout%n), f2)
        ifn = ifn + 1
        if (f2 < f) then
          tot = tot + alpha
          f = f2
          x(1:layout%n) = w(1:layout%n)
          line_search_done = .true.
        else
          z = 0.1d0
          if (f1 + f > f2 + f2) z = one + half * (f - f1) / (f + f1 - f2 - f2)
          if (z < 0.1d0) z = 0.1d0
          alpha = z * alpha
          int_mode = 1
        end if
      end if
    end do
  end subroutine perform_line_search

  subroutine update_quasi_newton_metric(funct, layout, x, f, g, h, w, xm, hh, &
      idiff, ifn, maxfn, gs0, alpha, ff, dmin, df, maxfn_reached)
    implicit none

    procedure(objective_function) :: funct
    real(8), parameter :: one = 1.0d0
    type(workspace_layout), intent(in) :: layout
    integer, intent(in) :: maxfn, idiff
    integer, intent(inout) :: ifn
    real(8), intent(inout) :: x(layout%n), f, g(layout%n), h(layout%nn), w(4 * layout%n), dmin, df
    real(8), intent(in) :: xm(layout%n), hh, gs0, alpha, ff
    logical, intent(out) :: maxfn_reached

    integer :: i
    real(8) :: gys, dgs, sig, zz, z

    do i = 1, layout%n
      w(i) = x(i)
      w(layout%buffer_offset + i) = g(i)
    end do
    call compute_gradient(funct, layout%n, w, f, g, xm, hh, idiff, ifn)
    if (ifn >= maxfn) then
      maxfn_reached = .true.
      return
    end if

    gys = 0.0d0
    do i = 1, layout%n
      w(i) = w(layout%buffer_offset + i)
      gys = gys + g(i) * w(layout%search_offset + i)
    end do
    df = ff - f
    dgs = gys - gs0
    maxfn_reached = .false.
    if (dgs <= 0.0d0) return

    if (dgs + alpha * gs0 <= 0.0d0) then
      do i = 1, layout%n
        w(layout%update_offset + i) = g(i) - w(i)
      end do
      sig = one / (alpha * dgs)
      call apply_metric_update(layout, h, w, sig, dmin)

      do i = 1, layout%n
        w(layout%update_offset + i) = w(i)
      end do
      sig = one / gs0
      call apply_metric_update(layout, h, w, sig, dmin)
      return
    end if

    zz = alpha / (dgs - alpha * gs0)
    z = dgs * zz - one
    do i = 1, layout%n
      w(layout%update_offset + i) = z * w(i) + g(i)
    end do
    sig = one / (zz * dgs * dgs)
    call apply_metric_update(layout, h, w, sig, dmin)

    do i = 1, layout%n
      w(layout%update_offset + i) = w(i)
    end do
    sig = -zz
    call apply_metric_update(layout, h, w, sig, dmin)
  end subroutine update_quasi_newton_metric

  subroutine apply_metric_update(layout, h, w, sig, dmin)
    implicit none

    type(workspace_layout), intent(in) :: layout
    real(8), intent(inout) :: h(layout%nn), w(4 * layout%n), sig, dmin

    integer :: i, j, i1, ij
    real(8) :: z

    w(layout%transformed_offset + 1) = w(layout%update_offset + 1)
    do i = 2, layout%n
      ij = i
      i1 = i - 1
      z = w(layout%update_offset + i)
      do j = 1, i1
        z = z - h(ij) * w(layout%transformed_offset + j)
        ij = ij + layout%n - j
      end do
      w(layout%transformed_offset + i) = z
    end do

    ij = 1
    do i = 1, layout%n
      z = h(ij) + sig * w(layout%transformed_offset + i) * w(layout%transformed_offset + i)
      if (z <= 0.0d0) z = dmin
      if (z < dmin) dmin = z
      h(ij) = z
      w(layout%buffer_offset + i) = w(layout%transformed_offset + i) * sig / z
      sig = sig - w(layout%buffer_offset + i) * w(layout%buffer_offset + i) * z
      ij = ij + layout%np - i
    end do

    ij = 1
    do i = 1, layout%n1
      ij = ij + 1
      i1 = i + 1
      do j = i1, layout%n
        w(layout%update_offset + j) = w(layout%update_offset + j) - h(ij) * w(layout%transformed_offset + i)
        h(ij) = h(ij) + w(layout%buffer_offset + i) * w(layout%update_offset + j)
        ij = ij + 1
      end do
    end do
  end subroutine apply_metric_update

end module minimize_module

なんということでしょう。あれほど大量にあったgoto文が、見事に構造化されたサブルーチン群で表現されているではありませんか。minimizeには名前に意味のあるサブルーチンが並び、何をしているか一目瞭然です。匠の技が光りますね。

速度比較

リファクタリングして遅くなったら困るので、速度比較もしました。その結果は、

original Fortran: 平均 0.422 s
refactored Fortran: 平均 0.430 s
1 回あたりに直すと、おおよそ
original Fortran: 42.2 us
refactored Fortran: 43.0 us

です。ほぼ変わらない、と言っていいと思います。

95
50
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
95
50

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?