Source code for src.gridmind.feature_construction.cnn_feature_extractor

import torch.nn as nn
from torchvision.models import resnet18


[docs]class ResNetFeatureExtractor(nn.Module): def __init__(self, output_dim=512): # 512 for resnet18, 2048 for resnet50 super().__init__() resnet = resnet18(pretrained=True)
[docs] self.features = nn.Sequential( *list(resnet.children())[:-2] ) # Remove avgpool & fc
[docs] self.pool = nn.AdaptiveAvgPool2d((1, 1)) # Global average pooling
[docs] self.flatten = nn.Flatten()
[docs] self.output_dim = output_dim
[docs] def forward(self, x): x = self.features(x) x = self.pool(x) x = self.flatten(x) return x # shape: (B, output_dim)