Last updated at Posted at 2018-07-16


All-Optical Machine Learning Using Diffractive Deep Neural Networks






import keras
from keras.models import Model, Input
from keras.layers import *
import numpy as np

batch_size = 128
epochs = 50
lr = 0.001
D0 = 14*14
num_classes = 10

model_name = 'model.h5'

x_train = np.load("x_train.npy")
y_train = np.load("y_train.npy")
x_test = np.load("x_test.npy")
y_test = np.load("y_test.npy")

# background:white ,pen:black
x_train = 1 - x_train.reshape((-1,D0))/255 
x_test = 1 - x_test.reshape((-1,D0))/255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

## add noise for robustness
x_train += np.random.normal(0,0.1,x_train.shape)
x_test += np.random.normal(0,0.1,x_test.shape)

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

x = Input(shape=(D0,))
h = Dense(num_classes,kernel_constraint=keras.constraints.non_neg(),use_bias=False)(x)
y = Activation("softmax")(h)
model = Model(inputs=[x],outputs=[y])



# train
checkpoint = keras.callbacks.ModelCheckpoint(model_name,save_best_only=True)
reducelr = keras.callbacks.ReduceLROnPlateau(factor=0.5,patience=8,cooldown=2,verbose=1)

model.fit(x_train, y_train,
          validation_data=(x_test, y_test),

# evaluate
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])





import keras
import numpy as np
import matplotlib.pyplot as plt

model = keras.models.load_model("model.h5")
w, = model.layers[1].get_weights()
w = w.T

## convert mask to RGB565 format
a = np.max(w)
if np.min(w)<0:
    print("error: mask contains negative value")
w = w/a # normalize
w5 = np.clip((32*w),0,32).astype(np.uint16)
w6 = np.clip((64*w),0,64).astype(np.uint16)
w565 = 2048*w5+32*w6+w5
print("weights of 1st layer:",w565.shape)

## display components
# ref: https://pythonmemo.hatenablog.jp/entry/2018/04/22/204614
for i in range(N):



#include <M5Stack.h>

#define CLS 10

const int adc_pin = 36;
const int disp_wait = 30;
const int rep = 100;
const int capture_wait_us = 100;

const int ox = 104;
const int oy = 68;
const int cell = 8;
const int len = cell * 14;

