Kotlin
algorithm
Sequence
反復法

【Kotlin】「初めて終了条件を満たしたitを含むSequence」を返すtakeWhile的なものを作った

この記事で使っていたリポジトリはこちら

問題になったこと

takeWhileで返って来るSequenceは「初めて終了条件を満たしたitより前の範囲」となり、終了条件を初めて満たしたitは含まれません。
しかし、以下のように、反復法で解を求めるコードでは、終了条件を初めて満たしたitとは収束した解であるため、これが取れないと困ります。

Sequenceを使った反復法(二分法)の実装例、このままでは収束した解が得られない
//二分法の中点計算用
fun calcHalf(a: Double, b: Double): Double = (a + b) / 2.0
//二分法
fun binaryIterationSequence(
    f: (Double) -> Double, //入力関数
    trueValue: Double, //二分法の真値
    a0: Double, //小さい方の初期値
    b0: Double, //大きい方の初期値
    nMax: Int = 10000 //反復回数上限
): List<Double> = generateSequence(Triple(a0, b0, calcHalf(a0, b0))) { (a, b, c) ->
    when { //更新
        f(a) * f(c) > 0.0 -> Triple(c, b, calcHalf(c, b))
        f(b) * f(c) > 0.0 -> Triple(a, c,  calcHalf(a, c))
        else -> throw RuntimeException("条件判定に失敗しました")
    }
}.map {
    it.third //Tripleから配列へ変換
}.takeWhile {
    abs(f(it) - trueValue) >= 1.0E-15 //収束条件
}.take(nMax).toList() //nMaxまで行ったら終了、Listにして返却

実装を探してみましたが、初めて終了条件を満たしたitまで含めて返してくれるものは見つけられませんでした。

解決方法

ほぼコピペですが、takeForとして自力で実装しました1
以下のコードを適当な位置にコピペし、Sequence.takeForを呼び出すことで、takeWhileと同様に使えます。

takeFor.kt
fun <T> Sequence<T>.takeFor(predicate: (T) -> Boolean): Sequence<T> {
    return TakeForSequence(this, predicate)
}

internal class TakeForSequence<T>
constructor(
    private val sequence: Sequence<T>,
    private val predicate: (T) -> Boolean
) : Sequence<T> {
    override fun iterator(): Iterator<T> = object : Iterator<T> {
        val iterator = sequence.iterator()
        var nextState: Int = -1 // -1 for unknown, 0 for done, 1 for continue
        var nextItem: T? = null
        var doneFlag = false

        private fun calcNext() {
            if (iterator.hasNext()) {
                val item = iterator.next()
                if (predicate(item) && !doneFlag) {
                    nextState = 1
                    nextItem = item
                    return
                } else if(!doneFlag){
                    doneFlag = true
                    nextState = 1
                    nextItem = item
                    return
                }
            }
            nextState = 0
        }

        override fun next(): T {
            if (nextState == -1)
                calcNext() // will change nextState
            if (nextState == 0)
                throw NoSuchElementException()
            @Suppress("UNCHECKED_CAST")
            val result = nextItem as T

            // Clean next to avoid keeping reference on yielded instance
            nextItem = null
            nextState = -1
            return result
        }

        override fun hasNext(): Boolean {
            if (nextState == -1)
                calcNext() // will change nextState
            return nextState == 1
        }
    }
}

実装について

実装の簡単な解説とコピペ元をまとめます。

TakeForSequenceクラス

Kotlinのソースコードから、stdlib/src/kotlin/collections/Sequences.kt#TakeWhileSequenceをコピペし、doneFlag周りを追加しました。

takeFor関数

Kotlinのソースコードから、stdlib/common/src/generated/_Sequences.kt#takeWhileをコピペし、TakeForSequenceクラスに合わせ改変しました。

おまけ

冒頭で紹介したコードをtakeForで書き直したものと、これを使って円周率を計算するサンプルです。
これに関しては別の記事で解説したので詳しくは触れません。

binaryIterationSequence
//二分法の中点計算用
fun calcHalf(a: Double, b: Double): Double = (a + b) / 2.0
//二分法
fun binaryIterationSequence(
    f: (Double) -> Double, //入力関数
    trueValue: Double, //二分法の真値
    a0: Double, //小さい方の初期値
    b0: Double, //大きい方の初期値
    nMax: Int = 10000 //反復回数上限
): List<Double> = generateSequence(Triple(a0, b0, calcHalf(a0, b0))) { (a, b, c) ->
    when { //更新
        f(a) * f(c) > 0.0 -> Triple(c, b, calcHalf(c, b))
        f(b) * f(c) > 0.0 -> Triple(a, c,  calcHalf(a, c))
        else -> throw RuntimeException("条件判定に失敗しました")
    }
}.map {
    it.third //Tripleから配列へ変換
}.takeFor {
    abs(f(it) - trueValue) >= 1.0E-15 //収束条件
}.take(nMax).toList() //nMaxまで行ったら終了、Listにして返却
testBinarySequence
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import kotlin.math.PI

class SequenceTest{
    @Test
    fun testBinarySequence(){
        val ans = binaryIterationSequence(::binaryF, 0.0, 0.0, 5.0)
        Assertions.assertTrue(10000 > ans.size)
        Assertions.assertEquals(PI, ans.last(), 1.0E-15)
    }
}


  1. この名前はテキトーにつけたので、より良い名前が有れば教えてほしいです。