LoginSignup
0
0

More than 5 years have passed since last update.

chainerの作法 その9

Last updated at Posted at 2018-05-05

概要

chainerの作法を調べてみた。
手書きを推定してみた。

写真

image.png

サンプルコード

import sys
import numpy as np
from PyQt5.QtWidgets import *
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PIL import Image
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training, Chain, datasets, serializers
from chainer.training import extensions
import PIL

class MLP(Chain):
    def __init__(self, n_units, n_out):
        super(MLP, self).__init__(l1 = L.Linear(None, n_units), l2 = L.Linear(None, n_units), l3 = L.Linear(None, n_out))
    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

class pline(QWidget):
    def __init__(self, parent = None):
        QWidget.__init__(self, parent)
        self.px = None
        self.py = None
        self.points = []
        self.psets = []
        self.points_saved = []
        self.image2 = QImage(180, 180, QImage.Format_RGBA8888)
        self.image2.fill(qRgb(0, 0, 0))
    def mousePressEvent(self, event):
        self.points.append(event.pos())
        self.update()
    def mouseMoveEvent(self, event):
        self.points.append(event.pos())
        self.update()
    def mouseReleaseEvent(self, event):
        self.pressed = False
        self.psets.append(self.points)
        self.points_saved.extend(self.points)
        self.points = []
        self.update()
    def paintEvent(self, event):
        painter = QPainter(self)
        painter.setPen(QPen(Qt.black, 8, Qt.SolidLine))
        for points in self.psets:
            painter.drawPolyline(*points)
        if self.points:
            painter.drawPolyline(*self.points)
        painter2 = QPainter(self.image2)
        painter2.setPen(QPen(Qt.white, 8, Qt.SolidLine))
        for points in self.psets:
            painter2.drawPolyline(*points)
        if self.points:
            painter2.drawPolyline(*self.points)
    def save(self):
        self.image2.save(r'd:/test.png', 'png')
    def clear(self):
        self.points = []
        self.psets = []
        self.points_saved = []
        self.image2 = QImage(180, 180, QImage.Format_RGBA8888)
        self.image2.fill(qRgb(0, 0, 0))
        self.repaint()

class MainWindow(QWidget):
    def __init__(self, parent = None):
        super(MainWindow, self).__init__(parent)
        Button0 = QPushButton("recog")
        Button0.clicked.connect(self.recog)
        self.pain = pline()
        self.setGeometry(200, 200, 200, 200)
        layout = QVBoxLayout()
        layout.addWidget(self.pain)
        layout.addWidget(Button0)
        self.setLayout(layout)
        self.image = np.zeros((paint_width, paint_height, 3), np.uint8)
        self.setWindowTitle('MNIST')
    def recog(self):
        self.pain.save()
        img = Image.open('d:/test.png')
        img = img.resize((28, 28))
        img = img.convert('L')
        img_arr = np.array(img, dtype = np.float32).reshape(1, -1) / 255
        x = img_arr
        y = model.predictor(x)
        y = y.data
        p = y.argmax(axis = 1)[0]
        msg = QMessageBox()
        msg.setText("推定は、" + str(p))
        msg.setWindowTitle("recog")
        msg.setStandardButtons(QMessageBox.Ok)
        retval = msg.exec_()
        self.pain.clear()

paint_width = 180
paint_height = 180
model = L.Classifier(MLP(1000, 10))
serializers.load_npz('mnist3.model', model)

if __name__ == '__main__':
    app = QApplication(sys.argv)
    main_window = MainWindow()
    main_window.show()
    sys.exit(app.exec_())


以上。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0