現在私は、大学の研究で料理の推薦システムを作っています。
今回はそのシステムに必要なスクリプトをいくつか書きました。
目的
料理名をキーに、その料理に近い栄養素を持つ他の料理を検索する。
そのために必要なもの
- データセット{料理名:{栄養素情報}}
- 類似度計算スクリプト
- Pythonの知識
そのためにやったこと
- データセットの作成&整形
- 類似度計算スクリプトの作成
環境
- OS(ホスト/コンテナ) : macOS 10.12.6 / Ubuntu 17.04
- Docker : 17.06
- Python : 3.5.3
- Selenium : 3.5.0
1. データセットの作成&整形
料理名と栄養素情報の紐づいたオープンデータが欲しかった。でも...
- 唯一のオープンデータが貧弱
https://fooddb.mext.go.jp/index.pl - 有料データは高い
https://imd.jp/fdb/
https://www.eatsmart.jp/do/corporate/calData/index
仕方ないのでWeb検索サービス(https://www.eatsmart.jp/do/caloriecheck/index) からスクレイピングしました。
以下、ソースコード。(丸三日くらい掛かります
getNutrients.py
# coding: utf-8
from selenium import webdriver
from bs4 import BeautifulSoup
import time, json, codecs
# 最初に読み込むURLを設定
url = "https://www.eatsmart.jp/do/caloriecheck/index"
# JSONの準備
nutrientJSON = {}
dishJSON = {}
categoryJSON = {}
# PhantomJSのドライバを設定
driver = webdriver.PhantomJS()
# 暗黙的な待機を約1秒行う
driver.implicitly_wait(1)
# URLを読み込む
driver.get(url)
### HTMLを解析し、リンクの一覧を表示する ###
# 取得するテーブル要素のセレクターを指定
table = driver.find_element_by_css_selector("#form > div:nth-child(6) > table > tbody > tr:nth-child(1) > td > div > table > tbody")
# テーブル内のすべてのリンクを抽出する
trs = table.find_elements_by_tag_name("tr")
for i in range(0,len(trs)):
tds = trs[i].find_elements_by_tag_name("td")
for j in range(0,len(tds)):
td = tds[j].find_element_by_css_selector("a")
link = td.get_attribute("href")
print(str(i)+", "+str(j)+" > "+link)
print("------------------------------------------------------------")
# カテゴリーHTMLを解析
driver.get(link)
# テーブル要素のセレクターを指定
table2 = driver.find_element_by_css_selector("#main > div.categoryLists > table > tbody")
trs2 = table2.find_elements_by_tag_name("tr")
for k in range(0,len(trs2)):
tds2 = trs2[k].find_elements_by_tag_name("td")
for l in range(0, len(tds2)):
print("= 1 Page =")
td2 = tds2[l].find_element_by_css_selector("a")
link2 = td2.get_attribute("href")
category_name = td2.text
# 種目HTMLを解析
driver.get(link2)
# 1ページ目の「料理名」テーブル要素のセレクターを指定
table3 = driver.find_element_by_css_selector("#form > div.result > table > tbody")
trs3 = table3.find_elements_by_tag_name("tr")
for m in range(1,len(trs3)):
time.sleep(1) # マナーの待機時間
link3 = trs3[m].find_element_by_css_selector("td:nth-child(1) > div > div > a").get_attribute("href")
print(link3)
# 個々の料理品目HTMLを解析
driver.get(link3)
dish_name = driver.find_element_by_css_selector("#main > p.manual > strong").text
divs = driver.find_elements_by_class_name("bargraph")
for n in range(0,len(divs)-1):
trs4 = divs[n].find_elements_by_css_selector("table > tbody > tr")
for o in range(1,len(trs4)):
nutrient_name = trs4[o].find_element_by_css_selector("td.item > a").text
nutrient_content = trs4[o].find_element_by_css_selector("td.capa").text.replace("kcal", "").replace("mg", "").replace("μg", "").replace("g", "").replace(",", ".")
nutrientJSON[nutrient_name] = float(nutrient_content)
trs6 = driver.find_elements_by_css_selector("#detailContR > div:nth-child(7) > div > table > tbody > tr")
for n in range(1,len(trs6)):
try:
nutrient_name = trs6[n].find_element_by_css_selector("td:nth-child(1) > a").text
nutrient_content = trs6[n].find_element_by_css_selector("td.alignR").text.replace("kcal", "").replace("mg", "").replace("μg", "").replace("g", "").replace(",", ".")
nutrientJSON[nutrient_name] = float(nutrient_content)
except:
nutrient_name = trs6[n].find_element_by_css_selector("td:nth-child(1)").text
nutrient_content = trs6[n].find_element_by_css_selector("td.alignR").text.replace("kcal", "").replace("mg", "").replace("μg", "").replace("g", "").replace(",", ".")
nutrientJSON[nutrient_name] = float(nutrient_content)
print("*** "+str(i)+"-"+str(j)+"-"+str(k)+"-"+str(l)+"-"+str(m)+"-1"+" ok ***")
dishJSON[dish_name] = nutrientJSON
nutrientJSON = {} #JSON初期化
# 料理品目ごとに書き出し
f = codecs.open("JSON/temp_nutrientsData"+str(i)+"-"+str(j)+"-"+str(k)+"-"+str(l)+"-"+str(m)+"-1"+".json", "w", "utf-8")
json.dump(dishJSON, f, ensure_ascii=False)
f.close()
driver.back()
table3 = driver.find_element_by_css_selector("#form > div.result > table > tbody")
trs3 = table3.find_elements_by_tag_name("tr")
# 1 2 3 ... > 次のページ
ul = driver.find_element_by_css_selector("#form > div.paginationT > ul")
lis = ul.find_elements_by_tag_name("li")
for m in range(1,len(lis)-1):
print("= "+str(m+1)+" Page =")
link4 = lis[m].find_element_by_css_selector("a").get_attribute("href")
driver.get(link4)
# 次ページ以降の「料理名」テーブル要素のセレクターを指定
table5 = driver.find_element_by_css_selector("#form > div.result > table > tbody")
trs5 = table5.find_elements_by_tag_name("tr")
for n in range(1,len(trs5)):
time.sleep(0.1) # マナーの待機時間
link5 = trs5[n].find_element_by_css_selector("td:nth-child(1) > div > div > a").get_attribute("href")
print(link5)
# 個々の料理を解析
driver.get(link5)
dish_name = driver.find_element_by_css_selector("#main > p.manual > strong").text
divs2 = driver.find_elements_by_class_name("bargraph")
for o in range(0,len(divs2)-1):
trs7 = divs2[o].find_elements_by_css_selector("table > tbody > tr")
for p in range(1,len(trs7)):
nutrient_name = trs7[p].find_element_by_css_selector("td.item > a").text
nutrient_content = trs7[p].find_element_by_css_selector("td.capa").text.replace("kcal", "").replace("mg", "").replace("μg", "").replace("g", "").replace(",", ".")
nutrientJSON[nutrient_name] = float(nutrient_content)
trs8 = driver.find_elements_by_css_selector("#detailContR > div:nth-child(7) > div > table > tbody > tr")
for o in range(1,len(trs8)):
try:
nutrient_name = trs8[o].find_element_by_css_selector("td:nth-child(1) > a").text
nutrient_content = trs8[o].find_element_by_css_selector("td.alignR").text.replace("kcal", "").replace("mg", "").replace("μg", "").replace("g", "").replace(",", ".")
nutrientJSON[nutrient_name] = float(nutrient_content)
except:
nutrient_name = trs8[o].find_element_by_css_selector("td:nth-child(1)").text
nutrient_content = trs8[o].find_element_by_css_selector("td.alignR").text.replace("kcal", "").replace("mg", "").replace("μg", "").replace("g", "").replace(",", ".")
nutrientJSON[nutrient_name] = float(nutrient_content)
print("*** "+str(i)+"-"+str(j)+"-"+str(k)+"-"+str(l)+"-"+str(m)+"-"+str(n)+" ok ***")
dishJSON[dish_name] = nutrientJSON
nutrientJSON = {} #JSON初期化
# 料理品目ごとに書き出し
f = codecs.open("JSON/temp_nutrientsData"+str(i)+"-"+str(j)+"-"+str(k)+"-"+str(l)+"-"+str(m)+"-"+str(n)+".json", "w", "utf-8")
json.dump(dishJSON, f, ensure_ascii=False)
f.close()
driver.back()
table5 = driver.find_element_by_css_selector("#form > div.result > table > tbody")
trs5 = table5.find_elements_by_tag_name("tr")
#driver.save_screenshot("./nutrient/link2_captcha"+str(l)+"-"+str(m)+"-"+str(n)+".png")
driver.back()
ul = driver.find_element_by_css_selector("#form > div.paginationT > ul")
lis = ul.find_elements_by_tag_name("li")
categoryJSON[category_name] = dishJSON
dishJSON = {} #JSON初期化
# カテゴリごとにも書き出し
f = codecs.open("JSON/temp_dishData"+str(i)+"-"+str(j)+"-"+str(k)+"-"+str(l)+".json", "w", "utf-8")
json.dump(categoryJSON, f, ensure_ascii=False)
f.close()
driver.back()
table2 = driver.find_element_by_css_selector("#main > div.categoryLists > table > tbody")
trs2 = table2.find_elements_by_tag_name("tr")
tds2 = trs2[k].find_elements_by_tag_name("td")
table2 = driver.find_element_by_css_selector("#main > div.categoryLists > table > tbody")
trs2 = table2.find_elements_by_tag_name("tr")
driver.back()
table = driver.find_element_by_css_selector("#form > div:nth-child(6) > table > tbody > tr:nth-child(1) > td > div > table > tbody")
trs = table.find_elements_by_tag_name("tr")
tds = trs[i].find_elements_by_tag_name("td")
# 最後にJSON形式で書き出し
f = codecs.open("JSON/categoryData.json", "w", "utf-8")
json.dump(categoryJSON, f, ensure_ascii=False)
f.close()
作成したデータセットの内容確認
横軸が「1料理ごとに含まれる栄養素データの数(= len(nutrients_in_dish))」のヒストグラム。
量はあるけど中身スッカスカなデータセットができました。
とりあえず、70種類以下の栄養素データを持たない料理は除外してみました。
1919/6403個の料理データが残りました。
また、データセットを各カテゴリ(メインディッシュ・サイドメニュー・お菓子・飲み物)に分けて、カテゴリごとに閾値を設定してみました。
- メインディッシュ(mainDish):
1021/1336個 <- 60種類 - サイドメニュー(sideDish):
1247/1514個 <- 60種類 - お菓子(dessert):
545/1805個 <- 60種類 - 飲み物(drink):
429/2454個 <- 10種類
以下、参考にしたヒストグラム。
2. 類似度計算スクリプトの作成
こちらの記事を参考に、類似度計算スクリプトを作成しました。
http://qiita.com/hik0107/items/96c483afd6fb2f077985
simCalculation.py
# coding: utf-8
import json, math
import numpy as np
# データ読込
f = open("data/formatFilterData.json", "r")
dataset = json.load(f)
f.close()
# 料理キー入力
print("料理名を入力してください。")
inputDish = input(">>> ")
# 入力された料理を指定(日本語辞書を扱う上での苦肉の策)
for key in dataset.keys():
if(key==inputDish):
choicedDish = inputDish
# 類似度を計算する関数
def get_similairty(dish1, dish2):
## 両料理で共通した栄養素の和集合を生成
setDish1 = set(dataset[dish1].keys())
setDish2 = set(dataset[dish2].keys())
setBoth = setDish1.intersection(setDish2)
# 栄養素が共通でない場合は類似度を0とする(とりあえず)
if len(setBoth)==0:
return 0
listDestance = []
for item in setBoth:
# 同じ栄養素の含有量の差の2乗を計算
# この数値が大きいほど「類似していない」と定義できる
distance = pow(dataset[dish1][item]-dataset[dish2][item], 2)
listDestance.append(distance)
# 各栄養素の非類似度の合計の逆比的な指標を返す(とりあえず)
return 1/(1+np.sqrt(sum(listDestance)))
# レコメンド関数
def get_recommend(dish, top_N):
simDishList = {} # 類似度と料理名の辞書
simList = [] # 類似度のリスト
# 入力された料理を除いたユーザリストを作成
# -> 各料理との類似度を計算するため
listOthers = list()
for key in dataset.keys():
listOthers.append(key)
listOthers.remove(dish)
# 類似度と料理の辞書を作成
for other in listOthers:
sim = get_similairty(dish, other)
simDishList[sim] = other
simList.append(sim)
simList.sort()
simList.reverse()
recommendDict = {}
for i in range(0,top_N):
recommendDict[simDishList[simList[i]].replace("\u3000", " ")] = simList[i] # 全角空白"\u3000"をreplaceで半角空白に変換
return recommendDict
# コマンドラインに表示
print(get_recommend(choicedDish, 5))
出力結果例
試しに「うな重」に類似した料理を出力してみました。
$ python3 simCalculation.py
栄養素名を入力してください。
>>> うな重
{'ひつまぶし': 0.0014330883248297468, 'いくら丼': 0.0011092730907257715, 'ぜんざい': 0.0010865128148115706, 'いわしのごま漬け': 0.0010664954000816627, 'もんじゃ焼き': 0.001064990896452261}
そこそこ類似してるように見えます。
続いて、z-score変換を用いて正規化を行った後の出力結果がこちら。
(ソースコードは付録参照)
$ python3 simCalculation.py
栄養素名を入力してください。
>>> うな重
{'ひつまぶし': 0.11615374412085168, 'いくら丼': 0.075985385218824225, '押し寿司盛り合わせ': 0.0698463684540752, 'にしんそば': 0.069143743122642959, '海鮮丼': 0.067707191581806936}
少し類似度が高まったみたいです。
今後の予定
- 類似度計算をユークリッド距離->コサイン類似度に置換
- 栄養素の充足率をダッシュボードで表示
- REST API化
付録
データセット整形スクリプト
formatJSON.py
# coding: utf-8
import json
# データの読込
f = open("data/nutrientsData.json", "r")
sampleJSON = json.load(f)
f.close()
### 第1キーをリストとして抽出 ###
# valueをリスト形式に抽出
valueList = list()
for value in sampleJSON.values():
valueList.append(value)
# 抽出したリストを統合する
formatJSON = {}
for value in valueList:
formatJSON.update(value)
# 70種類以下の栄養素しか持たぬ料理データは除外
keyMaterials = list()
for key in formatJSON.keys():
if(len(formatJSON[key]) >= 70):
key.replace("\u3000", "") # 全角空白"\u3000"をreplaceで半角空白に変換
keyMaterials.append(key)
formatJSON2 = {}
for i in range(len(keyMaterials)):
formatJSON2[keyMaterials[i]] = formatJSON[keyMaterials[i]]
i += 1
# データの書き出し
f = open("data/formatFilterData.json", "w")
json.dump(formatJSON2, f, ensure_ascii=False, sort_keys=True, indent=2)
f.close()
データセット正規化スクリプト
normalizeJSON.py
# coding: utf-8
import json
import numpy as np
# データの読込
f = open("data/formatFilterData.json", "r")
targetJSON = json.load(f)
f.close()
## 事前準備
# keyリストとvalueリストの作成
keyList = list()
for key in targetJSON.keys():
keyList.append(key)
valueList = list()
for value in targetJSON.values():
valueList.append(value)
# 最大栄養素数の確認と、栄養素カタログの作成
nutrientNum = []
nutrientList = []
keys = targetJSON.keys()
for i in range(len(targetJSON)):
nutrientNum.append(len(valueList[i]))
for nutrient in valueList[i]:
nutrientList.append(nutrient)
nutrientCatalog = list(set(nutrientList))
# 各栄養素の数値を全て収集
allNutrientAllList = [ [] for i in range(len(nutrientCatalog)) ]
for i in range(len(nutrientCatalog)):
for value in targetJSON.values():
if(value.keys() >= {nutrientCatalog[i]}):
allNutrientAllList[i].append(value[nutrientCatalog[i]])
else:
allNutrientAllList[i].append(0.0)
## 正規化
# 各栄養素の平均値と標準偏差を算出
z_allNutrientAllList = [ [] for i in range(len(nutrientCatalog)) ]
for i in range(len(allNutrientAllList)):
npValues = np.array(allNutrientAllList[i])
meanValue = npValues.mean()
stdValue = npValues.std()
# z-score normalizationによる正規化を行う
for value in allNutrientAllList[i]:
z_allNutrientAllList[i].append((value-meanValue)/stdValue) # z-score = (各value - 平均) / 標準偏差
# 正規化した値を用いたJSONを作成
normalizedJSON = {}
for i in range(len(keyList)): #=1873
normalizedJSON[keyList[i]] = {}
for j in range(len(nutrientCatalog)): #=95
if(z_allNutrientAllList[j][i]!=z_allNutrientAllList[j][i]): # NaN(不定)を判別
z_allNutrientAllList[j][i] = 0.0
normalizedJSON[keyList[i]][nutrientCatalog[j]] = z_allNutrientAllList[j][i]+2
# データの書き出し
with open("data/formatNormalizedFilterData.json", "w") as f:
json.dump(normalizedJSON, f, ensure_ascii=False, sort_keys=True, indent=2)