LoginSignup
1
3

More than 3 years have passed since last update.

tidymodelsとworkflow(その2)

Last updated at Posted at 2021-01-21

1.問題点とエラーコード

前回は、tidymodelsのうち、workflowオブジェクトによる方法を行いました。
workflowオブジェクトは非常に便利で、まさに神!な感じです。
しかし、workflowオブジェクトを使った方法で、次のようなエラーがでました。

まず次のコードを実行します。

library(tidymodels)
library(tidyverse)

d = mtcars

set.seed(12)
initial_split = initial_split(d, p = 0.75)
train_data = training(initial_split)
test_data  = testing(initial_split)

#前処理を作成 :目的変数を対数化
recipe1 = recipe(mpg ~., train_data) %>%
  step_log(mpg)

#モデルは線形モデル(lm)を指定する
model_lm = linear_reg() %>%  set_engine("lm")

#workflowオブジェクトを作成する
lm_workflow = workflow() %>% add_recipe(recipe1) %>% add_model(model_lm)

lm_workflow

$>══ Workflow ════════════════════════════════════════════════════════════════════════════
$>Preprocessor: Recipe
$>Model: linear_reg()
$>
$> Preprocessor ────────────────────────────────────
$>1 Recipe Step
$>
$> step_log()
$>
$> Model ────────────────────────────────────────
$>Linear Regression Model Specification (regression)
$>
$>Computational engine: lm 

訓練データにfitさせて、検証データにpredictさせます。

lm_workflow_fit = lm_workflow %>% fit(train_data)
lm_workflow_fit %>% predict(test_data)

$>エラー: Assigned data `log(new_data[[col_names[i]]] + object$offset, base = object$base)` must be compatible with existing data. x Existing data has 8 rows. x Assigned data has 0 rows.  Only vectors of size 1 are recycled.

2.解決方法

上記のようなエラーが出て、大変悩みました。
(上のコード事例ではありませんが、同じエラーがでました)

結局、今のところは、workflowオブジェクト独特の問題でrecipeの処理に目的変数を当てる(上記の場合、目的変数の対数化)てpredictをworkflowオブジェクトに当てると生じます。

そこで、解決方法です。

#recipe1を使って直接データを前処理(目的変数を対数化)
train_data_juice = juice(recipe1 %>% prep())
test_data_bake = bake(recipe1 %>% prep(),test_data)
recipe2 = recipe(mpg ~., train_data_juice)

#recipeのアップデート:目的変数の対数化を取り消し
lm_workflow = lm_workflow %>% update_recipe(recipe2)

lm_workflow_fit = lm_workflow %>% fit(train_data_juice)
lm_workflow_fit %>% predict(test_data_bake)

$>.pred
$><dbl>
$>3.094213              
$>2.878973              
$>2.802263              
$>2.468462              
$>2.940477              
$>3.370685              
$>3.456701              
$>2.413078
$>8 rows

今のところ、目的変数を対象にrecipeを使って前処理をする場合、predictするときにworkflowではエラーがでるので一度、prep(),juice()で前処理を済ませてから、workflowにかける必要があるみたいです。
(目的変数は、前処理過程に入れられない)

3.参考

Predict doesn't work for recipe that modify outcome #63

enjoy!

1
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
3