auxiliary_loss_base.py 253 Bytes
Newer Older
Paul Primus's avatar
Paul Primus committed
1
2
3
4
5
6
7
8
9
10
11
12
from abc import ABC, abstractmethod


class AuxiliaryLossBase(ABC):

    def __init__(self, weight=1.0):
        super().__init__()
        self.weight = weight

    @abstractmethod
    def auxiliary_loss(self, batch):
        raise NotImplementedError