import warnings
from .interfaces import Monitor
from .exceptions import SingleDimensionalLayerWarning, LayerRegisterException
[docs]class ObserverHookObject:
"""
AN object used to store:
1. The observer object which is being hooked
2. Parameters being monitored
3. Handler of the hooked layer
"""
def __init__(self, kwargs):
self.__dict__.update(**kwargs)
[docs]class FeatureMapMonitor(Monitor):
"""
A monitor type class for visualizing the feature maps of a neural network.
"""
def __init__(self):
self._layer_observers = dict() # Maintains a mapping of layers/observers being monitored.
[docs] def add_observer(self, layer_observer):
"""
1. Creates a layer observer object.
2. Hooks the layer to capture the activation map of the layer.
3. Adds the observer object to the list of monitored observers.
:param layer_observer: Observer object
"""
if not (hasattr(layer_observer, "get_layer") and hasattr(layer_observer, "get_layer_name")):
raise AttributeError(
"Layer Observer objects must have a get_layer and get_layer_name function defined. "
)
# Get the pyTorch layer object and layer from the observer.
layer = layer_observer.get_layer()
layer_name = layer_observer.get_layer_name()
# If layer is already being monitored then return.
if layer_name in self._layer_observers:
return
# Creates an observer hook object and store it in the list of monitored observers.
_observer_hook_object = ObserverHookObject(
{"object": layer_observer, "parameters": None, "handler": None}
)
self._layer_observers[layer_name] = _observer_hook_object
# Create a forward hook to capture the activation map for that layer.
handler = layer.register_forward_hook(self._get_activation_map(layer_name))
# Store the handler of the hook.
self._layer_observers[layer_name].handler = handler
[docs] def remove_observer(self, layer_observer=None, layer_name=None):
"""
If the layer observer/layer name is in the list of monitored observers:
1. Unhook the layer.
2. Remove the layer observer from the list of monitored observers.
:param layer_observer: Layer Observer.
:param layer_name: Name of the layer being monitored.
:return: True/False based on whether the observer by deleted or not.
"""
# Either layer name or the layer observer must be provided.
if not layer_name and not layer_observer:
raise NameError("Either layer_name or observer_object must be provided.")
# If layer name is not specified then extract the layer name from the observer object.
elif not layer_name:
layer_name = layer_observer.get_layer_name()
# Get the hook object from the layer name.
hook = self._layer_observers.get(layer_name, None)
# If the hook was present for the layer name, then remove the unhook its handler and
# delete the layer observer from the list of observers being monitored.
if hook:
hook.handler.remove()
del self._layer_observers[layer_name]
return True
return False # Return false if layer is not present
@staticmethod
[docs] def _is_layer_single_dim(layer):
"""Checks if the layer is a single dimensional layer"""
return len(layer.squeeze().shape[1:]) == 1
[docs] def _get_activation_map(self, layer_name):
"""Hooks the layer to capture activation maps for the given layer and return the handler to the hook"""
def hook(model, inp, out):
try:
observer = self._layer_observers[layer_name]
observer.parameters = out.detach()
except NameError:
raise LayerRegisterException(
layer_name=layer_name
) # Raise an error if the layer fails to register
return hook
[docs] def notify_observers(self):
"""Updates all the observers being monitored with the new parameters"""
for observer_name, observer in self._layer_observers.items():
if observer.parameters is None:
continue
if self._is_layer_single_dim(observer.parameters):
# If layer is a single dimensional layer, then raise a warning as an image needs
# to be at least of two dimensions in order to be plotted on a graph.
warnings.warn(SingleDimensionalLayerWarning(observer_name))
continue
# Retrieve the new parameters for an observer and
# update the observers object with the new parameters.
parameters = observer.parameters
observer.object.update(parameters)
[docs] def get_registered_observers(self):
"""Returns the list of observers being monitored."""
return self._layer_observers