Exapmole 1
def batch_iter(data, n_classes, img_size, batch_size, shuffle=True):
num_batches_per_epoch = int((len(data) - 1) / batch_size) + 1
def data_generator():
data_size = len(data)
while True:
# Shuffle the data at each epoch
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size))
shuffled_data = data[shuffle_indices]
else:
shuffled_data = data
for batch_num in range(num_batches_per_epoch):
start_index = batch_num * batch_size
end_index = min((batch_num + 1) * batch_size, data_size)
batch_list = shuffled_data[start_index: end_index]
x = []
y = []
for file_name in batch_list:
img_pil = Image.open(path_img+file_name+'.jpg').resize((img_size, img_size))
img_np = np.array(img_pil)
x.append(img_np)
class_pil = Image.open(path_class+file_name+'.png').resize((img_size, img_size))
class_np = np.array(class_pil)
y.append(class_np)
x = np.array(x) / 255.
y = np.array(y)
y_m = np.where(y==255, n_classes-1, y)
y_oh = np.identity(n_classes)[y_m]
yield x, y_oh
return num_batches_per_epoch, data_generator()
Example 2
train_df, validate_df = train_test_split(df, test_size=0.20, stratify= df['category'],
random_state=100)
train_df = train_df.reset_index(drop=True)
validate_df = validate_df.reset_index(drop=True)
total_train = len(train_df)
total_validate = len(validate_df)
train_datagen = ImageDataGenerator(
#rotation_range=15,
rescale=1./255,
#shear_range=0.1,
zoom_range=0.2,
horizontal_flip=True,
#width_shift_range=0.1,
#height_shift_range=0.1
)
train_generator = train_datagen.flow_from_dataframe(
train_df,
path_img,
x_col='filename',
y_col='category',
target_size=(img_size, img_size),
class_mode='binary',
batch_size=batch_size,
classes=[0, 1]
)
validation_datagen = ImageDataGenerator(rescale=1./255)
validation_generator = validation_datagen.flow_from_dataframe(
validate_df,
path_img,
x_col='filename',
y_col='category',
target_size=(img_size, img_size),
class_mode='binary',
batch_size=batch_size,
classes=[0, 1]
)
history = model.fit_generator(
train_generator,
steps_per_epoch=total_train//batch_size,
epochs=epochs,
validation_data=validation_generator,
validation_steps=total_validate//batch_size,
callbacks=callbacks,
verbose=1)