0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Data Every Day: アボカドの価格

Posted at

tldr

KggleのAvocado PricesClassifying Avocados by Type - Data Every Day #037に沿ってやっていきます。

実行環境はGoogle Colaboratorです。

インポート

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import sklearn.preprocessing as sp
from sklearn.model_selection import train_test_split
import sklearn.linear_model as slm

import tensorflow as tf

データのダウンロード

Google Driveをマウントします。

from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

KaggleのAPIクライアントを初期化し、認証します。
認証情報はGoogle Drive内(/content/drive/My Drive/Colab Notebooks/Kaggle)にkaggle.jsonとして置いてあります。

import os
kaggle_path = "/content/drive/My Drive/Colab Notebooks/Kaggle"
os.environ['KAGGLE_CONFIG_DIR'] = kaggle_path

from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate() 

Kaggle APIを使ってデータをダウンロードします。

dataset_id = 'neuromusic/avocado-prices'
dataset = api.dataset_list_files(dataset_id)
file_name = dataset.files[0].name
file_path = os.path.join(api.get_default_download_dir(), file_name)
file_path
Warning: Looks like you're using an outdated API Version, please consider updating (server 1.5.10 / client 1.5.9)





'/content/avocado.csv'
api.dataset_download_file(dataset_id, file_name, force=True, quiet=False)
100%|██████████| 629k/629k [00:00<00:00, 103MB/s]

Downloading avocado.csv.zip to /content









True

データの読み込み

Padasを使ってダウンロードしてきたCSVファイルを読み込みます。

data = pd.read_csv(file_path+'.zip')
data
Unnamed: 0 Date AveragePrice Total Volume 4046 4225 4770 Total Bags Small Bags Large Bags XLarge Bags type year region
0 0 2015-12-27 1.33 64236.62 1036.74 54454.85 48.16 8696.87 8603.62 93.25 0.0 conventional 2015 Albany
1 1 2015-12-20 1.35 54876.98 674.28 44638.81 58.33 9505.56 9408.07 97.49 0.0 conventional 2015 Albany
2 2 2015-12-13 0.93 118220.22 794.70 109149.67 130.50 8145.35 8042.21 103.14 0.0 conventional 2015 Albany
3 3 2015-12-06 1.08 78992.15 1132.00 71976.41 72.58 5811.16 5677.40 133.76 0.0 conventional 2015 Albany
4 4 2015-11-29 1.28 51039.60 941.48 43838.39 75.78 6183.95 5986.26 197.69 0.0 conventional 2015 Albany
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
18244 7 2018-02-04 1.63 17074.83 2046.96 1529.20 0.00 13498.67 13066.82 431.85 0.0 organic 2018 WestTexNewMexico
18245 8 2018-01-28 1.71 13888.04 1191.70 3431.50 0.00 9264.84 8940.04 324.80 0.0 organic 2018 WestTexNewMexico
18246 9 2018-01-21 1.87 13766.76 1191.92 2452.79 727.94 9394.11 9351.80 42.31 0.0 organic 2018 WestTexNewMexico
18247 10 2018-01-14 1.93 16205.22 1527.63 2981.04 727.01 10969.54 10919.54 50.00 0.0 organic 2018 WestTexNewMexico
18248 11 2018-01-07 1.62 17489.58 2894.77 2356.13 224.53 12014.15 11988.14 26.01 0.0 organic 2018 WestTexNewMexico

18249 rows × 14 columns

data = data.drop(data.columns[0], axis=1)

下準備

