LoginSignup
1

More than 5 years have passed since last update.

【Kotlin】「初めて終了条件を満たしたitを含むSequence」を返すtakeWhile的なものを作った

Last updated at Posted at 2019-01-06

この記事で使っていたリポジトリはこちら

問題になったこと

takeWhileで返って来るSequenceは「初めて終了条件を満たしたitより前の範囲」となり、終了条件を初めて満たしたitは含まれません。
しかし、以下のように、反復法で解を求めるコードでは、終了条件を初めて満たしたitとは収束した解であるため、これが取れないと困ります。

Sequenceを使った反復法(二分法)の実装例、このままでは収束した解が得られない
//二分法の中点計算用
fun calcHalf(a: Double, b: Double): Double = (a + b) / 2.0
//二分法
fun binaryIterationSequence(
    f: (Double) -> Double, //入力関数
    trueValue: Double, //二分法の真値
    a0: Double, //小さい方の初期値
    b0: Double, //大きい方の初期値
    nMax: Int = 10000 //反復回数上限
): List<Double> = generateSequence(Triple(a0, b0, calcHalf(a0, b0))) { (a, b, c) ->
    when { //更新
        f(a) * f(c) > 0.0 -> Triple(c, b, calcHalf(c, b))
        f(b) * f(c) > 0.0 -> Triple(a, c,  calcHalf(a, c))
        else -> throw RuntimeException("条件判定に失敗しました")
    }
}.map {
    it.third //Tripleから配列へ変換
}.takeWhile {
    abs(f(it) - trueValue) >= 1.0E-15 //収束条件
}.take(nMax).toList() //nMaxまで行ったら終了、Listにして返却

実装を探してみましたが、初めて終了条件を満たしたitまで含めて返してくれるものは見つけられませんでした。

解決方法

ほぼコピペですが、takeForとして自力で実装しました1
以下のコードを適当な位置にコピペし、Sequence.takeForを呼び出すことで、takeWhileと同様に使えます。

takeFor.kt
fun <T> Sequence<T>.takeFor(predicate: (T) -> Boolean): Sequence<T> {
    return TakeForSequence(this, predicate)
}

internal class TakeForSequence<T>
constructor(
    private val sequence: Sequence<T>,
    private val predicate: (T) -> Boolean
) : Sequence<T> {
    override fun iterator(): Iterator<T> = object : Iterator<T> {
        val iterator = sequence.iterator()
        var nextState: Int = -1 // -1 for unknown, 0 for done, 1 for continue
        var nextItem: T? = null
        var doneFlag = false

        private fun calcNext() {
            if (iterator.hasNext()) {
                val item = iterator.next()
                if (predicate(item) && !doneFlag) {
                    nextState = 1
                    nextItem = item
                    return
                } else if(!doneFlag){
                    doneFlag = true
                    nextState = 1
                    nextItem = item
                    return
                }
            }
            nextState = 0
        }

        override fun next(): T {
            if (nextState == -1)
                calcNext() // will change nextState
            if (nextState == 0)
                throw NoSuchElementException()
            @Suppress("UNCHECKED_CAST")
            val result = nextItem as T

            // Clean next to avoid keeping reference on yielded instance
            nextItem = null
            nextState = -1
            return result
        }

        override fun hasNext(): Boolean {
            if (nextState == -1)
                calcNext() // will change nextState
            return nextState == 1
        }
    }
}

実装について

実装の簡単な解説とコピペ元をまとめます。

TakeForSequenceクラス

Kotlinのソースコードから、stdlib/src/kotlin/collections/Sequences.kt#TakeWhileSequenceをコピペし、doneFlag周りを追加しました。

takeFor関数

Kotlinのソースコードから、stdlib/common/src/generated/_Sequences.kt#takeWhileをコピペし、TakeForSequenceクラスに合わせ改変しました。

おまけ

冒頭で紹介したコードをtakeForで書き直したものと、これを使って円周率を計算するサンプルです。
これに関しては別の記事で解説したので詳しくは触れません。

binaryIterationSequence
//二分法の中点計算用
fun calcHalf(a: Double, b: Double): Double = (a + b) / 2.0
//二分法
fun binaryIterationSequence(
    f: (Double) -> Double, //入力関数
    trueValue: Double, //二分法の真値
    a0: Double, //小さい方の初期値
    b0: Double, //大きい方の初期値
    nMax: Int = 10000 //反復回数上限
): List<Double> = generateSequence(Triple(a0, b0, calcHalf(a0, b0))) { (a, b, c) ->
    when { //更新
        f(a) * f(c) > 0.0 -> Triple(c, b, calcHalf(c, b))
        f(b) * f(c) > 0.0 -> Triple(a, c,  calcHalf(a, c))
        else -> throw RuntimeException("条件判定に失敗しました")
    }
}.map {
    it.third //Tripleから配列へ変換
}.takeFor {
    abs(f(it) - trueValue) >= 1.0E-15 //収束条件
}.take(nMax).toList() //nMaxまで行ったら終了、Listにして返却
testBinarySequence
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import kotlin.math.PI

class SequenceTest{
    @Test
    fun testBinarySequence(){
        val ans = binaryIterationSequence(::binaryF, 0.0, 0.0, 5.0)
        Assertions.assertTrue(10000 > ans.size)
        Assertions.assertEquals(PI, ans.last(), 1.0E-15)
    }
}


  1. この名前はテキトーにつけたので、より良い名前が有れば教えてほしいです。 

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1