const uint16_t wmask[CLS][196] = {{0, 32, 4226, 0, 32, 2145, 0, 32, 32, 0, 0, 32, 0, 2113, 6339, 2113, 4226, 4226, 2145, 2113, 6339, 6339, 8484, 12710, 6371, 2113, 2113, 2145, 32, 0, 4258, 4226, 10565, 10597, 8484, 4226, 0, 8452, 8484, 10565, 2113, 2145, 0, 2113, 4226, 10565, 16936, 6339, 32, 4258, 0, 0, 0, 6371, 19049, 2145, 2113, 4226, 8452, 10597, 8452, 8452, 8452, 6339, 0, 0, 0, 0, 12710, 32, 32, 32, 4258, 32, 10565, 16904, 10565, 33808, 25388, 32, 0, 0, 0, 0, 2113, 2145, 0, 0, 6371, 6371, 25356, 65535, 50712, 23243, 6371, 0, 0, 4226, 2145, 4226, 0, 0, 0, 0, 35953, 2080, 42292, 19049, 2145, 0, 0, 0, 2145, 2113, 0, 0, 0, 6339, 52825, 57051, 21162, 8452, 0, 0, 0, 0, 2113, 4226, 0, 0, 0, 0, 33808, 23275, 6371, 8484, 4226, 2113, 2145, 32, 0, 6371, 2145, 2113, 0, 0, 6339, 16936, 19017, 21130, 12710, 10597, 4258, 4226, 2113, 2145, 8484, 0, 0, 0, 0, 0, 8484, 19049, 14823, 4258, 2113, 2113, 0, 2113, 4226, 4226, 10565, 10597, 8484, 12710, 14791, 12678, 6339, 2113, 2145, 2145, 2145, 32, 0, 4226, 2113, 2113, 4258, 6339, 6371, 2145, 0, 4258, 32, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4226, 2113, 0, 2113, 0, 32, 0, 0, 0, 0, 0, 0, 0, 12678, 16904, 2145, 2145, 6339, 0, 0, 0, 0, 0, 0, 0, 0, 8452, 27501, 21162, 14791, 29582, 33808, 8452, 0, 4258, 4226, 0, 0, 0, 0, 12710, 27469, 29582, 14823, 0, 12678, 21130, 27469, 27469, 10565, 0, 0, 0, 2113, 8452, 16936, 27501, 12710, 0, 6371, 27469, 25388, 12710, 2145, 0, 0, 0, 0, 8484, 23243, 52857, 14791, 0, 19049, 25356, 12678, 2145, 0, 0, 0, 0, 0, 12710, 23275, 50712, 0, 0, 35921, 31727, 14791, 8452, 0, 0, 0, 0, 0, 23275, 25356, 25356, 0, 0, 38034, 31695, 12710, 4226, 32, 0, 0, 0, 8452, 31695, 21162, 21130, 0, 2145, 27469, 16904, 10597, 8452, 0, 0, 0, 0, 0, 8484, 6371, 21130, 19049, 21162, 6371, 12678, 12710, 4258, 0, 0, 0, 0, 0, 0, 0, 16904, 31727, 6371, 0, 2113, 4226, 32, 0, 0, 0, 0, 0, 2113, 10597, 19017, 19049, 14823, 2145, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {2113, 4226, 32, 2145, 0, 2113, 4226, 0, 32, 2145, 2113, 2145, 4258, 4226, 2145, 2113, 0, 32, 0, 0, 0, 0, 2113, 10565, 8484, 4226, 2145, 32, 32, 2145, 0, 0, 0, 0, 0, 0, 0, 4226, 14791, 12678, 6339, 32, 2113, 4226, 0, 0, 0, 0, 4226, 10565, 14791, 8452, 10597, 19017, 6371, 2113, 2113, 4226, 2113, 0, 6339, 2145, 8484, 8452, 16936, 10597, 12678, 19049, 10597, 32, 4258, 4226, 6339, 16936, 29582, 40147, 40179, 35921, 16904, 14823, 12678, 12710, 8484, 32, 2145, 4226, 16904, 44373, 44373, 42260, 40147, 38066, 35953, 12710, 10565, 16936, 2113, 2145, 32, 2113, 10597, 14823, 12678, 12710, 2145, 14823, 25388, 21162, 14823, 12678, 0, 32, 0, 32, 0, 0, 6339, 6339, 0, 12678, 8484, 10565, 0, 0, 0, 0, 2113, 0, 0, 0, 0, 32, 0, 0, 4226, 4258, 0, 0, 0, 2113, 2113, 0, 0, 0, 0, 2113, 12678, 16904, 6339, 0, 0, 0, 0, 32, 2145, 2145, 2145, 2113, 0, 8484, 19049, 19049, 2145, 0, 0, 0, 0, 32, 2145, 32, 6339, 12710, 21130, 25356, 21130, 16904, 12710, 12710, 4226, 4258, 32, 4226, 4226, 2113, 32, 2145, 0, 32, 2113, 6339, 4258, 6339, 4226, 32, 4258, 2145}, {4226, 32, 6339, 2145, 4258, 4226, 4226, 2145, 4258, 2113, 4226, 2145, 2145, 4226, 2145, 2113, 2145, 2145, 4258, 0, 0, 0, 0, 32, 4226, 6339, 2113, 2145, 4226, 0, 0, 0, 0, 0, 0, 4226, 6339, 6371, 12678, 10565, 2145, 4258, 2113, 2113, 0, 0, 0, 0, 0, 2145, 8452, 0, 4226, 16936, 8484, 4226, 4258, 0, 0, 2113, 10565, 19017, 19017, 32, 6339, 0, 0, 12710, 10565, 4226, 4226, 4226, 4226, 21130, 38066, 42292, 16936, 32, 2145, 0, 32, 12710, 4258, 4258, 6339, 2145, 8484, 27501, 27469, 23275, 8452, 8484, 10597, 12710, 19049, 14823, 2145, 32, 32, 2113, 6371, 25356, 29582, 23275, 4258, 16904, 23243, 6339, 2113, 6371, 2145, 4258, 4258, 0, 0, 19049, 33808, 38034, 25388, 29582, 12678, 0, 0, 0, 8452, 2145, 2145, 0, 0, 10597, 25388, 40179, 40179, 19017, 0, 0, 0, 0, 4258, 2145, 2145, 0, 0, 0, 8484, 23243, 21130, 14791, 32, 0, 0, 12678, 6339, 4226, 4258, 0, 0, 0, 2113, 6339, 16936, 12678, 4258, 2113, 8452, 4258, 4258, 4258, 2145, 32, 0, 0, 0, 0, 0, 0, 8452, 6371, 6371, 2145, 32, 2145, 4226, 4226, 32, 32, 4258, 4226, 6371, 4226, 4226, 2145, 2113, 2113, 0, 2145}, {0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 2113, 32, 0, 0, 32, 2113, 8484, 14791, 12678, 8452, 4258, 0, 32, 0, 0, 0, 0, 0, 32, 8452, 19017, 23275, 19049, 8484, 4226, 0, 0, 0, 0, 0, 0, 0, 2145, 10597, 16904, 27469, 44373, 29582, 12678, 2145, 0, 0, 0, 32, 32, 2145, 8452, 10565, 21162, 23275, 50744, 29614, 14791, 8484, 6371, 8452, 2113, 0, 2145, 8452, 14823, 2113, 0, 8452, 44373, 2113, 14823, 16936, 16936, 4258, 0, 32, 4226, 4226, 0, 0, 0, 8452, 25388, 0, 4226, 6339, 8452, 2113, 0, 0, 0, 0, 0, 0, 0, 10597, 8484, 0, 0, 0, 2113, 2113, 32, 0, 32, 2113, 0, 0, 12710, 2145, 0, 0, 0, 6371, 12678, 4258, 32, 0, 0, 12710, 23243, 25356, 40147, 10597, 0, 6339, 23275, 21130, 16936, 2145, 0, 0, 0, 14823, 29582, 40147, 35953, 31695, 21130, 16904, 12678, 10597, 4258, 32, 0, 0, 32, 32, 16904, 16904, 14791, 21130, 14791, 2145, 0, 0, 32, 0, 0, 0, 0, 32, 32, 6371, 12710, 19017, 12710, 16904, 6339, 32, 32, 32, 2113, 0, 2113, 0, 0, 2113, 4258, 8484, 10597, 6371, 6339, 2113, 32, 0, 0}, {4226, 4226, 2113, 32, 2113, 32, 0, 4226, 0, 4258, 32, 2113, 32, 0, 32, 32, 2113, 32, 4226, 4258, 4258, 10597, 6339, 4226, 4226, 0, 32, 4226, 2113, 4226, 6339, 10597, 14791, 16904, 19049, 21130, 8484, 6371, 2145, 0, 0, 2113, 32, 2145, 10597, 12710, 8484, 2113, 8484, 19017, 16936, 2145, 0, 0, 0, 0, 4258, 4226, 12678, 6371, 2113, 0, 8452, 27501, 27501, 12678, 0, 0, 0, 0, 32, 4258, 32, 2113, 0, 0, 0, 19017, 33840, 33808, 27469, 6371, 0, 0, 0, 32, 4226, 0, 6339, 0, 0, 29582, 40179, 27501, 31727, 35953, 10565, 32, 2145, 2113, 8484, 14823, 8484, 2113, 10597, 33808, 38034, 19017, 10597, 16936, 10565, 0, 32, 2145, 6339, 16904, 31727, 31727, 33808, 31695, 21130, 8484, 6339, 8484, 4258, 4226, 32, 2113, 0, 0, 10565, 27469, 16936, 10597, 8452, 6339, 6339, 32, 4258, 4226, 0, 4226, 2145, 0, 0, 4226, 10597, 16936, 8452, 6371, 4258, 32, 2113, 2145, 4258, 2113, 6339, 4226, 0, 32, 6371, 4258, 2145, 2145, 0, 0, 0, 4226, 2113, 2145, 0, 4258, 2145, 0, 0, 6339, 6339, 6371, 2145, 2113, 32, 2113, 32, 2113, 4226, 4226, 6339, 2113, 4258, 4226, 2145, 2113, 0, 32, 4258, 4226}, {0, 0, 0, 0, 0, 0, 32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2145, 0, 0, 4258, 4226, 12710, 12710, 10597, 0, 0, 0, 0, 32, 32, 0, 32, 4226, 6339, 14823, 19017, 21162, 29614, 23243, 14791, 16904, 16904, 6339, 0, 0, 2113, 4258, 12678, 12678, 19017, 23243, 27469, 46518, 46486, 33808, 29614, 14823, 32, 2113, 2113, 4226, 4226, 4258, 12678, 8452, 38034, 42260, 31695, 19049, 8452, 8452, 2145, 32, 2113, 6339, 2145, 8484, 8484, 6339, 35921, 19017, 21130, 32, 0, 0, 2113, 0, 32, 4226, 2113, 2145, 0, 12678, 6371, 23243, 21130, 0, 0, 0, 0, 2113, 0, 12710, 4258, 0, 0, 4226, 16936, 21130, 2113, 2113, 32, 8452, 6339, 0, 2145, 16936, 16936, 0, 0, 0, 0, 0, 0, 2145, 12678, 10597, 0, 0, 32, 14823, 25356, 4226, 0, 0, 0, 0, 0, 14791, 12678, 4258, 0, 0, 32, 2145, 14791, 25388, 25356, 16936, 16936, 19049, 19049, 12678, 4258, 32, 0, 2113, 0, 0, 6339, 10597, 16936, 21130, 19049, 12710, 4258, 2145, 2113, 0, 0, 0, 0, 0, 2113, 0, 0, 2113, 0, 0, 0, 2113, 0, 0, 32}, {0, 0, 0, 0, 0, 2113, 2113, 0, 0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6339, 4226, 2113, 32, 2113, 0, 0, 0, 0, 32, 0, 32, 10565, 21162, 31727, 35953, 29614, 19017, 12678, 4258, 0, 0, 0, 0, 0, 0, 0, 0, 8484, 27469, 14823, 8452, 12678, 14823, 8484, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8484, 6339, 0, 0, 0, 0, 32, 6339, 8484, 4226, 0, 0, 0, 0, 10597, 2145, 2113, 0, 0, 0, 10565, 14791, 38034, 59196, 31695, 0, 2145, 32, 6371, 2145, 0, 2145, 0, 2113, 14823, 29582, 46518, 54970, 27501, 8484, 0, 0, 32, 2113, 0, 0, 0, 8452, 19049, 27469, 35953, 25388, 12710, 4258, 0, 0, 6339, 32, 0, 0, 0, 8484, 29582, 38034, 31727, 14791, 12678, 21162, 23243, 25356, 16904, 4226, 0, 0, 0, 8484, 27469, 31695, 25388, 14823, 19049, 23275, 31727, 25388, 10565, 2113, 0, 32, 2113, 32, 2113, 6371, 10597, 21162, 8452, 10597, 19049, 14823, 6339, 32, 2113, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 2145, 2113, 0, 0, 0, 0, 2113, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0}, {4226, 6339, 4226, 4226, 2145, 4226, 4258, 2145, 4258, 2113, 2113, 4258, 4226, 4258, 6339, 2145, 2113, 2145, 32, 4226, 4258, 8484, 10565, 8484, 4226, 4258, 4258, 2145, 0, 2145, 6339, 4258, 10565, 12678, 2113, 0, 0, 4258, 4258, 10565, 4258, 4226, 2145, 4258, 8484, 10565, 8452, 4226, 6339, 8452, 6371, 4226, 4226, 4258, 4226, 4226, 2145, 2145, 2145, 0, 0, 2113, 8452, 29582, 25356, 6371, 2113, 32, 2113, 2145, 2113, 2145, 0, 0, 0, 2113, 0, 21162, 25356, 8484, 0, 0, 0, 4226, 2113, 2145, 2145, 2145, 8452, 16936, 0, 6339, 19017, 12678, 0, 0, 32, 2145, 2113, 2145, 6371, 23275, 35953, 14791, 0, 0, 23275, 23243, 23243, 14791, 4258, 2145, 4258, 4226, 10565, 21130, 6339, 4226, 0, 12678, 19049, 23243, 16904, 14791, 6371, 4258, 4226, 6371, 12678, 12678, 0, 4226, 12678, 19017, 12678, 12710, 12678, 6371, 6371, 4226, 4258, 6339, 12678, 4226, 10597, 21130, 23275, 19049, 19017, 8484, 8484, 6371, 6339, 2113, 2113, 4226, 12710, 10565, 8452, 32, 0, 0, 0, 0, 2145, 2113, 4258, 2145, 4226, 4258, 6339, 10565, 14791, 8452, 32, 0, 0, 0, 4226, 4226, 4226, 4226, 2145, 2113, 6339, 6339, 4258, 8452, 4226, 2145, 4258, 2145, 4258, 2145, 2145, 6339}, {2145, 32, 0, 32, 0, 32, 2113, 0, 0, 2113, 2113, 0, 2113, 0, 32, 32, 2113, 32, 0, 2145, 2145, 4258, 6339, 2145, 0, 32, 32, 0, 0, 0, 2113, 2145, 8484, 16904, 19049, 29582, 33808, 23243, 14823, 8452, 32, 32, 0, 32, 4258, 14791, 19017, 10597, 0, 0, 0, 2113, 14823, 23243, 8484, 32, 0, 32, 12710, 12678, 8484, 8452, 32, 0, 10565, 10597, 10597, 21162, 14791, 0, 32, 4226, 0, 0, 0, 0, 19017, 27469, 12678, 0, 0, 4226, 8452, 2145, 32, 2113, 0, 0, 32, 16936, 21162, 4258, 0, 0, 0, 2145, 4226, 2113, 0, 2113, 0, 4226, 14823, 14823, 21130, 14823, 0, 2113, 10565, 14791, 8452, 2145, 0, 0, 4258, 12678, 8484, 6371, 25388, 16904, 4258, 14823, 21162, 16904, 4258, 32, 32, 2113, 14791, 23275, 25356, 27501, 25356, 8484, 16904, 14823, 14823, 16904, 4258, 32, 0, 32, 8452, 21162, 33840, 35953, 31695, 23275, 21162, 19017, 10565, 6371, 4226, 0, 0, 0, 4258, 12710, 14791, 16904, 25388, 25388, 16936, 8484, 0, 0, 0, 0, 0, 2113, 0, 32, 0, 0, 2145, 4226, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2113, 0, 2113}};
float output[CLS];

const int out_pins[7] = {21, 22, 16, 17, 2, 5, 26 };
const boolean num_array[10][7] = {
  {0, 0, 0, 0, 0, 0, 1}, //0
  {1, 0, 0, 1, 1, 1, 1}, //1
  {0, 0, 1, 0, 0, 1, 0}, //2
  {0, 0, 0, 0, 1, 1, 0}, //3
  {1, 0, 0, 1, 1, 0, 0}, //4
  {0, 1, 0, 0, 1, 0, 0}, //5
  {0, 1, 0, 0, 0, 0, 0}, //6
  {0, 0, 0, 1, 1, 0, 1}, //7
  {0, 0, 0, 0, 0, 0, 0}, //8
  {0, 0, 0, 0, 1, 0, 0} //9

void dispMask(int mask_idx) {
  int p = 0;
  for (int i = oy; i < oy + len; i += cell) {
    for (int j = ox; j < ox + len; j += cell) {
      uint16_t color = wmask[mask_idx][p];
      M5.Lcd.fillRect(j, i, cell, cell, color);

void setup() {
  for (int i = 0; i < 7; i++) {
    pinMode(out_pins[i], OUTPUT);

  dacWrite(25, 0);

  //disp grid
  M5.Lcd.drawRect(ox - 1, oy - 1, len + 2, len + 2, WHITE);

  //  Serial.begin(115200);

void loop() {

  // capture
  for (int i = 0; i < CLS; i++) {
    output[i] = 0.0;
    for (int j = 0; j < rep; j++) {
      output[i] += float(analogRead(adc_pin));
    output[i] /= rep;

  float omax = -5000.0;
  int argmax = -1;
  for (int i = 0; i < CLS; i++) {
    if (omax < output[i]) {
      omax = output[i];
      argmax = i;
  //  Serial.println(argmax);

  // 7seg output
  for (int i = 0; i < 7; i++) {
    digitalWrite(out_pins[i], num_array[argmax][i]);




  • 暗いところでしか使えないので、光学系を工夫するなどして明るいところでも使えるようにする。
  • キャリブレーションやネットワーク構造を工夫して多層ネットワークを利用し、精度向上する。

