LoginSignup
0
0

More than 1 year has passed since last update.

折り返しのある系列を巻き戻す その2

Last updated at Posted at 2021-08-02

折り返しのある系列を巻き戻すことを検討している。前回は、折り返し何周目かをあらわす変数$y_n$を使っていたが、そもそもどこで折り返されているのかを判断するのも簡単ではない。そこで、折り返し何周目なのかも自動で調べるようにした。

$y_n$には次のような制約がある。最初は0であり、ある$y_k$は直前と同じか、それより1つ大きい。

$$y_1 = 0$$

$$y_{k+1}\in \{y_k ,y_k+1 \}$$

したがって、$n$が有限であれば、$y_n$の系列は有限個しかない。例えば$n=4$なら、下に示す8通りしかない。

$$0,0,0,0$$

$$0,0,0,1$$

$$0,0,1,1$$

$$0,0,1,2$$

$$0,1,1,1$$

$$0,1,1,2$$

$$0,1,2,2$$

$$0,1,2,3$$

そこですべての$y_n$の系列について前回のアルゴリズムで巻き戻しを行い、その時の誤差が最小のものを選べばよい。

しかし、$n$が大きくなってくると、$y_n$の系列の個数は爆発的に大きくなるので、全探索は難しくなる。一方、$n$を逐次的に大きくしていった場合、全体の直線近似の係数がその都度変化するので、動的計画法は使えない。

しかたないのでビーム探索を使う。ある$n$について、そこまでの$y_n$の系列のうち、誤差が小さいものを一定個数残し、それを1サンプルずつ伸ばしていく。

まず$y_n$を既知として巻き戻しの係数を求める。前回示したものと同じアルゴリズムだが、最適な係数だけでなく誤差も返すようにしている。

findlinear <- function(pos,x,y,n=length(pos)) {
  if (max(y[1:n]) == 0) {
    # only one segment
    model <- lm(x~pos,data=data.frame(x=x[1:n],pos=pos[1:n]))
    return(list(param=c(0,model$coefficients),error=sum(model$residual^2)))
  } 
  sx <- sum(x[1:n])
  sy <- sum(y[1:n])
  sxy <- sum(x[1:n]*y[1:n])
  syy <- sum(y[1:n]^2)
  mx <- sum(pos[1:n]*x[1:n])
  my <- sum(pos[1:n]*y[1:n])
  v <- -c(sxy,sx,mx)
  sn <- sum(pos[1:n])
  sn2 <- sum(pos[1:n]^2)
  m <- matrix(c(syy,-sy,-my,sy,-n,-sn,my,-sn,-sn2),
              nrow=3,ncol=3,byrow=TRUE)
  r <- solve(m,v)
  alpha <- r[1]
  a <- r[2]
  b <- r[3]
  e <- sum((x[1:n]+alpha*y[1:n]-a-b*pos[1:n])^2)
  list(param=r,error=e)
}

次にビーム探索によって最適っぽい$y_n$を逐次求めていく関数。ビーム幅を100で決め打ちにしてハードコードしているところがちょいダサ。


findbest <- function(pos,x) {
  n <- length(x)
  y <- c(0,0)
  cand <- list(y)
  err <- 0
  for (i in 3:n) {
    ncd <- list()
    nerr <- c()
    nparam <- list()
    for (y in cand) {
      y0 <- c(y,y[i-1])
      y1 <- c(y,y[i-1]+1)
      e0 <- findlinear(1:i,x,y0,i)
      e1 <- findlinear(1:i,x,y1,i)
      ncd[[length(ncd)+1]] <- y0 
      ncd[[length(ncd)+1]] <- y1
      nerr[length(nerr)+1] <- e0$error
      nerr[length(nerr)+1] <- e1$error
      nparam[[length(nparam)+1]] <- e0$param
      nparam[[length(nparam)+1]] <- e1$param
    }
    if (length(ncd) < 100) {
      cand <- ncd
      err <- nerr[order(nerr)]
    } else {
      od <- order(nerr)
      cand <- list()
      err <- c()
      for (j in 1:100) {
        cand[[length(cand)+1]] <- ncd[[od[j]]]
        err[length(err)+1] <- nerr[od[j]]
      }
    }
  }
  list(score=err,cand=cand,param=nparam)
}

最後に、折り返し系列を入れると勝手に巻き戻してくれる関数。uが巻き戻した系列値になる。

unroll <- function(pos,x) {
  bs <- findbest(pos,x)
  od <- order(bs$score)
  bi <- which.min(od)
  list(pos=pos,x=x,u=x+bs$cand[[bi]]*bs$param[[bi]][1],
       alpha=bs$param[[bi]][1],
       intercept=bs$param[[bi]][2],
       slope=bs$param[[bi]][3])
}

テスト。

x0 <- seq(1,100,3)
z0 <- seq(1,100,20)
x <- rep(0,length(x0))
for (i in 1:length(x0)) {
  x[i] <- x0[i]-z0[which.min(abs(x0[i]-z0))]
}
x <- x+runif(length(x))
plot(x,type="b")

image.png

u <- unroll(1:length(x),x)
plot(u$u)

image.png

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0