14
19

More than 5 years have passed since last update.

word2vecをD3.jsのforce-simulationで可視化する

Last updated at Posted at 2017-12-04

概要

学習済みのword2vecモデルから、指定の単語に対して、その類似単語、さらに類似単語の類似単語を出力させ、各単語をノード、cos類似度をリンクの重みとしてd3.jsで力学グラフ化します。
t-sneなどで次元圧縮プロットする手法とはまた違った味があるかな、と思います。厳密さはあまりないので、ビジュアライゼーションの手法としては、解析向きというより説明向きの手法ですね。視覚化と言ったほうがいいかもしれません。
DOM/HTMLなのでカスタマイズが効きやすい点はメリットで、インタラクティブな感じに機能追加していくとウケが良さそうです。

出力例

graph.png

「AI」を指定単語として、その類似単語(緑)と、類似単語の類似単語(灰)までを出力しています。別途収集した、幾つかの新聞社のWeb記事(2MB程度)で学習させたモデルを使用しています。
ノード間の適正距離をcos類似度から計算しており、適正距離より近い場合は遠ざける方向へ、適正距離より遠い場合は近づける方向へ力が働きます。
「自殺」が出ていて物騒ですが、AIで自殺予防、の記事を取り込んだためと思われます。

実装

webサーバ側

Python3環境で実装しています。軽量webサーバライブラリのbottleでさくっとJSONサーバを仕立てます。dictを返すとそのままJSONにしてくれるのは便利ですね。

import os
import glob
import re
from bottle import route, run, request

from gensim.models import word2vec

hostname = "localhost"
port = 80
vecdir = ".\\"

@route("/words/vec", method="GET")
def words_vec():
    if any([x not in request.query.keys() for x in ["date","word"]]):
        return {"error": "invalid args"}

    date = request.query.date
    word = request.query.word

    if vecdir + "vec_" + date not in glob.glob(vecdir + 'vec_*'):
        return {"error": "invalid date"}

    model = word2vec.Word2Vec.load(vecdir + "vec_" + date)
    if word not in model.wv.vocab:
        return {"error": "'" + word + "' does not exist in vocabulary"}

    index = 0
    if "index" in request.query.keys():
        if not re.match("^[0-9]{1,2}$", request.query.index):
            return {"error": "invalid index"}
        index = int(request.query.index)

    nodes = {}
    links = []

    nodes[word] = 1

    for w1 in model.most_similar(positive=[word], negative=[], topn=index+12)[index:]:
        nodes[w1[0]] = 2
        links.append({"source":word, "target":w1[0], "value":w1[1]})

    for w1 in model.most_similar(positive=[word], negative=[], topn=index+12)[index:]:
        for w2 in model.most_similar(positive=[w1[0]], negative=[], topn=index+7)[index:]:
            if w2[0] not in nodes.keys():
                nodes[w2[0]] = 3
                links.append({"source":w1[0], "target":w2[0], "value":w2[1]})
            elif nodes[w2[0]] == 3:
                links.append({"source":w1[0], "target":w2[0], "value":w2[1]})

    nodes = [{"id":k, "group":v} for k, v in sorted(nodes.items(), key=lambda x:x[1])]

    return {"nodes":nodes, "links":links}

run(host=hostname, port=int(os.environ.get("PORT", port)))

localhost/words/vecにword、date、indexのクエリを指定してGETすると、jsonでd3.jsへの入力用データを返します。
wordは検索語、indexは類似度上位N件の除外(デフォルト0)の意です。
dateはモデルのファイル名に紐付きます。日付別にモデルを作っており、出力を切り替えたかったが為にこうなってます。そんな機能はいらんという人の方が多いと思いますので、適当に変更してください。

処理結果としてはnodesとlinksを返します。
nodesはid=単語、group=レベル(1は検索語、2は検索語の類似語、3は類似語の類似語)で構成される配列です。
linksはsource=起点nodeのid、target=終点nodeのid、value=node間のcos類似度で構成される配列になります。
javascript側ではnodesとlinksからグラフを生成していくことになります。

なお、処理を追えば分かりますが、一部のパターンのリンクは意図的に無視しています。例えば、レベル3→レベル1や、レベル2同士の類似度は、多くの場合ランキング上位に入ってきますが、あえて無視しています。
ちょっと恣意的で気持ち悪さはありますが、リンクが多いと表示結果が過度に密集したりとコントロールし辛く、今の形に落ち着きました。d3.js側のパラメタや、出力単語数等で調整はできると思いますので、色々試してみると良いと思います。

HTML/javascript側

コア部分だけ書きます。

//namespace
var ns = {};

