LoginSignup
12
9

More than 5 years have passed since last update.

栄養素多次元ベクトル間の類似度を計算する

Last updated at Posted at 2017-11-06

現在私は、大学の研究で料理の推薦システムを作っています。
今回はそのシステムに必要なスクリプトをいくつか書きました。

目的

料理名をキーに、その料理に近い栄養素を持つ他の料理を検索する。

そのために必要なもの

  • データセット{料理名:{栄養素情報}}
  • 類似度計算スクリプト
  • Pythonの知識

そのためにやったこと

  1. データセットの作成&整形
  2. 類似度計算スクリプトの作成

環境

  • OS(ホスト/コンテナ) : macOS 10.12.6 / Ubuntu 17.04
  • Docker : 17.06
  • Python : 3.5.3
  • Selenium : 3.5.0

1. データセットの作成&整形

料理名と栄養素情報の紐づいたオープンデータが欲しかった。でも...

仕方ないのでWeb検索サービス(https://www.eatsmart.jp/do/caloriecheck/index) からスクレイピングしました。
以下、ソースコード。(丸三日くらい掛かります

getNutrients.py
# coding: utf-8
from selenium import webdriver
from bs4 import BeautifulSoup
import time, json, codecs

# 最初に読み込むURLを設定
url = "https://www.eatsmart.jp/do/caloriecheck/index"

# JSONの準備
nutrientJSON = {}
dishJSON = {}
categoryJSON = {}

# PhantomJSのドライバを設定
driver = webdriver.PhantomJS()

# 暗黙的な待機を約1秒行う
driver.implicitly_wait(1)

# URLを読み込む
driver.get(url)

### HTMLを解析し、リンクの一覧を表示する ###
# 取得するテーブル要素のセレクターを指定
table = driver.find_element_by_css_selector("#form > div:nth-child(6) > table > tbody > tr:nth-child(1) > td > div > table > tbody")
# テーブル内のすべてのリンクを抽出する
trs = table.find_elements_by_tag_name("tr")
for i in range(0,len(trs)):
    tds = trs[i].find_elements_by_tag_name("td")
    for j in range(0,len(tds)):
        td = tds[j].find_element_by_css_selector("a")
        link = td.get_attribute("href")
        print(str(i)+", "+str(j)+" > "+link)
        print("------------------------------------------------------------")
        # カテゴリーHTMLを解析
        driver.get(link)
        # テーブル要素のセレクターを指定
        table2 = driver.find_element_by_css_selector("#main > div.categoryLists > table > tbody")
        trs2 = table2.find_elements_by_tag_name("tr")
        for k in range(0,len(trs2)):
            tds2 = trs2[k].find_elements_by_tag_name("td")
            for l in range(0, len(tds2)):
                print("= 1 Page =")
                td2 = tds2[l].find_element_by_css_selector("a")
                link2 = td2.get_attribute("href")
                category_name = td2.text
                # 種目HTMLを解析
                driver.get(link2)
                # 1ページ目の「料理名」テーブル要素のセレクターを指定
                table3 = driver.find_element_by_css_selector("#form > div.result > table > tbody")
                trs3 = table3.find_elements_by_tag_name("tr")
                for m in range(1,len(trs3)):
                    time.sleep(1)   # マナーの待機時間
                    link3 = trs3[m].find_element_by_css_selector("td:nth-child(1) > div > div > a").get_attribute("href")
                    print(link3)
                    # 個々の料理品目HTMLを解析
                    driver.get(link3)
                    dish_name = driver.find_element_by_css_selector("#main > p.manual > strong").text
                    divs = driver.find_elements_by_class_name("bargraph")
                    for n in range(0,len(divs)-1):
                        trs4 = divs[n].find_elements_by_css_selector("table > tbody > tr")
                        for o in range(1,len(trs4)):
                            nutrient_name = trs4[o].find_element_by_css_selector("td.item > a").text
                            nutrient_content = trs4[o].find_element_by_css_selector("td.capa").text.replace("kcal", "").replace("mg", "").replace("μg", "").replace("g", "").replace(",", ".")
                            nutrientJSON[nutrient_name] = float(nutrient_content)
                    trs6 = driver.find_elements_by_css_selector("#detailContR > div:nth-child(7) > div > table > tbody > tr")
                    for n in range(1,len(trs6)):
                        try:
                            nutrient_name = trs6[n].find_element_by_css_selector("td:nth-child(1) > a").text
                            nutrient_content = trs6[n].find_element_by_css_selector("td.alignR").text.replace("kcal", "").replace("mg", "").replace("μg", "").replace("g", "").replace(",", ".")
                            nutrientJSON[nutrient_name] = float(nutrient_content)
                        except:
                            nutrient_name = trs6[n].find_element_by_css_selector("td:nth-child(1)").text
                            nutrient_content = trs6[n].find_element_by_css_selector("td.alignR").text.replace("kcal", "").replace("mg", "").replace("μg", "").replace("g", "").replace(",", ".")
                            nutrientJSON[nutrient_name] = float(nutrient_content)
                    print("*** "+str(i)+"-"+str(j)+"-"+str(k)+"-"+str(l)+"-"+str(m)+"-1"+" ok ***")
                    dishJSON[dish_name] = nutrientJSON
                    nutrientJSON = {}   #JSON初期化
                    # 料理品目ごとに書き出し
                    f = codecs.open("JSON/temp_nutrientsData"+str(i)+"-"+str(j)+"-"+str(k)+"-"+str(l)+"-"+str(m)+"-1"+".json", "w", "utf-8")
                    json.dump(dishJSON, f, ensure_ascii=False)
                    f.close()
                    driver.back()
                    table3 = driver.find_element_by_css_selector("#form > div.result > table > tbody")
                    trs3 = table3.find_elements_by_tag_name("tr")
                # 1 2 3 ... > 次のページ
                ul = driver.find_element_by_css_selector("#form > div.paginationT > ul")
                lis = ul.find_elements_by_tag_name("li")
                for m in range(1,len(lis)-1):
                    print("= "+str(m+1)+" Page =")
                    link4 = lis[m].find_element_by_css_selector("a").get_attribute("href")
                    driver.get(link4)
                    # 次ページ以降の「料理名」テーブル要素のセレクターを指定
                    table5 = driver.find_element_by_css_selector("#form > div.result > table > tbody")
                    trs5 = table5.find_elements_by_tag_name("tr")
                    for n in range(1,len(trs5)):
                        time.sleep(0.1) # マナーの待機時間
                        link5 = trs5[n].find_element_by_css_selector("td:nth-child(1) > div > div > a").get_attribute("href")
                        print(link5)
                        # 個々の料理を解析
                        driver.get(link5)
                        dish_name = driver.find_element_by_css_selector("#main > p.manual > strong").text
                        divs2 = driver.find_elements_by_class_name("bargraph")
                        for o in range(0,len(divs2)-1):
                            trs7 = divs2[o].find_elements_by_css_selector("table > tbody > tr")
                            for p in range(1,len(trs7)):
                                nutrient_name = trs7[p].find_element_by_css_selector("td.item > a").text
                                nutrient_content = trs7[p].find_element_by_css_selector("td.capa").text.replace("kcal", "").replace("mg", "").replace("μg", "").replace("g", "").replace(",", ".")
                                nutrientJSON[nutrient_name] = float(nutrient_content)
                        trs8 = driver.find_elements_by_css_selector("#detailContR > div:nth-child(7) > div > table > tbody > tr")
                        for o in range(1,len(trs8)):
                            try:
                                nutrient_name = trs8[o].find_element_by_css_selector("td:nth-child(1) > a").text
                                nutrient_content = trs8[o].find_element_by_css_selector("td.alignR").text.replace("kcal", "").replace("mg", "").replace("μg", "").replace("g", "").replace(",", ".")
                                nutrientJSON[nutrient_name] = float(nutrient_content)
                            except:
                                nutrient_name = trs8[o].find_element_by_css_selector("td:nth-child(1)").text
                                nutrient_content = trs8[o].find_element_by_css_selector("td.alignR").text.replace("kcal", "").replace("mg", "").replace("μg", "").replace("g", "").replace(",", ".")
                                nutrientJSON[nutrient_name] = float(nutrient_content)
                        print("*** "+str(i)+"-"+str(j)+"-"+str(k)+"-"+str(l)+"-"+str(m)+"-"+str(n)+" ok ***")
                        dishJSON[dish_name] = nutrientJSON
                        nutrientJSON = {}   #JSON初期化
                        # 料理品目ごとに書き出し
                        f = codecs.open("JSON/temp_nutrientsData"+str(i)+"-"+str(j)+"-"+str(k)+"-"+str(l)+"-"+str(m)+"-"+str(n)+".json", "w", "utf-8")
                        json.dump(dishJSON, f, ensure_ascii=False)
                        f.close()
                        driver.back()
                        table5 = driver.find_element_by_css_selector("#form > div.result > table > tbody")
                        trs5 = table5.find_elements_by_tag_name("tr")
                        #driver.save_screenshot("./nutrient/link2_captcha"+str(l)+"-"+str(m)+"-"+str(n)+".png")
                    driver.back()
                    ul = driver.find_element_by_css_selector("#form > div.paginationT > ul")
                    lis = ul.find_elements_by_tag_name("li")
                categoryJSON[category_name] = dishJSON
                dishJSON = {}   #JSON初期化
                # カテゴリごとにも書き出し
                f = codecs.open("JSON/temp_dishData"+str(i)+"-"+str(j)+"-"+str(k)+"-"+str(l)+".json", "w", "utf-8")
                json.dump(categoryJSON, f, ensure_ascii=False)
                f.close()
                driver.back()
                table2 = driver.find_element_by_css_selector("#main > div.categoryLists > table > tbody")
                trs2 = table2.find_elements_by_tag_name("tr")
                tds2 = trs2[k].find_elements_by_tag_name("td")
            table2 = driver.find_element_by_css_selector("#main > div.categoryLists > table > tbody")
            trs2 = table2.find_elements_by_tag_name("tr")
        driver.back()
        table = driver.find_element_by_css_selector("#form > div:nth-child(6) > table > tbody > tr:nth-child(1) > td > div > table > tbody")
        trs = table.find_elements_by_tag_name("tr")
        tds = trs[i].find_elements_by_tag_name("td")

# 最後にJSON形式で書き出し
f = codecs.open("JSON/categoryData.json", "w", "utf-8")
json.dump(categoryJSON, f, ensure_ascii=False)
f.close()

 

作成したデータセットの内容確認


caption2
横軸が「1料理ごとに含まれる栄養素データの数(= len(nutrients_in_dish))」のヒストグラム。
量はあるけど中身スッカスカなデータセットができました。

とりあえず、70種類以下の栄養素データを持たない料理は除外してみました。
1919/6403個の料理データが残りました。

また、データセットを各カテゴリ(メインディッシュ・サイドメニュー・お菓子・飲み物)に分けて、カテゴリごとに閾値を設定してみました。

  • メインディッシュ(mainDish):
     1021/1336個 <- 60種類
  • サイドメニュー(sideDish):
     1247/1514個 <- 60種類
  • お菓子(dessert):
     545/1805個 <- 60種類
  • 飲み物(drink):
     429/2454個 <- 10種類

以下、参考にしたヒストグラム。

hist_all_category.png

2. 類似度計算スクリプトの作成

こちらの記事を参考に、類似度計算スクリプトを作成しました。
http://qiita.com/hik0107/items/96c483afd6fb2f077985

simCalculation.py
# coding: utf-8
import json, math
import numpy as np

# データ読込
f = open("data/formatFilterData.json", "r")
dataset = json.load(f)
f.close()

# 料理キー入力
print("料理名を入力してください。")
inputDish = input(">>>  ")

# 入力された料理を指定(日本語辞書を扱う上での苦肉の策)
for key in dataset.keys():
    if(key==inputDish):
        choicedDish = inputDish

# 類似度を計算する関数
def get_similairty(dish1, dish2):

  ## 両料理で共通した栄養素の和集合を生成
  setDish1 = set(dataset[dish1].keys())
  setDish2 = set(dataset[dish2].keys())
  setBoth = setDish1.intersection(setDish2)

  # 栄養素が共通でない場合は類似度を0とする(とりあえず)
  if len(setBoth)==0:
    return 0

  listDestance = []

  for item in setBoth:
    # 同じ栄養素の含有量の差の2乗を計算
    # この数値が大きいほど「類似していない」と定義できる 
    distance = pow(dataset[dish1][item]-dataset[dish2][item], 2) 
    listDestance.append(distance)

  # 各栄養素の非類似度の合計の逆比的な指標を返す(とりあえず)
  return 1/(1+np.sqrt(sum(listDestance)))

# レコメンド関数
def get_recommend(dish, top_N):

  simDishList = {} # 類似度と料理名の辞書
  simList = [] # 類似度のリスト

  # 入力された料理を除いたユーザリストを作成
  # -> 各料理との類似度を計算するため
  listOthers = list()
  for key in dataset.keys():
    listOthers.append(key)
  listOthers.remove(dish)

  # 類似度と料理の辞書を作成
  for other in listOthers:
    sim = get_similairty(dish, other)
    simDishList[sim] = other
    simList.append(sim)
  simList.sort()
  simList.reverse()
  recommendDict = {}
  for i in range(0,top_N):
    recommendDict[simDishList[simList[i]].replace("\u3000", " ")] = simList[i] # 全角空白"\u3000"をreplaceで半角空白に変換

  return recommendDict

# コマンドラインに表示
print(get_recommend(choicedDish, 5))

出力結果例

試しに「うな重」に類似した料理を出力してみました。

$ python3 simCalculation.py
栄養素名を入力してください。
>>>  うな重
{'ひつまぶし': 0.0014330883248297468, 'いくら丼': 0.0011092730907257715, 'ぜんざい': 0.0010865128148115706, 'いわしのごま漬け': 0.0010664954000816627, 'もんじゃ焼き': 0.001064990896452261}

そこそこ類似してるように見えます。

続いて、z-score変換を用いて正規化を行った後の出力結果がこちら。
(ソースコードは付録参照)

$ python3 simCalculation.py
栄養素名を入力してください。
>>>  うな重
{'ひつまぶし': 0.11615374412085168, 'いくら丼': 0.075985385218824225, '押し寿司盛り合わせ': 0.0698463684540752, 'にしんそば': 0.069143743122642959, '海鮮丼': 0.067707191581806936}

少し類似度が高まったみたいです。

今後の予定

  • 類似度計算をユークリッド距離->コサイン類似度に置換
  • 栄養素の充足率をダッシュボードで表示
  • REST API化

付録

データセット整形スクリプト

formatJSON.py
# coding: utf-8
import json

# データの読込
f = open("data/nutrientsData.json", "r")
sampleJSON = json.load(f)
f.close()

### 第1キーをリストとして抽出 ###
# valueをリスト形式に抽出
valueList = list()
for value in sampleJSON.values():
    valueList.append(value)
# 抽出したリストを統合する
formatJSON = {}
for value in valueList:
    formatJSON.update(value)

# 70種類以下の栄養素しか持たぬ料理データは除外
keyMaterials = list()
for key in formatJSON.keys():
    if(len(formatJSON[key]) >= 70):
        key.replace("\u3000", "") # 全角空白"\u3000"をreplaceで半角空白に変換
        keyMaterials.append(key)
formatJSON2 = {}
for i in range(len(keyMaterials)):
    formatJSON2[keyMaterials[i]] = formatJSON[keyMaterials[i]]
    i += 1

# データの書き出し
f = open("data/formatFilterData.json", "w")
json.dump(formatJSON2, f, ensure_ascii=False, sort_keys=True, indent=2)
f.close()

データセット正規化スクリプト

normalizeJSON.py
# coding: utf-8
import json
import numpy as np

# データの読込
f = open("data/formatFilterData.json", "r")
targetJSON = json.load(f)
f.close()

## 事前準備
# keyリストとvalueリストの作成
keyList = list()
for key in targetJSON.keys():
    keyList.append(key)
valueList = list()
for value in targetJSON.values():
    valueList.append(value)

# 最大栄養素数の確認と、栄養素カタログの作成
nutrientNum = []
nutrientList = []
keys = targetJSON.keys() 
for i in range(len(targetJSON)):
    nutrientNum.append(len(valueList[i]))
    for nutrient in valueList[i]:
        nutrientList.append(nutrient)
nutrientCatalog = list(set(nutrientList))

# 各栄養素の数値を全て収集
allNutrientAllList = [ [] for i in range(len(nutrientCatalog)) ]
for i in range(len(nutrientCatalog)):
    for value in targetJSON.values():
        if(value.keys() >= {nutrientCatalog[i]}):
            allNutrientAllList[i].append(value[nutrientCatalog[i]])
        else:
            allNutrientAllList[i].append(0.0)

## 正規化
# 各栄養素の平均値と標準偏差を算出
z_allNutrientAllList = [ [] for i in range(len(nutrientCatalog)) ]
for i in range(len(allNutrientAllList)):
    npValues = np.array(allNutrientAllList[i])
    meanValue = npValues.mean()
    stdValue = npValues.std()
    # z-score normalizationによる正規化を行う
    for value in allNutrientAllList[i]:
        z_allNutrientAllList[i].append((value-meanValue)/stdValue)  # z-score = (各value - 平均) / 標準偏差

# 正規化した値を用いたJSONを作成
normalizedJSON = {}
for i in range(len(keyList)): #=1873
    normalizedJSON[keyList[i]] = {}
    for j in range(len(nutrientCatalog)): #=95
        if(z_allNutrientAllList[j][i]!=z_allNutrientAllList[j][i]): # NaN(不定)を判別
            z_allNutrientAllList[j][i] = 0.0
        normalizedJSON[keyList[i]][nutrientCatalog[j]] = z_allNutrientAllList[j][i]+2

# データの書き出し
with open("data/formatNormalizedFilterData.json", "w") as f:
    json.dump(normalizedJSON, f, ensure_ascii=False, sort_keys=True, indent=2)
12
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
9