import torch.nn as nn
from .monitor import FeatureMapMonitor
from .observers import FeatureMapObserverFactory
[docs]class FeatureMapMonitoring(object):
"""
A wrapper class for adding a model or model layers for monitoring the feature maps
during training of the model.
"""
def __init__(self):
"""
Initialises:
1. A observer factory for creating observer for a given layer.
2. A monitor for displaying the feature maps for the monitored layer.
"""
self.observer_factory = FeatureMapObserverFactory()
self.monitor = FeatureMapMonitor()
[docs] def add_layer(self, layer, layer_name):
"""
Adds the layer whose feature maps are to be monitored.
:param layer: pyTorch layer
:param layer_name: (str) name of the layer
:return: the observer object of the layer being monitored
"""
if not layer_name:
raise NameError("Specify the name of the layer to be monitored.")
# the layer to be monitored should be a subclass of nn.Module
if not issubclass(layer.__class__, nn.Module):
raise ReferenceError("Layer should be a subclass of nn.Module")
return self._add_layer(layer=layer, layer_name=layer_name)
[docs] def _add_layer(self, layer, layer_name):
"""
Create a observer class of the layer to be monitored and adds it to the list of
observers being monitored.
:param layer: pyTorch layer
:param layer_name: (str) name of the layer
:return: Observer object of the layer
"""
layer_observer = self.observer_factory.create(layer=layer, layer_name=layer_name)
self.monitor.add_observer(layer_observer=layer_observer)
return layer_observer
[docs] def _remove_layer(self, layer_name):
"""
Removes an observer from the list of observers being monitored.
:param layer_name: name of the observer
:return: True if observer was deleted, else False
"""
self.monitor.remove_observer(layer_name=layer_name)
[docs] def remove_layer(self, layer_name):
"""
Remove the layer from list of layer being monitored.
:param layer_name: str (name of the layer)
:return: True if the layer is deleted from the list of monitored objects, else False
"""
return self._remove_layer(layer_name=layer_name)
[docs] def add_model(self, model):
"""
Registers all the layers a pyTorch model whose activations maps are to monitored.
:param model:
:return:
"""
if not isinstance(model, nn.Module):
raise AttributeError("Model should be an instance of nn.Module")
for layer_name, layer in model.named_children():
self.add_layer(layer=layer, layer_name=layer_name)
[docs] def start(self):
"""Starts monitoring the feature maps of the registered layers/model."""
self.monitor.notify_observers()