Skip to content

GitLab

  • Menu
Projects Groups Snippets
    • Loading...
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in / Register
  • ketos ketos
  • Project information
    • Project information
    • Activity
    • Labels
    • Planning hierarchy
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 27
    • Issues 27
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 0
    • Merge requests 0
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Wiki
    • Wiki
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • public_projects
  • ketosketos
  • Issues
  • #147

Closed
Open
Created Dec 24, 2021 by Oliver Kirsebom@kirsebomOwner

How to add a new frontend layer to an existing neural net

The below code examples illustrates how we could go about adding new frontend layers to existing neural nets in ketos. The idea would be to use this approach to add e.g. a PCEN layer or a Spectrogram computation layer in front of e.g. a ResNet. I'm not sure if this is a good solution. Comments are welcome! Thanks

import numpy as np
import tensorflow as tf
from ketos.neural_networks.dev_utils.nn_interface import RecipeCompat, NNInterface

# here we define a simple feedforward network
class MLP(tf.keras.Model):
        def __init__(self, n_neurons=128, activation='relu'):
            super(MLP, self).__init__()
            self.dense = tf.keras.layers.Dense(n_neurons, activation=activation)
            self.final_node = tf.keras.layers.Dense(1)
        
        # here we implement a method specific to this network, it could be anything    
        def print_hello(self):
            print('hello')

        def call(self, inputs):
            output = self.dense(inputs)
            output = self.dense(output)
            output = self.final_node(output)
            return output
            
# and here we create a ketos interface for this simple network
class MLPInterface(NNInterface):
    def __init__(self, n_neurons, activation, optimizer, loss_function, metrics):
        super(MLPInterface, self).__init__(optimizer, loss_function, metrics)
        self.n_neurons = n_neurons
        self.activation = activation
        self.model = MLP(n_neurons=n_neurons, activation=activation)
    
    # this method allows the user to add a new (frontend) layer 
    # in front of the network by passing an instance of 
    # tf.keras.Layer. The idea would be to implement this method 
    # in NNInterface so that all derived classes would automatically 
    # inherit it.
    def add_frontend(self, frontend, input_shape):                
        inputs = tf.keras.Input(shape=input_shape)
        x = frontend(inputs)
        outputs = self.model(x)
        self.model = tf.keras.Model(inputs=inputs, outputs=outputs)
        

# let's create an instance of the interface class
adam = RecipeCompat("adam", tf.keras.optimizers.Adam)
bce = RecipeCompat("bce", tf.keras.losses.BinaryCrossentropy)
acc = RecipeCompat("acc", tf.keras.metrics.Accuracy)
nn = MLPInterface(n_neurons=128, 
    activation='relu', optimizer=adam, loss_function=bce, metrics=[acc])

# let's process a single input sample
x = np.ones(shape=(1,128))
y = nn.model(x)
print(y)

# note that at this point we can still use the 'print_hello' method 
# of the MLP class
nn.model.print_hello()

# now, let's add a new frontend layer
frontend = tf.keras.layers.Dense(128, activation='relu')
nn.add_frontend(frontend, input_shape=(128,))

# using the modified network, we process another input sample
y = nn.model(x)
print(y)  #note that the output value is different, as it should be

# however, since the self.model attribute is no longer an instance of 
# the MLP class, we can no longer use any of the class' methods, such 
# as the print_hello method
nn.model.print_hello()  #this results in an error :-(
To upload designs, you'll need to enable LFS and have an admin enable hashed storage. More information
Assignee
Assign to
Time tracking