//grapharea1
(function() {
    var svg = d3.select("#ga1")
    .call(d3.zoom()
          .scaleExtent([-8, 8])
          .on("zoom", function() {
        lines.attr("transform", d3.event.transform);
        circles.attr("transform", d3.event.transform);
        labels.attr("transform", d3.event.transform);
    }));

    var width = svg.attr("width");
    var height = svg.attr("height");

    var lines = svg.append("g").selectAll("line");
    var circles = svg.append("g").selectAll("circle");
    var labels = svg.append("g").selectAll("text");

    var nodes;
    var links;

    //※1
    var simulation = d3.forceSimulation()
        .force("link", d3.forceLink()
               .id(function(d){return d.id;})
               .distance(function(d){
                    var min = d3.min(links, function(e){return parseFloat(e.value);});
                    var max = d3.max(links, function(e){return parseFloat(e.value);});
                    var n = (max - parseFloat(d.value)) / (max - min);
                    return parseInt(n * 200);
        }))
        .force("charge",    d3.forceManyBody().strength(-50))
        .force("collide",   d3.forceCollide().radius(20))
        .force("center",    d3.forceCenter(width/2, height/2))
        .on("tick", ticked);

    function ticked() {
        lines
            .attr("x1", function(d){return d.source.x;})
            .attr("y1", function(d){return d.source.y;})
            .attr("x2", function(d){return d.target.x;})
            .attr("y2", function(d){return d.target.y;});
        circles
            .attr("cx", function(d){return d.x;})
            .attr("cy", function(d){return d.y;});
        labels
            .attr("x", function(d){return d.x;})
            .attr("y", function(d){return d.y;});
    }

    function update(word,date,index) {
        d3.json("/words/vec?date="+date+"&word="+word+"&index="+index, function(err, json){
            if (err) throw err;
            if ("error" in json) {
                console.log(json.error);
                d3.select("#msg").text(json.error);
                return;
            }

            nodes = json.nodes;
            links = json.links;

            nodes[0].fx = width/2;
            nodes[0].fy = height/2;

            //※2
            simulation.nodes(nodes);
            simulation.force("link").links(links);

            //※3
            circles = circles.data([]);
            circles.exit().remove();
            circles = circles.data(nodes).enter().append("circle")
                .attr("r",      10)
                .attr("fill",   function(d){return d3.schemeCategory20c[d.group*5%20];})
                .call(d3.drag()
                      .on("start", dragstarted)
                      .on("drag", dragged)
                      .on("end", dragended));

            function dragstarted(d) {
                if (!d3.event.active) simulation.alphaTarget(0.3).restart();
                d.fx = d.x;
                d.fy = d.y;
            }
            function dragged(d) {
                d.fx = d3.event.x;
                d.fy = d3.event.y;
            }
            function dragended(d) {
                if (!d3.event.active) simulation.alphaTarget(0);
                d.fx = null;
                d.fy = null;
            } 

            labels = labels.data([]);
            labels.exit().remove();
            labels = labels.data(nodes).enter().append("text")
                .text(function(d){return d.id;});

            lines = lines.data([]);
            lines.exit().remove();
            lines = lines.data(links).enter().append("line")

            simulation.alpha(1).restart();
        });
    }

    ns.ga1_update = update;
})();

だいたいはd3.jsのお作法通りなので、ポイントだけ補足しておきます。なお、d3.jsはv4を使用しています(v3とは大きく異なるので注意)。

上記のjs以外に必要な実装

  • https://d3js.org/d3.v4.min.jsを読み込んでおくこと
  • id="ga1"のSVG要素を定義しておくこと
  • ns.ga1_update(word,date,index)をどこかから呼び出すこと

※1
d3.forceLinkのdistanceに指定するコールバックにおいて、d.value(=links[n].value)を使用し、cos類似度を距離に反映しています。また、cos類似度はスケールがまちまちになる(1位が0.999で10位が0.995のようなパターンもあれば、1位が0.910で10位が0.750みたいなこともある)ため、スケールを揃える計算処理を入れています。

※2
simulationとnodes/links配列を紐付ける箇所です。
なかなか分かりづらいですが、実は、circleやlineといったsvg要素の座標とnodes/linksは直接的には紐付いていません。simulationのtickに指定したコールバックtickedにより、circle等の座標が間接的に書き換えられる形になっています。

※3
各単語に対応するcircle要素を生成する箇所ですが、fillにd.group(=nodes[n].group)を使用し、group値による色分けをさせています。

補足
updateは繰り返し呼ばれる前提で書いてあります。検索語を変えて再度呼び出せば、ページを再読み込みせずにグラフだけ書き直すことが可能です。
冒頭のns、及び即時関数での処理のラッピングは名前空間を汚さないためのjavascriptのお作法です。(最近はletやconstが使えるので余りお勧めすべき書き方ではないのでしょうが…)

所見等

やはりDOMで表現することによる柔軟性は魅力です。単語をクリックすると元ネタ記事が参照できる、みたいなインターフェースも容易に付け加えられますし、circleの大きさや色で他の情報(例えば元テキスト内の単語数)を表現したり、等も考えられます。

14
19
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
19