Commit 6e2c8d88 authored by Oliver Kirsebom's avatar Oliver Kirsebom
Browse files

part III working

parent f2336e1d
......@@ -629,29 +629,28 @@
<a id='step_10'></a>
## 10. Configure batch generator
%% Cell type:markdown id: tags:
We load the training/validation dataset from the same HDF5 database that contains the test dataset used previously.
We load the training/validation dataset from the HDF5 table that we created previously.
%% Cell type:code id: tags:
``` python
import numpy as np
from ketos.data_handling.data_feeding import BatchGenerator # A helper class to read data from disk in batches
```
%% Cell type:markdown id: tags:
We'll split this dataset into training and validation using an stratified sampling algorithm ([scikit-learn's StratifiedShuffleSplit](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html)). This yields training and validation datasets with the same proportions of positive (upcall) and negative (no upcall) examples.
We will split the dataset into a training set of 30 (randomly selected) samples and a validation set consisting of the 10 remaining samples.
%% Cell type:code id: tags:
``` python
train_indices = np.random.choice(np.arange(40), 30, replace=False)
val_indices = [i for i in range(40) if i not in train_indices]
train_indices = np.random.choice(np.arange(40), 30, replace=False) # select 30 random indices out of 0...39
val_indices = [i for i in range(40) if i not in train_indices] # 10 indices that were not selected
print(train_indices)
print(val_indices)
```
......@@ -674,11 +673,11 @@
val_generator = BatchGenerator(hdf5_table=table, batch_size=10, instance_function=transform_batch, y_field="labels", shuffle=True, refresh_on_epoch_end=True, indices=val_indices)
```
%% Cell type:markdown id: tags:
Now we have configured two batch generators, which will load 32 spectrograms and associated labels at a time during the training process. After attaching these generators to the new_resnet_model, we can run the training loop for a couple of epochs.
Now we have configured two batch generators, which will load 10 spectrograms and associated labels at a time during the training process. After attaching these generators to the new_resnet_model, we can run the training loop for a couple of epochs.
%% Cell type:code id: tags:
``` python
new_resnet_model.set_train_generator(train_generator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment