[docs]classResNetFeatureExtractor(nn.Module):def__init__(self,output_dim=512):# 512 for resnet18, 2048 for resnet50super().__init__()resnet=resnet18(pretrained=True)
[docs]self.features=nn.Sequential(*list(resnet.children())[:-2])# Remove avgpool & fc
[docs]self.pool=nn.AdaptiveAvgPool2d((1,1))# Global average pooling