data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 18249 entries, 0 to 18248
Data columns (total 13 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   Date          18249 non-null  object 
 1   AveragePrice  18249 non-null  float64
 2   Total Volume  18249 non-null  float64
 3   4046          18249 non-null  float64
 4   4225          18249 non-null  float64
 5   4770          18249 non-null  float64
 6   Total Bags    18249 non-null  float64
 7   Small Bags    18249 non-null  float64
 8   Large Bags    18249 non-null  float64
 9   XLarge Bags   18249 non-null  float64
 10  type          18249 non-null  object 
 11  year          18249 non-null  int64  
 12  region        18249 non-null  object 
dtypes: float64(9), int64(1), object(3)
memory usage: 1.8+ MB

可視化

ボックスプロットを表示してみます。

plt.figure(figsize=(20, 10))

for i in range(len(data.columns)):
    if data.dtypes[i] != 'object':
        plt.subplot(3, 5, i+1)
        plt.boxplot(data[data.columns[i]], vert=False)
        plt.title(data.columns[i])
plt.show()

png

欠損値の処理

data.isna().sum()
Date            0
AveragePrice    0
Total Volume    0
4046            0
4225            0
4770            0
Total Bags      0
Small Bags      0
Large Bags      0
XLarge Bags     0
type            0
year            0
region          0
dtype: int64

エンコード

def get_uniques(df, columns):
    return {column: list(df[column].unique()) for column in columns}
categorical_columns = ['region', 'Date', 'type']

get_uniques(data, categorical_columns)
{'Date': ['2015-12-27',
  '2015-12-20',
  '2015-12-13',
  '2015-12-06',
  '2015-11-29',
  '2015-11-22',
  '2015-11-15',
  '2015-11-08',
  '2015-11-01',
  '2015-10-25',
  '2015-10-18',
  '2015-10-11',
  '2015-10-04',
  '2015-09-27',
  '2015-09-20',
  '2015-09-13',
  '2015-09-06',
  '2015-08-30',
  '2015-08-23',
  '2015-08-16',
  '2015-08-09',
  '2015-08-02',
  '2015-07-26',
  '2015-07-19',
  '2015-07-12',
  '2015-07-05',
  '2015-06-28',
  '2015-06-21',
  '2015-06-14',
  '2015-06-07',
  '2015-05-31',
  '2015-05-24',
  '2015-05-17',
  '2015-05-10',
  '2015-05-03',
  '2015-04-26',
  '2015-04-19',
  '2015-04-12',
  '2015-04-05',
  '2015-03-29',
  '2015-03-22',
  '2015-03-15',
  '2015-03-08',
  '2015-03-01',
  '2015-02-22',
  '2015-02-15',
  '2015-02-08',
  '2015-02-01',
  '2015-01-25',
  '2015-01-18',
  '2015-01-11',
  '2015-01-04',
  '2016-12-25',
  '2016-12-18',
  '2016-12-11',
  '2016-12-04',
  '2016-11-27',
  '2016-11-20',
  '2016-11-13',
  '2016-11-06',
  '2016-10-30',
  '2016-10-23',
  '2016-10-16',
  '2016-10-09',
  '2016-10-02',
  '2016-09-25',
  '2016-09-18',
  '2016-09-11',
  '2016-09-04',
  '2016-08-28',
  '2016-08-21',
  '2016-08-14',
  '2016-08-07',
  '2016-07-31',
  '2016-07-24',
  '2016-07-17',
  '2016-07-10',
  '2016-07-03',
  '2016-06-26',
  '2016-06-19',
  '2016-06-12',
  '2016-06-05',
  '2016-05-29',
  '2016-05-22',
  '2016-05-15',
  '2016-05-08',
  '2016-05-01',
  '2016-04-24',
  '2016-04-17',
  '2016-04-10',
  '2016-04-03',
  '2016-03-27',
  '2016-03-20',
  '2016-03-13',
  '2016-03-06',
  '2016-02-28',
  '2016-02-21',
  '2016-02-14',
  '2016-02-07',
  '2016-01-31',
  '2016-01-24',
  '2016-01-17',
  '2016-01-10',
  '2016-01-03',
  '2017-12-31',
  '2017-12-24',
  '2017-12-17',
  '2017-12-10',
  '2017-12-03',
  '2017-11-26',
  '2017-11-19',
  '2017-11-12',
  '2017-11-05',
  '2017-10-29',
  '2017-10-22',
  '2017-10-15',
  '2017-10-08',
  '2017-10-01',
  '2017-09-24',
  '2017-09-17',
  '2017-09-10',
  '2017-09-03',
  '2017-08-27',
  '2017-08-20',
  '2017-08-13',
  '2017-08-06',
  '2017-07-30',
  '2017-07-23',
  '2017-07-16',
  '2017-07-09',
  '2017-07-02',
  '2017-06-25',
  '2017-06-18',
  '2017-06-11',
  '2017-06-04',
  '2017-05-28',
  '2017-05-21',
  '2017-05-14',
  '2017-05-07',
  '2017-04-30',
  '2017-04-23',
  '2017-04-16',
  '2017-04-09',
  '2017-04-02',
  '2017-03-26',
  '2017-03-19',
  '2017-03-12',
  '2017-03-05',
  '2017-02-26',
  '2017-02-19',
  '2017-02-12',
  '2017-02-05',
  '2017-01-29',
  '2017-01-22',
  '2017-01-15',
  '2017-01-08',
  '2017-01-01',
  '2018-03-25',
  '2018-03-18',
  '2018-03-11',
  '2018-03-04',
  '2018-02-25',
  '2018-02-18',
  '2018-02-11',
  '2018-02-04',
  '2018-01-28',
  '2018-01-21',
  '2018-01-14',
  '2018-01-07'],
 'region': ['Albany',
  'Atlanta',
  'BaltimoreWashington',
  'Boise',
  'Boston',
  'BuffaloRochester',
  'California',
  'Charlotte',
  'Chicago',
  'CincinnatiDayton',
  'Columbus',
  'DallasFtWorth',
  'Denver',
  'Detroit',
  'GrandRapids',
  'GreatLakes',
  'HarrisburgScranton',
  'HartfordSpringfield',
  'Houston',
  'Indianapolis',
  'Jacksonville',
  'LasVegas',
  'LosAngeles',
  'Louisville',
  'MiamiFtLauderdale',
  'Midsouth',
  'Nashville',
  'NewOrleansMobile',
  'NewYork',
  'Northeast',
  'NorthernNewEngland',
  'Orlando',
  'Philadelphia',
  'PhoenixTucson',
  'Pittsburgh',
  'Plains',
  'Portland',
  'RaleighGreensboro',
  'RichmondNorfolk',
  'Roanoke',
  'Sacramento',
  'SanDiego',
  'SanFrancisco',
  'Seattle',
  'SouthCarolina',
  'SouthCentral',
  'Southeast',
  'Spokane',
  'StLouis',
  'Syracuse',
  'Tampa',
  'TotalUS',
  'West',
  'WestTexNewMexico'],
 'type': ['conventional', 'organic']}
ordinal_features = ['Date']
nominal_features = ['region']
target_column = 'type'
def ordinal_encode(df, column, ordering):
    df = df.copy()
    df[column] = df[column].apply(lambda x: ordering.index(x))
    return df

def onehot_encode(df, column):
    df = df.copy()
    dummies = pd.get_dummies(df[column])
    df = pd.concat([df, dummies], axis=1)
    df = df.drop(column, axis=1)
    return df
date_ordering = sorted(data['Date'].unique())
data = ordinal_encode(data, 'Date', date_ordering)
data = onehot_encode(data, 'region')
label_encoder = sp.LabelEncoder()
data[target_column] = label_encoder.fit_transform(data[target_column])

X, Yデータの分割

y = data[target_column]
X = data.drop(target_column, axis=1)

スケーリング

scaler = sp.StandardScaler()
X = scaler.fit_transform(X)

トレーニング、テストデータの分割

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7)

