TVSSL/linear_classifier.py

47 lines
1.3 KiB
Python
Raw Normal View History

2023-09-12 13:32:15 +00:00
from __future__ import print_function
from torch import channel_shuffle
import torchvision
import torch.nn as nn
import torch
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, feat):
return feat.view(feat.size(0), -1)
class LinearClassifierResNet(nn.Module):
def __init__(self, layer=6, n_label=1000,p=0.5):
super(LinearClassifierResNet, self).__init__()
self.layer = layer
if layer == 1:
nChannels = 64
elif layer == 2:
nChannels = 64
elif layer == 3:
nChannels = 128
elif layer == 4:
nChannels = 256
elif layer == 5:
nChannels = 512
elif layer == 6:
nChannels = 512
else:
raise NotImplementedError('layer not supported: {}'.format(layer))
self.classifier = nn.Sequential()
self.classifier.add_module('Dropout', nn.Dropout(p=p))
self.classifier.add_module('LiniearClassifier', nn.Linear(nChannels, n_label))
self.initilize()
def initilize(self):
for m in self.modules():
if isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.fill_(0.0)
def forward(self, x):
return self.classifier(x)