Commit 6e2c8d88 authored by Oliver Kirsebom's avatar Oliver Kirsebom
Browse files

part III working

parent f2336e1d
...@@ -629,29 +629,28 @@ ...@@ -629,29 +629,28 @@
<a id='step_10'></a> <a id='step_10'></a>
## 10. Configure batch generator ## 10. Configure batch generator
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
We load the training/validation dataset from the same HDF5 database that contains the test dataset used previously. We load the training/validation dataset from the HDF5 table that we created previously.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import numpy as np import numpy as np
from ketos.data_handling.data_feeding import BatchGenerator # A helper class to read data from disk in batches from ketos.data_handling.data_feeding import BatchGenerator # A helper class to read data from disk in batches
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
We'll split this dataset into training and validation using an stratified sampling algorithm ([scikit-learn's StratifiedShuffleSplit](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html)). This yields training and validation datasets with the same proportions of positive (upcall) and negative (no upcall) examples. We will split the dataset into a training set of 30 (randomly selected) samples and a validation set consisting of the 10 remaining samples.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
train_indices = np.random.choice(np.arange(40), 30, replace=False) train_indices = np.random.choice(np.arange(40), 30, replace=False) # select 30 random indices out of 0...39
val_indices = [i for i in range(40) if i not in train_indices] val_indices = [i for i in range(40) if i not in train_indices] # 10 indices that were not selected
print(train_indices) print(train_indices)
print(val_indices) print(val_indices)
``` ```
...@@ -674,11 +673,11 @@ ...@@ -674,11 +673,11 @@
val_generator = BatchGenerator(hdf5_table=table, batch_size=10, instance_function=transform_batch, y_field="labels", shuffle=True, refresh_on_epoch_end=True, indices=val_indices) val_generator = BatchGenerator(hdf5_table=table, batch_size=10, instance_function=transform_batch, y_field="labels", shuffle=True, refresh_on_epoch_end=True, indices=val_indices)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Now we have configured two batch generators, which will load 32 spectrograms and associated labels at a time during the training process. After attaching these generators to the new_resnet_model, we can run the training loop for a couple of epochs. Now we have configured two batch generators, which will load 10 spectrograms and associated labels at a time during the training process. After attaching these generators to the new_resnet_model, we can run the training loop for a couple of epochs.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
new_resnet_model.set_train_generator(train_generator) new_resnet_model.set_train_generator(train_generator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment