Edited at

【Kotlin】「初めて終了条件を満たしたitを含むSequence」を返すtakeWhile的なものを作った

この記事で使っていたリポジトリはこちら


問題になったこと

takeWhileで返って来るSequenceは「初めて終了条件を満たしたitより前の範囲」となり、終了条件を初めて満たしたitは含まれません。

しかし、以下のように、反復法で解を求めるコードでは、終了条件を初めて満たしたitとは収束した解であるため、これが取れないと困ります。


Sequenceを使った反復法(二分法)の実装例、このままでは収束した解が得られない

//二分法の中点計算用

fun calcHalf(a: Double, b: Double): Double = (a + b) / 2.0
//二分法
fun binaryIterationSequence(
f: (Double) -> Double, //入力関数
trueValue: Double, //二分法の真値
a0: Double, //小さい方の初期値
b0: Double, //大きい方の初期値
nMax: Int = 10000 //反復回数上限
): List<Double> = generateSequence(Triple(a0, b0, calcHalf(a0, b0))) { (a, b, c) ->
when { //更新
f(a) * f(c) > 0.0 -> Triple(c, b, calcHalf(c, b))
f(b) * f(c) > 0.0 -> Triple(a, c, calcHalf(a, c))
else -> throw RuntimeException("条件判定に失敗しました")
}
}.map {
it.third //Tripleから配列へ変換
}.takeWhile {
abs(f(it) - trueValue) >= 1.0E-15 //収束条件
}.take(nMax).toList() //nMaxまで行ったら終了、Listにして返却

実装を探してみましたが、初めて終了条件を満たしたitまで含めて返してくれるものは見つけられませんでした。


解決方法

ほぼコピペですが、takeForとして自力で実装しました1

以下のコードを適当な位置にコピペし、Sequence.takeForを呼び出すことで、takeWhileと同様に使えます。


takeFor.kt

fun <T> Sequence<T>.takeFor(predicate: (T) -> Boolean): Sequence<T> {

return TakeForSequence(this, predicate)
}

internal class TakeForSequence<T>
constructor(
private val sequence: Sequence<T>,
private val predicate: (T) -> Boolean
) : Sequence<T> {
override fun iterator(): Iterator<T> = object : Iterator<T> {
val iterator = sequence.iterator()
var nextState: Int = -1 // -1 for unknown, 0 for done, 1 for continue
var nextItem: T? = null
var doneFlag = false

private fun calcNext() {
if (iterator.hasNext()) {
val item = iterator.next()
if (predicate(item) && !doneFlag) {
nextState = 1
nextItem = item
return
} else if(!doneFlag){
doneFlag = true
nextState = 1
nextItem = item
return
}
}
nextState = 0
}

override fun next(): T {
if (nextState == -1)
calcNext() // will change nextState
if (nextState == 0)
throw NoSuchElementException()
@Suppress("UNCHECKED_CAST")
val result = nextItem as T

// Clean next to avoid keeping reference on yielded instance
nextItem = null
nextState = -1
return result
}

override fun hasNext(): Boolean {
if (nextState == -1)
calcNext() // will change nextState
return nextState == 1
}
}
}



実装について

実装の簡単な解説とコピペ元をまとめます。


TakeForSequenceクラス

Kotlinのソースコードから、stdlib/src/kotlin/collections/Sequences.kt#TakeWhileSequenceをコピペし、doneFlag周りを追加しました。


takeFor関数

Kotlinのソースコードから、stdlib/common/src/generated/_Sequences.kt#takeWhileをコピペし、TakeForSequenceクラスに合わせ改変しました。


おまけ

冒頭で紹介したコードをtakeForで書き直したものと、これを使って円周率を計算するサンプルです。

これに関しては別の記事で解説したので詳しくは触れません。


binaryIterationSequence

//二分法の中点計算用

fun calcHalf(a: Double, b: Double): Double = (a + b) / 2.0
//二分法
fun binaryIterationSequence(
f: (Double) -> Double, //入力関数
trueValue: Double, //二分法の真値
a0: Double, //小さい方の初期値
b0: Double, //大きい方の初期値
nMax: Int = 10000 //反復回数上限
): List<Double> = generateSequence(Triple(a0, b0, calcHalf(a0, b0))) { (a, b, c) ->
when { //更新
f(a) * f(c) > 0.0 -> Triple(c, b, calcHalf(c, b))
f(b) * f(c) > 0.0 -> Triple(a, c, calcHalf(a, c))
else -> throw RuntimeException("条件判定に失敗しました")
}
}.map {
it.third //Tripleから配列へ変換
}.takeFor {
abs(f(it) - trueValue) >= 1.0E-15 //収束条件
}.take(nMax).toList() //nMaxまで行ったら終了、Listにして返却


testBinarySequence

import org.junit.jupiter.api.Assertions

import org.junit.jupiter.api.Test
import kotlin.math.PI

class SequenceTest{
@Test
fun testBinarySequence(){
val ans = binaryIterationSequence(::binaryF, 0.0, 0.0, 5.0)
Assertions.assertTrue(10000 > ans.size)
Assertions.assertEquals(PI, ans.last(), 1.0E-15)
}
}






  1. この名前はテキトーにつけたので、より良い名前が有れば教えてほしいです。