0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

関係データ学習 機械学習シリーズ

0
Last updated at Posted at 2018-12-03
# グラフラプラシアン

# 関係データ学習 機械学習シリーズ

data=data.frame(num=1:20,y=c(167,167.5,168.4,172,155.3,151.4,163,174,168,160.4,164.7,171,162.6,164.8,163.3,167.6,169.2,168,167.4,172),x1=c(84,87,86,85,82,87,92,94,88,84.9,78,90,88,87,82,84,86,83,85.2,82),x2=c(61,55.5,57,57,50,50,66.5,65,60.5,49.5,49.5,61,59.5,58.4,53.5,54,60,58.8,54,56))

n=nrow(data)

mat_data=as.matrix(data[,colnames(data) %in% c("y","x1","x2")])

cor=cor(mat_data)

alpha=0.01

t_value=qt(1-alpha/2,df=n-2)

t=cor*sqrt(n-2)/sqrt(1-cor^2)

diag(t)=0

t[t>qt(1-alpha/2,df=n-2)]=1

t[t!=1]=0

t=t+diag(-apply(t,1,sum),ncol=(ncol(t)))

eigen_values=eigen(t)$values

vec=eigen_values

vec[abs(vec)<10^(-4)]=1;vec[vec!=1]=0

vec_data=data.frame(num=1:length(vec),vec=vec)

num=vec_data$num[vec_data$vec==1]

vectors=eigen(t)$vectors[,num]

vectors=vectors/mean(vectors)




# p.68 Chinese restaurant process モデル

# Escobar & West 補助変数法
update_alpha_ew <- function(alpha, n, K, a, b) {
  eta <- rbeta(1, alpha + 1, n)
  pi <- (a + K - 1) / (n * (b - log(eta)) + a + K - 1)
  
  if (runif(1) < pi) {
    rgamma(1, shape = a + K, rate = b - log(eta))
  } else {
    rgamma(1, shape = a + K - 1, rate = b - log(eta))
  }
}

# クラスタ平均更新(共役正規事前)
update_theta <- function(xk, mu0, tau2, sigma2) {
  n_k <- length(xk)
  if (n_k == 0) return(rnorm(1, mu0, sqrt(tau2)))
  
  post_var <- 1 / (n_k / sigma2 + 1 / tau2)
  post_mean <- post_var * (sum(xk) / sigma2 + mu0 / tau2)
  
  rnorm(1, post_mean, sqrt(post_var))
}

# クラスタ割当 z 更新(Gibbs sampling)
update_z <- function(x, z, theta, alpha) {
  n <- length(x)
  K <- length(theta)
  
  for (i in 1:n) {
    # 現在のクラスタから外す
    k_i <- z[i]
    
    # 各クラスタの人数
    table_counts <- table(factor(z[-i], levels=1:K))
    
    # 既存クラスタの確率
    probs <- numeric(K + 1)
    for (k in 1:K) {
      probs[k] <- table_counts[k] * dnorm(x[i], mean=theta[k], sd=1)
    }
    
    # 新しいクラスタ
    probs[K+1] <- alpha * dnorm(x[i], mean=mu0, sd=sqrt(sigma2 + tau2))
    
    probs <- probs / sum(probs)
    
    # サンプリング
    new_k <- sample(1:(K+1), 1, prob=probs)
    
    if (new_k <= K) {
      z[i] <- new_k
    } else {
      z[i] <- K + 1
      theta <- c(theta, rnorm(1, mu0, sqrt(tau2)))  # 新しいクラスタ平均生成
    }
    
    # クラスタ番号の整理
    unique_z <- sort(unique(z))
    z <- match(z, unique_z)
    theta <- theta[unique_z]
    K <- length(theta)
  }
  
  list(z=z, theta=theta)
}


# 初期値と初期データ
set.seed(123)

# サンプルデータ:1次元ガウス混合
n <- 500
x <- c(rnorm(200, 15, 1), rnorm(150, 5, 1), rnorm(150, 10, 1))

# ハイパーパラメータ
a <- 1       # α の Gamma 事前 shape
b <- 1       # α の Gamma 事前 rate
sigma2 <- 1  # 観測分布の分散(共役簡略化)
mu0 <- 0     # クラスタ事前平均
tau2 <- 10   # クラスタ事前分散
alpha <- 1            # 集中パラメータ初期値
z <- sample(1:3, n, replace=TRUE)  # 初期クラスタ割当


# Gibbs Sampling

n_iter <- 300
theta <- tapply(x, z, mean)  # 初期クラスタ平均

for (iter in 1:n_iter) {
  # 1. クラスタ割当 z 更新
  res <- update_z(x, z, theta, alpha)
  z <- res$z
  theta <- res$theta
  
  # 2. クラスタ平均 θ 更新
  for (k in 1:length(theta)) {
    theta[k] <- update_theta(x[z==k], mu0, tau2, sigma2)
  }
  
  # 3. α 更新
  K <- length(theta)
  alpha <- update_alpha_ew(alpha, n, K, a, b)
  
  if (iter %% 10 == 0) {
    cat("Iter:", iter, "Clusters:", K, "Alpha:", round(alpha,2), "\n")
  }
}


table(z)       # 各クラスタの人数
theta           # 各クラスタ平均

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?