Edited at

chainerの作法 その9

More than 1 year has passed since last update.


概要

chainerの作法を調べてみた。

手書きを推定してみた。


写真

image.png


サンプルコード

import sys

import numpy as np
from PyQt5.QtWidgets import *
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PIL import Image
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training, Chain, datasets, serializers
from chainer.training import extensions
import PIL

class MLP(Chain):
def __init__(self, n_units, n_out):
super(MLP, self).__init__(l1 = L.Linear(None, n_units), l2 = L.Linear(None, n_units), l3 = L.Linear(None, n_out))
def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
return self.l3(h2)

class pline(QWidget):
def __init__(self, parent = None):
QWidget.__init__(self, parent)
self.px = None
self.py = None
self.points = []
self.psets = []
self.points_saved = []
self.image2 = QImage(180, 180, QImage.Format_RGBA8888)
self.image2.fill(qRgb(0, 0, 0))
def mousePressEvent(self, event):
self.points.append(event.pos())
self.update()
def mouseMoveEvent(self, event):
self.points.append(event.pos())
self.update()
def mouseReleaseEvent(self, event):
self.pressed = False
self.psets.append(self.points)
self.points_saved.extend(self.points)
self.points = []
self.update()
def paintEvent(self, event):
painter = QPainter(self)
painter.setPen(QPen(Qt.black, 8, Qt.SolidLine))
for points in self.psets:
painter.drawPolyline(*points)
if self.points:
painter.drawPolyline(*self.points)
painter2 = QPainter(self.image2)
painter2.setPen(QPen(Qt.white, 8, Qt.SolidLine))
for points in self.psets:
painter2.drawPolyline(*points)
if self.points:
painter2.drawPolyline(*self.points)
def save(self):
self.image2.save(r'd:/test.png', 'png')
def clear(self):
self.points = []
self.psets = []
self.points_saved = []
self.image2 = QImage(180, 180, QImage.Format_RGBA8888)
self.image2.fill(qRgb(0, 0, 0))
self.repaint()

class MainWindow(QWidget):
def __init__(self, parent = None):
super(MainWindow, self).__init__(parent)
Button0 = QPushButton("recog")
Button0.clicked.connect(self.recog)
self.pain = pline()
self.setGeometry(200, 200, 200, 200)
layout = QVBoxLayout()
layout.addWidget(self.pain)
layout.addWidget(Button0)
self.setLayout(layout)
self.image = np.zeros((paint_width, paint_height, 3), np.uint8)
self.setWindowTitle('MNIST')
def recog(self):
self.pain.save()
img = Image.open('d:/test.png')
img = img.resize((28, 28))
img = img.convert('L')
img_arr = np.array(img, dtype = np.float32).reshape(1, -1) / 255
x = img_arr
y = model.predictor(x)
y = y.data
p = y.argmax(axis = 1)[0]
msg = QMessageBox()
msg.setText("推定は、" + str(p))
msg.setWindowTitle("recog")
msg.setStandardButtons(QMessageBox.Ok)
retval = msg.exec_()
self.pain.clear()

paint_width = 180
paint_height = 180
model = L.Classifier(MLP(1000, 10))
serializers.load_npz('mnist3.model', model)

if __name__ == '__main__':
app = QApplication(sys.argv)
main_window = MainWindow()
main_window.show()
sys.exit(app.exec_())

以上。