PyQt5
Chainer
手書き
MNIST

概要

chainerの作法を調べてみた。
手書きを推定してみた。

写真

image.png

サンプルコード

import sys
import numpy as np
from PyQt5.QtWidgets import *
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PIL import Image
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training, Chain, datasets, serializers
from chainer.training import extensions
import PIL

class MLP(Chain):
    def __init__(self, n_units, n_out):
        super(MLP, self).__init__(l1 = L.Linear(None, n_units), l2 = L.Linear(None, n_units), l3 = L.Linear(None, n_out))
    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

class pline(QWidget):
    def __init__(self, parent = None):
        QWidget.__init__(self, parent)
        self.px = None
        self.py = None
        self.points = []
        self.psets = []
        self.points_saved = []
        self.image2 = QImage(180, 180, QImage.Format_RGBA8888)
        self.image2.fill(qRgb(0, 0, 0))
    def mousePressEvent(self, event):
        self.points.append(event.pos())
        self.update()
    def mouseMoveEvent(self, event):
        self.points.append(event.pos())
        self.update()
    def mouseReleaseEvent(self, event):
        self.pressed = False
        self.psets.append(self.points)
        self.points_saved.extend(self.points)
        self.points = []
        self.update()
    def paintEvent(self, event):
        painter = QPainter(self)
        painter.setPen(QPen(Qt.black, 8, Qt.SolidLine))
        for points in self.psets:
            painter.drawPolyline(*points)
        if self.points:
            painter.drawPolyline(*self.points)
        painter2 = QPainter(self.image2)
        painter2.setPen(QPen(Qt.white, 8, Qt.SolidLine))
        for points in self.psets:
            painter2.drawPolyline(*points)
        if self.points:
            painter2.drawPolyline(*self.points)
    def save(self):
        self.image2.save(r'd:/test.png', 'png')
    def clear(self):
        self.points = []
        self.psets = []
        self.points_saved = []
        self.image2 = QImage(180, 180, QImage.Format_RGBA8888)
        self.image2.fill(qRgb(0, 0, 0))
        self.repaint()

class MainWindow(QWidget):
    def __init__(self, parent = None):
        super(MainWindow, self).__init__(parent)
        Button0 = QPushButton("recog")
        Button0.clicked.connect(self.recog)
        self.pain = pline()
        self.setGeometry(200, 200, 200, 200)
        layout = QVBoxLayout()
        layout.addWidget(self.pain)
        layout.addWidget(Button0)
        self.setLayout(layout)
        self.image = np.zeros((paint_width, paint_height, 3), np.uint8)
        self.setWindowTitle('MNIST')
    def recog(self):
        self.pain.save()
        img = Image.open('d:/test.png')
        img = img.resize((28, 28))
        img = img.convert('L')
        img_arr = np.array(img, dtype = np.float32).reshape(1, -1) / 255
        x = img_arr
        y = model.predictor(x)
        y = y.data
        p = y.argmax(axis = 1)[0]
        msg = QMessageBox()
        msg.setText("推定は、" + str(p))
        msg.setWindowTitle("recog")
        msg.setStandardButtons(QMessageBox.Ok)
        retval = msg.exec_()
        self.pain.clear()

paint_width = 180
paint_height = 180
model = L.Classifier(MLP(1000, 10))
serializers.load_npz('mnist3.model', model)

if __name__ == '__main__':
    app = QApplication(sys.argv)
    main_window = MainWindow()
    main_window.show()
    sys.exit(app.exec_())


以上。