この記事で使っていたリポジトリはこちら。
問題になったこと
takeWhile
で返って来るSequence
は「初めて終了条件を満たしたit
より前の範囲」となり、終了条件を初めて満たしたit
は含まれません。
しかし、以下のように、反復法で解を求めるコードでは、終了条件を初めて満たしたit
とは収束した解であるため、これが取れないと困ります。
Sequenceを使った反復法(二分法)の実装例、このままでは収束した解が得られない
//二分法の中点計算用
fun calcHalf(a: Double, b: Double): Double = (a + b) / 2.0
//二分法
fun binaryIterationSequence(
f: (Double) -> Double, //入力関数
trueValue: Double, //二分法の真値
a0: Double, //小さい方の初期値
b0: Double, //大きい方の初期値
nMax: Int = 10000 //反復回数上限
): List<Double> = generateSequence(Triple(a0, b0, calcHalf(a0, b0))) { (a, b, c) ->
when { //更新
f(a) * f(c) > 0.0 -> Triple(c, b, calcHalf(c, b))
f(b) * f(c) > 0.0 -> Triple(a, c, calcHalf(a, c))
else -> throw RuntimeException("条件判定に失敗しました")
}
}.map {
it.third //Tripleから配列へ変換
}.takeWhile {
abs(f(it) - trueValue) >= 1.0E-15 //収束条件
}.take(nMax).toList() //nMaxまで行ったら終了、Listにして返却
実装を探してみましたが、初めて終了条件を満たしたit
まで含めて返してくれるものは見つけられませんでした。
解決方法
ほぼコピペですが、takeFor
として自力で実装しました1。
以下のコードを適当な位置にコピペし、Sequence.takeFor
を呼び出すことで、takeWhile
と同様に使えます。
takeFor.kt
fun <T> Sequence<T>.takeFor(predicate: (T) -> Boolean): Sequence<T> {
return TakeForSequence(this, predicate)
}
internal class TakeForSequence<T>
constructor(
private val sequence: Sequence<T>,
private val predicate: (T) -> Boolean
) : Sequence<T> {
override fun iterator(): Iterator<T> = object : Iterator<T> {
val iterator = sequence.iterator()
var nextState: Int = -1 // -1 for unknown, 0 for done, 1 for continue
var nextItem: T? = null
var doneFlag = false
private fun calcNext() {
if (iterator.hasNext()) {
val item = iterator.next()
if (predicate(item) && !doneFlag) {
nextState = 1
nextItem = item
return
} else if(!doneFlag){
doneFlag = true
nextState = 1
nextItem = item
return
}
}
nextState = 0
}
override fun next(): T {
if (nextState == -1)
calcNext() // will change nextState
if (nextState == 0)
throw NoSuchElementException()
@Suppress("UNCHECKED_CAST")
val result = nextItem as T
// Clean next to avoid keeping reference on yielded instance
nextItem = null
nextState = -1
return result
}
override fun hasNext(): Boolean {
if (nextState == -1)
calcNext() // will change nextState
return nextState == 1
}
}
}
実装について
実装の簡単な解説とコピペ元をまとめます。
TakeForSequence
クラス
Kotlinのソースコードから、stdlib/src/kotlin/collections/Sequences.kt#TakeWhileSequenceをコピペし、doneFlag
周りを追加しました。
takeFor
関数
Kotlinのソースコードから、stdlib/common/src/generated/_Sequences.kt#takeWhileをコピペし、TakeForSequence
クラスに合わせ改変しました。
おまけ
冒頭で紹介したコードをtakeFor
で書き直したものと、これを使って円周率を計算するサンプルです。
これに関しては別の記事で解説したので詳しくは触れません。
binaryIterationSequence
//二分法の中点計算用
fun calcHalf(a: Double, b: Double): Double = (a + b) / 2.0
//二分法
fun binaryIterationSequence(
f: (Double) -> Double, //入力関数
trueValue: Double, //二分法の真値
a0: Double, //小さい方の初期値
b0: Double, //大きい方の初期値
nMax: Int = 10000 //反復回数上限
): List<Double> = generateSequence(Triple(a0, b0, calcHalf(a0, b0))) { (a, b, c) ->
when { //更新
f(a) * f(c) > 0.0 -> Triple(c, b, calcHalf(c, b))
f(b) * f(c) > 0.0 -> Triple(a, c, calcHalf(a, c))
else -> throw RuntimeException("条件判定に失敗しました")
}
}.map {
it.third //Tripleから配列へ変換
}.takeFor {
abs(f(it) - trueValue) >= 1.0E-15 //収束条件
}.take(nMax).toList() //nMaxまで行ったら終了、Listにして返却
testBinarySequence
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import kotlin.math.PI
class SequenceTest{
@Test
fun testBinarySequence(){
val ans = binaryIterationSequence(::binaryF, 0.0, 0.0, 5.0)
Assertions.assertTrue(10000 > ans.size)
Assertions.assertEquals(PI, ans.last(), 1.0E-15)
}
}
-
この名前はテキトーにつけたので、より良い名前が有れば教えてほしいです。 ↩