マルチプロセス内でTensorflow2.Xが動かない。
Q&A
Closed
解決したいこと
マルチプロセス内でTensorflow2.Xが動かしたいです。
エラーメッセージは出ていません。
解決方法を教えて下さい。
該当するソースコード
import time
import tensorflow as tf
import tensorflow.keras.layers as kl
import numpy as np
import random
from multiprocessing import Process, Pool
import multiprocessing
class NeuralNetwork(tf.keras.Model):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.action_space = 1
self.dense1 = kl.Dense(128, activation="relu")
self.dense2 = kl.Dense(128, activation="relu")
self.out = kl.Dense(self.action_space)
self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
@tf.function
def call(self, x):
x = self.dense1(x)
x = self.dense2(x)
out = self.out(x)
return out
def get_data():
images = np.random.randn(12, 3)
labels = np.array([])
for m in images:
num = 4 * m[0] ** 3 + 5 * m[1] ** 2 + 6 * m[2] + random.random()
labels = np.append(labels, num)
labels = labels.astype(np.float32)
return images, labels
def worker(i):
net = NeuralNetwork()
images, labels = get_data()
arr = tf.constant([[1 ,2 ,3]], dtype = tf.float32)
print(arr)
net(arr)
with multiprocessing.Pool(processes=4) as pool:
pool.map(worker, range(4))
1