トレーニング

model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(65,)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid'),
])

model.summary()

model.compile(
    optimizer='adam',
    loss='mse',
    metrics=['accuracy'],
)

batch_size=64
epochs=100

history = model.fit(
    X_train,
    y_train,
    validation_split=0.2,
    batch_size=batch_size,
    epochs=epochs,
    callbacks=[tf.keras.callbacks.ReduceLROnPlateau()],
    verbose=0,
)
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_15 (Dense)             (None, 64)                4224      
_________________________________________________________________
dense_16 (Dense)             (None, 64)                4160      
_________________________________________________________________
dense_17 (Dense)             (None, 1)                 65        
=================================================================
Total params: 8,449
Trainable params: 8,449
Non-trainable params: 0
_________________________________________________________________

Result

plt.figure(figsize=(14, 10))

epochs_range = range(1, epochs+1)
train_loss = history.history['loss']
val_loss = history.history['val_loss']

plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')

plt.title('Training and Validation Loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.show()

png

過学習が起きる前(validatio lossが最も低いとき)のEpochを調べます。

np.argmin(val_loss)
42
model.evaluate(X_test, y_test)
172/172 [==============================] - 0s 949us/step - loss: 4.9961e-04 - accuracy: 0.9998





[0.0004996125353500247, 0.9998173713684082]
かなり良い精度がでました
0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?