6
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

GPSが位置信号を受信できていない時、周囲のWiFi情報を使って位置情報を特定できたらいいなと思ったら、機械学習でできたという話

Last updated at Posted at 2021-02-28

はじめに

Raspberry PiとGPSモジュールを使って現在位置を取得するプログラムを作成しているのですが、電源を入れてからGPSが有効になるまでにはどうしても時間がかかってしまう。
時には10分近く位置を取得できない場合もあり、ちょっと使い物にならない。

そんな時、思いついたのが周囲のWi-Fiアクセスポイントの情報を使ってGPSの代わりに位置をある程特定できるんじゃないのかなということ。

ということで、試してみたらこれがなかなかの精度でできたので、嬉しくなってメモしてみることにした。

利用環境

  • Raspberry Pi 3B+
  • GPSモジュール
  • SORACOM Air
  • Python 3.x
  • SQLite3
  • scikit-learn
  • micropyGPS

データ収集

データベースの作成

ここではSQLite3を使います。

sqlite3 data.db

以下のSQLを実行。

create table gps (
	dat text,
	lon real,
	lat real,
	wlan text
);

Wi-Fiデータ取得

iwlistを使って、Wi-Fiの情報を表示するスクリプトを作成。

get_iwlist.sh
#!/bin/sh
sudo iwlist wlan0 scan | grep Address

GPSデータ取得

USB接続のGPSモジュールを接続して測位情報を取得するために以下のスクリプトを作成。

gps.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from time import sleep
import serial
import micropyGPS
from datetime import datetime, timedelta
import subprocess
import sqlite3

DEVICE_NAME = "/dev/ttyACM"
RATE = 119200

DBNAME = "data.db"

gps = micropyGPS.MicropyGPS(9, 'dd')

lon = 0.0
lat = 0.0
dat = ""

def init_device():
    for i in range(5):
        try:
            device_name = "{}{}".format(DEVICE_NAME, i)
            s = serial.Serial(device_name, RATE, timeout=10)
            s.readline().decode('utf-8')
            print("device find : " + device_name)
            return s
        except:
            print("no device : " + device_name)

try:
    s = serial.Serial(DEVICE_NAME + "0", RATE, timeout=10)
except:
    s = init_device()

i = 0
t = datetime.now()
flg = True

while flg:
    try:
        sentence = s.readline().decode('utf-8')
        print("read data : {}".format(len(sentence)))
        if sentence[0] != '$':
            continue
        for x in sentence:
            gps.update(x)
            
        dat = "{:04}-{:02}-{:02} {:02}:{:02}:{:02}".format(
            gps.date[2]+2000,
            gps.date[1],
            gps.date[0],
            gps.timestamp[0] if gps.timestamp[0] < 24 else gps.timestamp[0] - 24,
            gps.timestamp[1],
            int(gps.timestamp[2])
        )
        lon = gps.longitude[0]
        lat = gps.latitude[0]
    except:
        print("GPS data error.")
        init_device()
        continue
                
    try:
        if (datetime.now() - t).seconds > 0:
            print(lon, lat, dat)
            
            completedProcess = subprocess.run(['/home/pi/get_iwlist.sh'], check=True, shell=True, stdout=subprocess.PIPE)
            wlan = completedProcess.stdout.decode('utf-8')

            sql = """
INSERT INTO gps VALUES (
	"{}",
	{},
	{},
	'{}'
);
""".format(dat, lon, lat, wlan)

            if lon > 0:
                con = sqlite3.connect(DBNAME)
                cur = con.cursor()
                cur.execute(sql)
                con.commit()
                cur.close()
                con.close()

            t = datetime.now()
    except:
    	print("Data error.")

実行

スクリプトを実行し、データを収集。

python3 gps.py

実行後、機器を持って周辺をうろうろします。

学習

import pandas as pd
import sqlite3

DB_FILE = "data.db"
conn = sqlite3.connect(DB_FILE)

df = pd.read_sql_query('SELECT * FROM gps WHERE lon > 0', conn)

addr = []

for txt in df.wlan:
    for t in txt.split("\n"):
        if t.find("Address") > 0:
            mac = t.split("Address: ")[-1]
            if mac in addr:
                pass
            else:
                addr.append(mac)

for a in addr:
    df[a] = 0

for i in df.index:
    print("\r{}".format(i), end="")
    row = df.loc[i]
    txt = row.wlan
    for t in txt.split("\n"):
        if t.find("Address") > 0:
            df[t.split("Address: ")[-1]].loc[i] = 1

K = 1000000

X = df[df.columns[4:]]
y_lon = df.lon * K
y_lat = df.lat * K

from sklearn import model_selection
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor

model = {}

model["lon"] = RandomForestRegressor()
model["lat"] = RandomForestRegressor()

y = {}
y["lon"] = y_lon.astype("int")
y["lat"] = y_lat.astype("int")

y_act = {}

for m in ["lon", "lat"]:
    
    print(m)
    
    X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y[m], test_size=.1, random_state=42)
    y_act[m] = y_test

    scaler = StandardScaler()
    scaler.fit(X_train)

    X_train = scaler.transform(X_train)
    X_test = scaler.transform(X_test)

    model[m].fit(X_train, y_train)
    print(model[m].score(X_test, y_test))

import joblib

joblib.dump(scaler, "scaler.pkl")
joblib.dump(model["lon"], "mdl_lon.pkl")
joblib.dump(model["lat"], "mdl_lat.pkl")

df.tail().to_csv("data.csv")

実行すると...

lon
0.9944890346537838
lat
0.9922297709277479

結構精度高く学習できてる!

位置予測

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import joblib

scaler = joblib.load('scaler.pkl')
model = {}
model["lon"] = joblib.load('mdl_lon.pkl')
model["lat"] = joblib.load('mdl_lat.pkl')

import pandas as pd

df = pd.read_csv("data.csv")

import subprocess

completedProcess = subprocess.run(['/home/pi/get_iwlist.sh'], check=True, shell=True, stdout=subprocess.PIPE)
txt = completedProcess.stdout.decode('utf-8')

row = df.loc[-1]
for t in txt.split("\n"):
    if t.find("Address") > 0:
        df[t.split("Address: ")[-1]].loc[-1] = 1

k = 1000000

X = df[df.columns[4:]]
X = scaler.transform(X)

pred = {}
for m in ["lon", "lat"]:
    pred[m] = model[m].predict(X)[-1] / K
    print("{} : {}".format(m, pred[m])

できた!!

6
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?