From f46aedbf006b9c0bf334379ce803af8ffcd9d36c Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 2 Aug 2022 02:06:21 +0800 Subject: [PATCH 01/46] =?UTF-8?q?merge=20contrastive=20baseline=20?= =?UTF-8?q?=E9=80=9A=20lateset?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- federatedscope/cl/__init__.py | 3 + .../cl/baseline/fedavg_lr_on_twitter.yaml | 32 + .../baseline/fedavg_lstm_on_shakespeare.yaml | 35 + .../cl/baseline/fedavg_lstm_on_subreddit.yaml | 32 + .../baseline/fedavg_transformer_on_imdb.yaml | 36 + .../fedsimclr_linearprob_on_cifar10.yaml | 33 + .../cl/baseline/fedsimclr_on_cifar10.yaml | 33 + federatedscope/cl/dataloader/Cifar10.py | 187 + federatedscope/cl/dataloader/__init__.py | 3 + federatedscope/cl/loss/NT_xentloss.py | 41 + federatedscope/cl/loss/__init__.py | 5 + federatedscope/cl/model/SimCLR.py | 225 + federatedscope/cl/model/__init__.py | 7 + federatedscope/cl/test.ipynb | 4425 +++++++++++++++++ federatedscope/cl/trainer/__init__.py | 8 + federatedscope/cl/trainer/trainer.py | 144 + .../core/auxiliaries/criterion_builder.py | 1 + .../core/auxiliaries/data_builder.py | 3 + federatedscope/core/auxiliaries/eunms.py | 30 + .../core/auxiliaries/model_builder.py | 9 +- .../core/auxiliaries/optimizer_builder.py | 2 +- .../core/auxiliaries/trainer_builder.py | 3 + 22 files changed, 5295 insertions(+), 2 deletions(-) create mode 100644 federatedscope/cl/__init__.py create mode 100644 federatedscope/cl/baseline/fedavg_lr_on_twitter.yaml create mode 100644 federatedscope/cl/baseline/fedavg_lstm_on_shakespeare.yaml create mode 100644 federatedscope/cl/baseline/fedavg_lstm_on_subreddit.yaml create mode 100644 federatedscope/cl/baseline/fedavg_transformer_on_imdb.yaml create mode 100644 federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml create mode 100644 federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml create mode 100644 federatedscope/cl/dataloader/Cifar10.py create mode 100644 federatedscope/cl/dataloader/__init__.py create mode 100644 federatedscope/cl/loss/NT_xentloss.py create mode 100644 federatedscope/cl/loss/__init__.py create mode 100644 federatedscope/cl/model/SimCLR.py create mode 100644 federatedscope/cl/model/__init__.py create mode 100644 federatedscope/cl/test.ipynb create mode 100644 federatedscope/cl/trainer/__init__.py create mode 100644 federatedscope/cl/trainer/trainer.py create mode 100644 federatedscope/core/auxiliaries/eunms.py diff --git a/federatedscope/cl/__init__.py b/federatedscope/cl/__init__.py new file mode 100644 index 000000000..f8e91f237 --- /dev/null +++ b/federatedscope/cl/__init__.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division diff --git a/federatedscope/cl/baseline/fedavg_lr_on_twitter.yaml b/federatedscope/cl/baseline/fedavg_lr_on_twitter.yaml new file mode 100644 index 000000000..4f0656cdf --- /dev/null +++ b/federatedscope/cl/baseline/fedavg_lr_on_twitter.yaml @@ -0,0 +1,32 @@ +use_gpu: True +device: 0 +early_stop: + patience: 5 +federate: + mode: standalone + total_round_num: 100 + sample_client_num: 10 +data: + root: data/ + type: twitter + batch_size: 5 + subsample: 0.005 + num_workers: 0 +model: + type: lr + out_channels: 2 + dropout: 0.0 +train: + local_update_steps: 10 + optimizer: + lr: 0.0003 + weight_decay: 0.0 +criterion: + type: CrossEntropyLoss +trainer: + type: nlptrainer +eval: + freq: 10 + metrics: ['acc', 'correct'] + split: ['train'] + best_res_update_round_wise_key: 'train_loss' \ No newline at end of file diff --git a/federatedscope/cl/baseline/fedavg_lstm_on_shakespeare.yaml b/federatedscope/cl/baseline/fedavg_lstm_on_shakespeare.yaml new file mode 100644 index 000000000..86d5aca27 --- /dev/null +++ b/federatedscope/cl/baseline/fedavg_lstm_on_shakespeare.yaml @@ -0,0 +1,35 @@ +use_gpu: True +device: 0 +early_stop: + patience: 10 +federate: + mode: standalone + total_round_num: 1000 + sample_client_rate: 0.2 +data: + root: data/ + type: shakespeare + batch_size: 64 + subsample: 0.2 + num_workers: 0 + splits: [0.6,0.2,0.2] +model: + type: lstm + in_channels: 80 + out_channels: 80 + embed_size: 8 + hidden: 256 + dropout: 0.0 +train: + local_update_steps: 1 + batch_or_epoch: epoch + optimizer: + lr: 0.8 + weight_decay: 0.0 +criterion: + type: character_loss +trainer: + type: nlptrainer +eval: + freq: 10 + metrics: ['acc', 'correct'] diff --git a/federatedscope/cl/baseline/fedavg_lstm_on_subreddit.yaml b/federatedscope/cl/baseline/fedavg_lstm_on_subreddit.yaml new file mode 100644 index 000000000..1080bb591 --- /dev/null +++ b/federatedscope/cl/baseline/fedavg_lstm_on_subreddit.yaml @@ -0,0 +1,32 @@ +use_gpu: True +device: 0 +early_stop: + patience: 10 +federate: + mode: standalone + total_round_num: 100 + sample_client_num: 10 +data: + root: data/ + type: subreddit + batch_size: 5 + subsample: 1.0 +model: + type: lstm + in_channels: 10000 + out_channels: 10000 + hidden: 256 + embed_size: 200 + dropout: 0.0 +train: + local_update_steps: 10 + optimizer: + lr: 8.0 + weight_decay: 0.0 +criterion: + type: CrossEntropyLoss +trainer: + type: nlptrainer +eval: + freq: 10 + metrics: ['acc', 'correct'] \ No newline at end of file diff --git a/federatedscope/cl/baseline/fedavg_transformer_on_imdb.yaml b/federatedscope/cl/baseline/fedavg_transformer_on_imdb.yaml new file mode 100644 index 000000000..a9e818aa1 --- /dev/null +++ b/federatedscope/cl/baseline/fedavg_transformer_on_imdb.yaml @@ -0,0 +1,36 @@ +use_gpu: True +device: 2 +federate: + mode: standalone + total_round_num: 400 + client_num: 5 + share_local_model: True + online_aggr: True + sample_client_rate: 1.0 +data: + root: 'data' + type: 'IMDB@torchtext' + args: [{'max_len': 512}] + splits: [0.8, 0.2, 0.0] # test is fixed + batch_size: 128 + splitter: 'lda' + splitter_args: [{'alpha': 0.5}] + num_workers: 0 +model: + type: 'google/bert_uncased_L-2_H-128_A-2@transformers' + task: 'SequenceClassification' + out_channels: 2 +train: + local_update_steps: 1 + batch_or_epoch: 'epoch' + optimizer: + lr: 0.0001 + weight_decay: 0.0 +criterion: + type: 'CrossEntropyLoss' +trainer: + type: 'nlptrainer' +eval: + freq: 2 + metrics: ['acc', 'correct', 'f1'] + split: ['test', 'val', 'train'] \ No newline at end of file diff --git a/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml b/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml new file mode 100644 index 000000000..11d70ce36 --- /dev/null +++ b/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml @@ -0,0 +1,33 @@ +use_gpu: True +device: 2 +federate: + mode: standalone + total_round_num: 50 + client_num: 5 + sample_client_rate: 1.0 + method: local + restore_from: 'checkpoint/SimCLR_on_Cifar4CL_lr0.1_lstep5_rn100.ckpt' +data: + root: 'data' + type: 'Cifar4LP' + batch_size: 256 + num_workers: 2 +model: + type: 'SimCLR_linear' +train: + local_update_steps: 1 + batch_or_epoch: 'epoch' + optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0 +early_stop: + patience: 0 +criterion: + type: CrossEntropyLoss +trainer: + type: general +eval: + freq: 2 + metrics: ['acc'] + split: ['val', 'test'] \ No newline at end of file diff --git a/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml b/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml new file mode 100644 index 000000000..3372c874d --- /dev/null +++ b/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml @@ -0,0 +1,33 @@ +use_gpu: True +device: 2 +federate: + mode: standalone + total_round_num: 100 + client_num: 5 + share_local_model: True + online_aggr: True + sample_client_rate: 1.0 + save_to: 'checkpoint/SimCLR_on_Cifar4CL_lr0.05_lus5_rn100.ckpt' +data: + root: 'data' + type: 'Cifar4CL' + batch_size: 256 + num_workers: 2 +model: + type: 'SimCLR' +train: + local_update_steps: 5 + batch_or_epoch: 'epoch' + optimizer: + lr: 0.05 + momentum: 0.1 +early_stop: + patience: 0 +criterion: + type: 'NT_xentloss' +trainer: + type: 'cltrainer' +eval: + freq: 2 + metrics: ['loss'] + split: ['val', 'test'] \ No newline at end of file diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py new file mode 100644 index 000000000..73a3ab149 --- /dev/null +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -0,0 +1,187 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +import torchvision +import torchvision.transforms as T +import torchvision.transforms.functional as TF +from torch.utils.data import DataLoader, Dataset +from torchvision.datasets import CIFAR10, CIFAR100 +import pickle as pkl +import numpy as np + + + +class SimCLRTransform(): + def __init__(self, is_sup, image_size=32): + self.transform = T.Compose([ + T.RandomResizedCrop(image_size, scale=(0.5, 1.0), interpolation=T.InterpolationMode.BICUBIC), + T.RandomHorizontalFlip(p=0.5), + T.RandomApply([T.ColorJitter(0.4,0.4,0.2,0.1)], p=0.8), + T.RandomGrayscale(p=0.2), + T.RandomApply([T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.5), + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + self.mode = is_sup + + def __call__(self, x): + if(self.mode): + return self.transform(x) + else: + x1 = self.transform(x) + x2 = self.transform(x) + return x1, x2 + +def Cifar4CL(config): + + transform_train = SimCLRTransform(is_sup=False, image_size=32) + transform_test = T.Compose([ + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + + + data_train = CIFAR10(config.data.root, train=True, download=True, transform=transform_train) + data_val = CIFAR10(config.data.root, train=True, download=True, transform=transform_train) + data_test = CIFAR10(config.data.root, train=False, download=True, transform=transform_train) + + # Split data into dict + data_dict = dict() + train_per_client = len(data_train) // config.federate.client_num + val_per_client = len(data_val) // config.federate.client_num + test_per_client = len(data_test) // config.federate.client_num + + print("time1") + for client_idx in range(1, config.federate.client_num + 1): + dataloader_dict = { + 'train': + DataLoader([ + data_train[i] + for i in range((client_idx - 1) * + train_per_client, client_idx * train_per_client) + ], + config.data.batch_size, + shuffle=config.data.shuffle), + 'val': + DataLoader([ + data_val[i] + for i in range((client_idx - 1) * + val_per_client, client_idx * val_per_client) + ], + config.data.batch_size, + shuffle=config.data.shuffle), + 'test': + DataLoader([ + data_test[i] + for i in range((client_idx - 1) * test_per_client, client_idx * + test_per_client) + ], + config.data.batch_size, + shuffle=False) + } + data_dict[client_idx] = dataloader_dict + print("time2") + r""" + + Returns: + data: + { + '{client_id}': { + 'train': Dataset or DataLoader, + 'test': Dataset or DataLoader, + 'val': Dataset or DataLoader + } + } + config: + cfg_node + """ + config = config + return data_dict, config + +def Cifar4LP(config): + + transform_train = T.Compose([ + T.RandomResizedCrop(32, scale=(0.5, 1.0), interpolation=T.InterpolationMode.BICUBIC), + T.RandomHorizontalFlip(p=0.5), + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + transform_test = T.Compose([ + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + + + data_train = CIFAR10(config.data.root, train=True, download=True, transform=transform_train) + data_val = CIFAR10(config.data.root, train=True, download=True, transform=transform_test) + data_test = CIFAR10(config.data.root, train=False, download=True, transform=transform_test) + + # Split data into dict + data_dict = dict() + train_per_client = len(data_train) // config.federate.client_num + val_per_client = len(data_val) // config.federate.client_num + test_per_client = len(data_test) // config.federate.client_num + + print("time1") + for client_idx in range(1, config.federate.client_num + 1): + dataloader_dict = { + 'train': + DataLoader([ + data_train[i] + for i in range((client_idx - 1) * + train_per_client, client_idx * train_per_client) + ], + config.data.batch_size, + shuffle=config.data.shuffle), + 'val': + DataLoader([ + data_val[i] + for i in range((client_idx - 1) * + val_per_client, client_idx * val_per_client) + ], + config.data.batch_size, + shuffle=config.data.shuffle), + 'test': + DataLoader([ + data_test[i] + for i in range((client_idx - 1) * test_per_client, client_idx * + test_per_client) + ], + config.data.batch_size, + shuffle=False) + } + data_dict[client_idx] = dataloader_dict + print("time2") + r""" + + Returns: + data: + { + '{client_id}': { + 'train': Dataset or DataLoader, + 'test': Dataset or DataLoader, + 'val': Dataset or DataLoader + } + } + config: + cfg_node + """ + config = config + return data_dict, config + +from federatedscope.register import register_data + +def load_cifar_dataset(config): + if config.data.type == "Cifar4CL": + data, modified_config = Cifar4CL(config) + return data, modified_config + elif config.data.type == "Cifar4LP": + data, modified_config = Cifar4LP(config) + return data, modified_config + + +register_data("Cifar4CL", load_cifar_dataset) +register_data("Cifar4LP", load_cifar_dataset) diff --git a/federatedscope/cl/dataloader/__init__.py b/federatedscope/cl/dataloader/__init__.py new file mode 100644 index 000000000..fd11aaf73 --- /dev/null +++ b/federatedscope/cl/dataloader/__init__.py @@ -0,0 +1,3 @@ +from federatedscope.cl.dataloader.Cifar10 import load_cifar_dataset + +__all__ = ['load_cifar_dataset'] diff --git a/federatedscope/cl/loss/NT_xentloss.py b/federatedscope/cl/loss/NT_xentloss.py new file mode 100644 index 000000000..58b257cef --- /dev/null +++ b/federatedscope/cl/loss/NT_xentloss.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from federatedscope.register import register_criterion + +class NT_xentloss(nn.Module): + def __init__(self, temperature=0.5): + super(NT_xentloss, self).__init__() + self.temperature = temperature + + def forward(self, z1, z2): + N, Z = z1.shape + device = z1.device + representations = torch.cat([z1, z2], dim=0) + similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) + + l_pos = torch.diag(similarity_matrix, N) + r_pos = torch.diag(similarity_matrix, -N) + positives = torch.cat([l_pos, r_pos]).view(2 * N, 1) + + diag = torch.eye(2*N, dtype=torch.bool, device=device) + diag[N:,:N] = diag[:N,N:] = diag[:N,:N] + negatives = similarity_matrix[~diag].view(2*N, -1) + + logits = torch.cat([positives, negatives], dim=1) / self.temperature + labels = torch.zeros(2*N, device=device, dtype=torch.int64) # scalar label per sample + loss = F.cross_entropy(logits, labels, reduction='sum') + + return loss / (2 * N) + + +def create_NT_xentloss(type, device): + + if type == 'NT_xentloss': + criterion = NT_xentloss().to(device) + + return criterion + + +register_criterion('NT_xentloss', create_NT_xentloss) diff --git a/federatedscope/cl/loss/__init__.py b/federatedscope/cl/loss/__init__.py new file mode 100644 index 000000000..49f2aa1fe --- /dev/null +++ b/federatedscope/cl/loss/__init__.py @@ -0,0 +1,5 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +from federatedscope.cl.loss.NT_xentloss import * diff --git a/federatedscope/cl/model/SimCLR.py b/federatedscope/cl/model/SimCLR.py new file mode 100644 index 000000000..3bf940d21 --- /dev/null +++ b/federatedscope/cl/model/SimCLR.py @@ -0,0 +1,225 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from math import pi, cos, e +import numpy as np +from collections import OrderedDict + +#### ResNets +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.use_shortcut = stride != 1 or in_planes != self.expansion*planes + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, affine=True) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, affine=True) + + self.shortcut_conv = nn.Sequential() + if self.use_shortcut: + self.shortcut_conv = nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) + self.shortcut_bn = nn.BatchNorm2d(self.expansion*planes, affine=True) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + shortcut = self.shortcut_conv(x) + if self.use_shortcut: + shortcut = self.shortcut_bn(shortcut) + out += shortcut + return F.relu(out) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.use_shortcut = stride != 1 or in_planes != self.expansion*planes + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, affine=True) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, affine=True) + self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion*planes, affine=True) + + self.shortcut_conv = nn.Sequential() + if self.use_shortcut: + self.shortcut_conv = nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) + self.shortcut_bn = nn.BatchNorm2d(self.expansion*planes, affine=True) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + shortcut = self.shortcut_conv(x) + if self.use_shortcut: + shortcut = self.shortcut_bn(shortcut) + out += shortcut + return F.relu(out) + + +# Model class +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10, cfg=None): + super(ResNet, self).__init__() + self.train_sup = (num_classes > 0) + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64, affine=True) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.output_dim = 512*block.expansion + if(self.train_sup): + self.linear = nn.Linear(512*block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.adaptive_avg_pool2d(out, (1, 1)) + out = out.view(out.size(0), -1) + if(self.train_sup): + out = self.linear(out) + return out + +class ResNet_basic(nn.Module): + def __init__(self, block, num_blocks, num_classes=10, cfg=None): + super(ResNet_basic, self).__init__() + self.train_sup = (num_classes > 0) + + self.in_planes = 16 + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(16, affine=True) + self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) + self.output_dim = 512*block.expansion + if(self.train_sup): + self.linear = nn.Linear(64*block.expansion, num_classes, bias=True) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = F.adaptive_avg_pool2d(out, (1, 1)) + out = out.view(out.size(0), -1) + if(self.train_sup): + out = self.linear(out) + return out + + +def get_block(block): + if(block=="BasicBlock"): + return BasicBlock + elif(block=="Bottleneck"): + return Bottleneck + +def ResNet18(num_classes=10, block="BasicBlock"): + return ResNet(get_block(block), [2,2,2,2], num_classes=num_classes) + +def ResNet34(num_classes=10, block="BasicBlock"): + return ResNet(get_block(block), [3,4,6,3], num_classes=num_classes) + +def ResNet56(num_classes=10, block="BasicBlock"): + return ResNet_basic(get_block(block), [9,9,9], num_classes=num_classes) + + +### Retrieval function for backbones ### +def create_backbone(name, num_classes=10, block='BasicBlock'): + if(name == 'VGG'): + net = VGGmodel(num_classes=num_classes) + elif(name == 'res18'): + net = ResNet18(num_classes=num_classes, block=block) + elif(name == 'res34'): + net = ResNet34(num_classes=num_classes, block=block) + elif(name == 'res56'): + net = ResNet56(num_classes=num_classes, block=block) + + return net + + +# SimCLR model + + +# Projector +class projection_MLP_simclr(nn.Module): + def __init__(self, in_dim, hidden_dim=512, out_dim=512): + super(projection_MLP_simclr, self).__init__() + self.layer1 = nn.Linear(in_dim, hidden_dim, bias=False) + self.layer1_bn = nn.BatchNorm1d(hidden_dim, affine=True) + self.layer2 = nn.Linear(hidden_dim, out_dim) + self.layer2_bn = nn.BatchNorm1d(out_dim, affine=False) + + def forward(self, x): + x = F.relu(self.layer1_bn(self.layer1(x))) + x = self.layer2_bn(self.layer2(x)) + return x + +# SimCLR +class simclr(nn.Module): + def __init__(self, bbone_arch): + super(simclr, self).__init__() + self.T = 0.5 + self.register_buffer("rounds_done", torch.zeros(1)) + + self.backbone = create_backbone(bbone_arch, num_classes=0) + self.projector = projection_MLP_simclr(self.backbone.output_dim, hidden_dim=512, out_dim=512) + + def forward(self, x1, x2, x3=None, deg_labels=None): + N = x1.shape[0] + z1, z2 = self.projector(self.backbone(x1)), self.projector(self.backbone(x2)) +# L = NT_xentloss(z1, z2, temperature=self.T) + + return z1, z2 + +def ModelBuilder(model_config, local_data): + # You can also build models without local_data + data = next(iter(local_data['train'])) + if model_config.type == "SimCLR": + model = simclr(bbone_arch='res18') + return model + if model_config.type == "SimCLR_linear": + model = create_backbone(name='res18', num_classes=10) + pretrained_model = torch.load('checkpoint/SimCLR_on_Cifar4CL_lr0.5_lstep1_rn100.ckpt', map_location='cpu') + model.load_state_dict({k[9:]:v for k, v in pretrained_model['model'].items() if k.startswith('backbone.')}, strict=False) +# for name, value in model.named_parameters(): +# if not name.startswith('linear') : +# value.requires_grad = False + return model + +from federatedscope.register import register_model + +def get_simclr(model_config, local_data): + if model_config.type in ["SimCLR", "SimCLR_linear"]: + model = ModelBuilder(model_config, local_data) + return model + + +register_model("SimCLR", get_simclr) +register_model("SimCLR_linear", get_simclr) diff --git a/federatedscope/cl/model/__init__.py b/federatedscope/cl/model/__init__.py new file mode 100644 index 000000000..5cbc8605c --- /dev/null +++ b/federatedscope/cl/model/__init__.py @@ -0,0 +1,7 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +from federatedscope.cl.model.SimCLR import get_simclr + +__all__ = ['get_simclr'] diff --git a/federatedscope/cl/test.ipynb b/federatedscope/cl/test.ipynb new file mode 100644 index 000000000..b3267377c --- /dev/null +++ b/federatedscope/cl/test.ipynb @@ -0,0 +1,4425 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 13, + "id": "2abe03ca", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "3009a354", + "metadata": {}, + "outputs": [], + "source": [ + "def NT_xentloss_(z1, z2, temperature=0.5): \n", + " N, Z = z1.shape \n", + " device = z1.device \n", + " representations = torch.cat([z1, z2], dim=0)\n", + " similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)\n", + "\n", + " l_pos = torch.diag(similarity_matrix, N)\n", + " r_pos = torch.diag(similarity_matrix, -N)\n", + " positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)\n", + "\n", + " diag = torch.eye(2*N, dtype=torch.bool, device=device)\n", + " diag[N:,:N] = diag[:N,N:] = diag[:N,:N]\n", + " negatives = similarity_matrix[~diag].view(2*N, -1)\n", + "\n", + " logits = torch.cat([positives, negatives], dim=1) / temperature\n", + " labels = torch.zeros(2*N, device=device, dtype=torch.int64) # scalar label per sample\n", + " loss = F.cross_entropy(logits, labels, reduction='sum')\n", + "\n", + " return loss / (2 * N)\n", + "\n", + "class NT_xentloss(nn.Module):\n", + " def __init__(self, temperature=0.5):\n", + " super(NT_xentloss, self).__init__()\n", + " self.temperature = temperature\n", + " \n", + " def forward(self, z1, z2):\n", + " N, Z = z1.shape \n", + " device = z1.device \n", + " representations = torch.cat([z1, z2], dim=0)\n", + " similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)\n", + "\n", + " l_pos = torch.diag(similarity_matrix, N)\n", + " r_pos = torch.diag(similarity_matrix, -N)\n", + " positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)\n", + "\n", + " diag = torch.eye(2*N, dtype=torch.bool, device=device)\n", + " diag[N:,:N] = diag[:N,N:] = diag[:N,:N]\n", + " negatives = similarity_matrix[~diag].view(2*N, -1)\n", + "\n", + " logits = torch.cat([positives, negatives], dim=1) / temperature\n", + " labels = torch.zeros(2*N, device=device, dtype=torch.int64) # scalar label per sample\n", + " loss = F.cross_entropy(logits, labels, reduction='sum')\n", + " \n", + " return loss / (2 * N)" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "89177767", + "metadata": {}, + "outputs": [], + "source": [ + "from model.SimCLR import simclr, create_backbone\n", + "import random" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "a0c608b3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "conv1.weight Parameter containing:\n", + "tensor([[[[-0.0183, -0.0765, 0.2014],\n", + " [-0.0594, -0.0160, -0.0777],\n", + " [-0.1222, -0.0284, 0.1164]],\n", + "\n", + " [[ 0.3198, 0.2831, 0.3139],\n", + " [ 0.3030, -0.0359, 0.2107],\n", + " [ 0.3482, -0.0047, 0.4345]],\n", + "\n", + " [[-0.0811, -0.0312, -0.0877],\n", + " [-0.4400, -0.6757, -0.4297],\n", + " [ 0.0313, -0.4326, 0.0159]]],\n", + "\n", + "\n", + " [[[-0.0643, 0.2237, 0.4458],\n", + " [-0.3193, -0.0773, 0.0561],\n", + " [-0.4273, -0.0085, 0.1814]],\n", + "\n", + " [[-0.0256, -0.1148, 0.1742],\n", + " [-0.4455, -0.1189, 0.2433],\n", + " [-0.4229, -0.2738, -0.1874]],\n", + "\n", + " [[-0.0910, 0.0338, 0.2989],\n", + " [-0.1994, 0.0990, 0.2593],\n", + " [-0.0121, 0.1451, -0.0939]]],\n", + "\n", + "\n", + " [[[-0.3403, -0.4584, -0.3110],\n", + " [ 0.1626, 0.2688, 0.1822],\n", + " [ 0.3948, 0.3835, 0.1284]],\n", + "\n", + " [[-0.4745, -0.1487, -0.1274],\n", + " [-0.1165, -0.2197, -0.0441],\n", + " [ 0.1622, 0.2059, -0.1887]],\n", + "\n", + " [[-0.0929, -0.2441, 0.0493],\n", + " [-0.0427, -0.0754, -0.0753],\n", + " [ 0.3795, 0.3901, 0.2711]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[-0.0482, 0.2177, 0.2685],\n", + " [-0.1009, -0.1820, 0.0026],\n", + " [-0.2059, -0.1552, -0.0143]],\n", + "\n", + " [[ 0.0179, -0.0478, -0.0131],\n", + " [ 0.1983, -0.2138, 0.2272],\n", + " [ 0.0263, -0.2440, -0.1631]],\n", + "\n", + " [[-0.0075, 0.1251, 0.0224],\n", + " [ 0.1411, -0.1159, -0.2196],\n", + " [ 0.0806, -0.3194, -0.2587]]],\n", + "\n", + "\n", + " [[[-0.0033, 0.2087, 0.0977],\n", + " [-0.0519, 0.2213, 0.2196],\n", + " [ 0.1723, 0.0177, 0.2410]],\n", + "\n", + " [[-0.1598, 0.1688, 0.1923],\n", + " [ 0.0352, 0.0465, 0.1310],\n", + " [ 0.1351, 0.2713, 0.0726]],\n", + "\n", + " [[-0.1874, 0.1088, -0.0098],\n", + " [-0.0707, 0.2780, 0.0114],\n", + " [-0.0553, 0.1157, 0.0258]]],\n", + "\n", + "\n", + " [[[-0.2044, -0.2053, -0.0668],\n", + " [-0.2899, -0.1816, -0.1537],\n", + " [-0.2074, -0.1091, -0.1762]],\n", + "\n", + " [[-0.1399, -0.2275, -0.0407],\n", + " [-0.1803, -0.0934, -0.0654],\n", + " [ 0.0834, -0.1998, 0.0623]],\n", + "\n", + " [[ 0.2251, 0.0599, 0.1248],\n", + " [ 0.1365, 0.3196, 0.1363],\n", + " [ 0.0486, 0.2550, 0.2294]]]])\n", + "bn1.weight Parameter containing:\n", + "tensor([1.1764, 1.1640, 1.0995, 1.1407, 1.1942, 0.9670, 1.8451, 1.1208, 1.0389,\n", + " 0.9228, 0.8822, 1.3220, 0.8472, 1.0401, 1.3100, 0.9914, 1.0612, 1.2423,\n", + " 1.0459, 1.0871, 0.9735, 0.9283, 1.0515, 0.9037, 1.0665, 0.9346, 1.5852,\n", + " 1.1763, 1.0359, 0.9605, 1.2617, 0.9576, 0.9276, 1.5612, 1.4488, 1.4455,\n", + " 1.0734, 1.0145, 1.0641, 1.1351, 0.7940, 1.5344, 1.0264, 0.9425, 0.9327,\n", + " 0.8563, 0.8882, 1.1045, 1.1882, 1.0556, 1.2136, 0.9525, 0.9602, 0.9024,\n", + " 1.0303, 0.9192, 0.8604, 1.1484, 1.9566, 0.9445, 1.1349, 0.9673, 0.8733,\n", + " 0.8793])\n", + "bn1.bias Parameter containing:\n", + "tensor([ 0.2469, 0.1301, 0.2798, 0.2101, 0.2512, -0.0913, 0.1757, -0.1319,\n", + " 0.0583, 0.1632, 0.0488, 0.4202, -0.0444, 0.0935, 0.2882, -0.0337,\n", + " 0.0455, 0.3127, 0.0145, 0.2779, -0.0525, 0.0980, 0.2111, -0.2265,\n", + " 0.3618, 0.0158, 0.4913, 0.3805, -0.1700, -0.0907, 0.3175, -0.0993,\n", + " 0.0133, 0.5292, 0.6708, -0.1043, 0.0500, 0.1086, 0.1487, 0.1016,\n", + " 0.1520, 0.6627, 0.0677, 0.1277, -0.1703, 0.2342, -0.0021, 0.0561,\n", + " 0.2777, 0.1305, 0.3191, 0.0372, -0.1303, 0.2400, -0.1451, -0.1098,\n", + " -0.1606, 0.2635, 0.5051, 0.1788, -0.0703, -0.1441, -0.0713, -0.1347])\n", + "layer1.0.conv1.weight Parameter containing:\n", + "tensor([[[[ 2.6822e-02, 2.7844e-02, 5.7835e-02],\n", + " [ 5.3032e-02, -1.9933e-02, 4.4936e-03],\n", + " [-7.0999e-03, -3.9780e-02, 8.8800e-03]],\n", + "\n", + " [[-4.3002e-02, -1.6001e-02, 2.9426e-02],\n", + " [-7.6142e-03, 7.4012e-02, 6.8242e-02],\n", + " [ 4.2627e-02, 9.5829e-02, 2.2561e-02]],\n", + "\n", + " [[ 2.0417e-02, 2.1891e-02, 7.8952e-02],\n", + " [-9.9509e-03, 8.9533e-02, 1.3280e-01],\n", + " [ 5.9915e-03, 1.0857e-01, 4.9066e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-4.3741e-02, -1.6707e-04, -8.0110e-02],\n", + " [-7.0618e-02, -5.8168e-02, -1.4400e-02],\n", + " [-7.1466e-02, -3.5159e-02, -1.8935e-02]],\n", + "\n", + " [[-5.9944e-02, -2.6373e-02, 6.4819e-02],\n", + " [-3.7174e-02, 4.6251e-02, 5.9955e-02],\n", + " [ 2.8329e-02, 7.1514e-02, 5.7975e-02]],\n", + "\n", + " [[-2.5611e-02, 3.2862e-02, 2.4414e-02],\n", + " [ 6.2642e-02, 9.8129e-03, 7.5067e-02],\n", + " [ 9.4383e-02, 3.7375e-02, 5.1015e-02]]],\n", + "\n", + "\n", + " [[[ 2.6016e-02, 3.3790e-03, 4.1034e-02],\n", + " [ 2.0423e-02, 3.9459e-02, 5.0560e-02],\n", + " [ 7.3355e-02, 5.9024e-02, 1.0244e-01]],\n", + "\n", + " [[ 9.3631e-02, -9.1128e-03, -2.5420e-03],\n", + " [ 1.2080e-01, 4.7242e-02, 4.0391e-02],\n", + " [ 6.4514e-02, 8.8472e-02, 4.4476e-02]],\n", + "\n", + " [[ 1.5563e-02, 6.9989e-02, 8.5023e-02],\n", + " [-6.7027e-04, 1.0338e-02, 6.1742e-02],\n", + " [-5.4143e-03, 2.5595e-02, 6.4282e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 3.9204e-02, -2.7143e-02, -2.1611e-02],\n", + " [-1.0465e-02, 6.9390e-03, 1.8000e-02],\n", + " [-2.1076e-02, 3.1111e-02, -2.4030e-02]],\n", + "\n", + " [[ 9.1457e-02, 5.4195e-02, 3.1970e-03],\n", + " [ 7.9195e-02, 3.1202e-02, 3.8534e-02],\n", + " [ 7.0259e-02, 5.2701e-02, 4.9941e-02]],\n", + "\n", + " [[ 2.0938e-02, 4.2828e-02, -1.8668e-02],\n", + " [ 6.0307e-02, 6.1750e-02, -1.7545e-02],\n", + " [ 7.9255e-02, 1.0614e-03, -2.1844e-02]]],\n", + "\n", + "\n", + " [[[-9.9994e-03, 5.3842e-02, 1.6255e-02],\n", + " [ 3.0797e-02, 7.7342e-02, 1.1806e-01],\n", + " [ 1.1641e-01, 8.0531e-02, 1.3195e-01]],\n", + "\n", + " [[ 3.2614e-02, 4.4080e-02, -9.1549e-02],\n", + " [ 3.5523e-02, 2.9408e-03, -3.7003e-02],\n", + " [ 6.3749e-02, 3.2515e-02, 4.9107e-02]],\n", + "\n", + " [[-1.5121e-03, 1.2614e-02, 4.8092e-02],\n", + " [-5.1469e-02, -6.2799e-03, -5.0903e-02],\n", + " [-2.3207e-02, -3.2161e-02, -4.1511e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.3646e-02, -7.9403e-02, -8.0696e-02],\n", + " [-5.3833e-02, -5.3324e-02, -4.2518e-02],\n", + " [-5.0440e-03, 2.9950e-02, -3.2976e-02]],\n", + "\n", + " [[-2.6310e-02, 5.8497e-02, 3.2514e-02],\n", + " [-4.3040e-02, -3.1510e-02, -9.4330e-03],\n", + " [-7.6044e-02, -1.1833e-02, -6.2875e-02]],\n", + "\n", + " [[ 1.9541e-02, 3.7317e-02, 4.0267e-02],\n", + " [-7.7281e-03, 5.9121e-03, -2.1681e-02],\n", + " [-2.5764e-02, 6.2579e-03, 9.6221e-04]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[ 1.1449e-02, 5.3609e-02, -3.0909e-03],\n", + " [ 6.0895e-02, 4.4686e-02, 4.9324e-02],\n", + " [-1.6167e-02, -3.9159e-02, 1.5942e-02]],\n", + "\n", + " [[-3.4361e-02, 3.7426e-02, 3.8636e-02],\n", + " [ 2.5543e-03, -3.0453e-02, -1.5798e-02],\n", + " [ 2.1582e-05, -8.5651e-02, -3.5822e-02]],\n", + "\n", + " [[-4.8913e-02, -6.9554e-02, -2.0775e-03],\n", + " [-1.4599e-02, -5.2028e-02, -3.5874e-02],\n", + " [-9.9485e-02, -1.4734e-03, -2.0859e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.8716e-02, 4.1136e-04, 3.3992e-02],\n", + " [-3.7199e-02, -4.4121e-03, 3.3526e-02],\n", + " [-1.1656e-02, -3.2170e-02, 2.0356e-02]],\n", + "\n", + " [[ 2.5727e-02, 5.4506e-02, -2.8860e-02],\n", + " [-2.2427e-02, -1.3907e-02, -2.6238e-02],\n", + " [-2.0807e-03, 6.2578e-03, 1.1508e-02]],\n", + "\n", + " [[-1.3074e-02, 6.6535e-02, 3.0361e-02],\n", + " [ 3.9119e-02, 3.1634e-02, 2.4559e-02],\n", + " [-2.3985e-02, -3.8256e-02, -6.4856e-02]]],\n", + "\n", + "\n", + " [[[-4.8634e-02, -7.1717e-02, -3.4695e-02],\n", + " [-7.3348e-04, -3.3484e-02, 4.1865e-03],\n", + " [-3.7279e-02, -3.6906e-02, -1.5438e-02]],\n", + "\n", + " [[-8.8918e-02, -5.0840e-02, -3.6884e-02],\n", + " [-5.6966e-02, -8.8720e-02, -3.8041e-02],\n", + " [ 9.9085e-03, -8.7135e-02, -1.0016e-01]],\n", + "\n", + " [[-1.7784e-02, -8.5942e-03, -4.1659e-03],\n", + " [ 2.7800e-02, 2.1069e-02, -1.0392e-03],\n", + " [-4.0285e-02, -2.9785e-03, 3.1919e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-2.0661e-02, 1.6509e-02, 5.3129e-02],\n", + " [-4.6699e-03, 7.5870e-03, 1.6611e-02],\n", + " [-6.6152e-02, -4.0273e-02, 4.6010e-02]],\n", + "\n", + " [[-9.7571e-03, -2.1538e-02, -3.4537e-04],\n", + " [ 5.6689e-02, 1.5892e-02, -4.1756e-02],\n", + " [ 5.6645e-02, 3.2210e-02, -3.3068e-03]],\n", + "\n", + " [[ 1.7798e-02, -3.1845e-02, -2.4694e-02],\n", + " [-6.8033e-02, -4.6940e-02, -2.4673e-02],\n", + " [-6.2354e-02, -4.0748e-02, 2.6249e-02]]],\n", + "\n", + "\n", + " [[[-6.6686e-02, -4.4081e-02, 4.6792e-02],\n", + " [-3.4211e-02, -6.6569e-02, 5.2638e-02],\n", + " [-5.5514e-02, -7.6365e-02, -4.6954e-03]],\n", + "\n", + " [[ 5.5348e-03, -4.3402e-02, -9.7526e-03],\n", + " [ 5.4384e-02, 2.8204e-02, 1.9586e-02],\n", + " [-4.1902e-02, -1.1975e-02, -4.1381e-03]],\n", + "\n", + " [[-7.1795e-03, -8.9517e-03, -4.6672e-02],\n", + " [-7.6845e-02, -8.6473e-02, -9.3363e-02],\n", + " [-4.8047e-03, -6.1196e-02, -5.3406e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-5.5690e-02, -3.6102e-02, 3.2068e-02],\n", + " [-4.7767e-02, -3.4354e-02, -2.1960e-02],\n", + " [-1.8896e-02, -2.4625e-02, -1.4877e-02]],\n", + "\n", + " [[-7.4994e-03, -1.6231e-03, -7.5342e-02],\n", + " [-8.7427e-03, -1.9412e-02, -9.5213e-02],\n", + " [-3.1043e-02, 3.6101e-03, -3.2579e-02]],\n", + "\n", + " [[ 9.4366e-02, 1.0742e-02, -7.4900e-03],\n", + " [ 7.4800e-02, 2.6321e-02, 5.5021e-02],\n", + " [ 2.0209e-02, 8.3471e-02, 9.2501e-02]]]])\n", + "layer1.0.bn1.weight Parameter containing:\n", + "tensor([1.0136, 0.9218, 0.9679, 1.0025, 0.9043, 0.8406, 1.2799, 1.0009, 1.0661,\n", + " 1.0704, 0.9351, 1.3376, 1.1102, 0.9364, 0.8600, 1.2370, 0.9971, 1.0767,\n", + " 1.3791, 0.9053, 0.9246, 1.0014, 1.0106, 0.9543, 0.8708, 0.9405, 1.0075,\n", + " 1.0012, 1.0660, 1.3460, 1.0126, 0.9835, 0.9491, 1.0240, 0.9015, 0.9719,\n", + " 0.8948, 1.3173, 0.9648, 0.9181, 0.9111, 0.8892, 0.8699, 0.9280, 1.0618,\n", + " 0.8207, 0.8929, 1.0177, 0.9283, 0.8522, 0.9356, 0.9263, 0.8903, 0.8923,\n", + " 0.9502, 0.9133, 1.1011, 1.0126, 0.8934, 0.9556, 1.1305, 0.9480, 1.1295,\n", + " 0.8721])\n", + "layer1.0.bn1.bias Parameter containing:\n", + "tensor([-0.0695, -0.1471, 0.0139, -0.0895, -0.0935, -0.1319, 0.2851, 0.0956,\n", + " 0.0141, -0.0588, -0.0934, 0.1814, 0.0189, -0.1585, -0.1160, 0.2663,\n", + " -0.0549, -0.0755, 0.2770, -0.2165, -0.1437, -0.0044, -0.0689, -0.2161,\n", + " -0.1360, 0.0215, -0.0459, 0.0850, -0.0693, 0.2729, -0.0163, 0.0720,\n", + " -0.0108, -0.0080, -0.0584, -0.0065, -0.0323, 0.0700, -0.0114, 0.0549,\n", + " 0.0068, -0.0173, -0.0960, -0.0148, 0.0259, -0.1088, -0.0520, -0.0149,\n", + " -0.0885, 0.0346, 0.1636, 0.0960, -0.0292, -0.0059, 0.1584, -0.2165,\n", + " -0.0836, -0.0022, -0.1441, -0.0754, 0.2161, -0.0256, 0.1242, -0.0562])\n", + "layer1.0.conv2.weight Parameter containing:\n", + "tensor([[[[ 1.1813e-02, -2.9603e-02, 8.2939e-03],\n", + " [-3.9731e-02, -4.8097e-02, 9.8338e-03],\n", + " [-4.8729e-02, -2.6406e-02, 3.8050e-02]],\n", + "\n", + " [[-1.3992e-01, -5.0215e-02, -3.7980e-02],\n", + " [-7.6246e-02, -8.1081e-02, -7.5231e-02],\n", + " [-1.0410e-01, -6.5782e-03, -2.2425e-02]],\n", + "\n", + " [[ 1.9414e-02, 3.7513e-02, 1.0480e-02],\n", + " [ 3.1074e-02, 4.5602e-02, 3.8310e-02],\n", + " [ 8.8012e-03, 2.3819e-03, 4.6176e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[-3.1133e-02, -2.8605e-03, -5.2947e-02],\n", + " [ 3.7506e-02, 6.3998e-04, -1.2003e-03],\n", + " [ 1.6662e-02, -3.0633e-02, -1.1023e-02]],\n", + "\n", + " [[-7.1672e-02, -3.3559e-02, 2.7411e-02],\n", + " [-1.2149e-02, -3.2008e-02, 5.0320e-02],\n", + " [-4.2444e-02, -6.7399e-02, 2.0852e-02]],\n", + "\n", + " [[ 1.1870e-04, -9.6183e-02, -8.2698e-02],\n", + " [ 1.0616e-02, -8.2367e-02, -6.7266e-02],\n", + " [-8.5162e-02, -2.5697e-02, -3.6014e-02]]],\n", + "\n", + "\n", + " [[[-2.4725e-02, 1.3686e-02, 4.1195e-02],\n", + " [ 1.0350e-02, 1.4489e-02, -3.2543e-02],\n", + " [-3.9483e-02, 4.3526e-03, -6.4983e-02]],\n", + "\n", + " [[-8.8199e-02, -6.7863e-02, -8.9248e-02],\n", + " [-9.4662e-02, -8.7976e-02, -1.0673e-01],\n", + " [-4.4556e-02, -1.0821e-01, -8.3664e-02]],\n", + "\n", + " [[-4.2361e-02, 2.5027e-02, 7.9239e-02],\n", + " [-7.9938e-02, 8.8911e-03, 3.7579e-02],\n", + " [ 2.5401e-02, 4.2908e-02, 8.5704e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[-2.8075e-02, 2.0473e-02, -9.1930e-03],\n", + " [ 3.9986e-02, 9.4368e-03, 1.3627e-02],\n", + " [ 4.2203e-02, 3.4609e-02, 3.5526e-02]],\n", + "\n", + " [[-5.1567e-02, -1.1104e-01, -2.5351e-02],\n", + " [-7.1857e-02, -1.0935e-01, -5.1078e-02],\n", + " [-4.2563e-02, -1.3174e-01, -1.1441e-01]],\n", + "\n", + " [[-2.1442e-02, -1.0601e-02, 2.1566e-02],\n", + " [ 4.4194e-02, -2.4603e-03, -2.8993e-02],\n", + " [ 4.4593e-02, 2.4336e-02, 4.1785e-02]]],\n", + "\n", + "\n", + " [[[-3.7110e-02, -2.1384e-02, -8.8386e-02],\n", + " [ 8.9586e-03, -2.1911e-02, -1.9369e-02],\n", + " [ 2.1264e-05, 1.0292e-02, -7.6875e-03]],\n", + "\n", + " [[-6.9804e-02, -2.7804e-02, -3.7626e-02],\n", + " [-1.0934e-04, -5.0990e-02, -1.1562e-01],\n", + " [-4.4917e-02, -3.5801e-02, -7.5220e-02]],\n", + "\n", + " [[-7.0519e-02, -1.0848e-02, 2.0123e-02],\n", + " [-3.8126e-02, -3.2042e-02, -2.0311e-02],\n", + " [-6.6819e-02, 3.2167e-04, -5.7737e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 3.5913e-03, 7.8448e-02, 8.0313e-02],\n", + " [ 1.4435e-02, -2.4045e-03, 6.7934e-03],\n", + " [-9.2509e-03, -9.8345e-03, 4.7002e-02]],\n", + "\n", + " [[ 1.6466e-02, 4.2463e-02, 7.0508e-02],\n", + " [ 3.4950e-02, 1.0808e-01, 7.7231e-02],\n", + " [-6.9534e-02, 4.1800e-02, 7.8877e-02]],\n", + "\n", + " [[-9.4433e-04, -6.5545e-02, 1.0643e-02],\n", + " [-1.2707e-02, -2.0921e-02, -7.0154e-02],\n", + " [-3.4571e-02, -3.0644e-02, -4.7103e-02]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[-7.4073e-02, -5.6490e-02, -7.8497e-03],\n", + " [-3.7687e-02, -7.1005e-02, -1.6297e-02],\n", + " [-4.7876e-02, -5.8232e-02, -6.4043e-02]],\n", + "\n", + " [[ 1.3346e-01, 9.8550e-02, 9.0538e-02],\n", + " [ 1.6279e-01, 4.8807e-02, 1.0326e-01],\n", + " [ 6.8256e-02, 8.8236e-02, 2.5214e-02]],\n", + "\n", + " [[ 5.0286e-02, -2.8655e-02, -3.9797e-02],\n", + " [ 8.1590e-02, 8.8670e-02, 3.2581e-02],\n", + " [ 6.6629e-02, 5.5634e-02, 6.7553e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-4.8289e-02, -1.8756e-02, -1.9947e-02],\n", + " [ 1.5246e-03, 2.4550e-02, -1.8470e-02],\n", + " [ 7.0791e-02, 5.5384e-02, 4.3716e-02]],\n", + "\n", + " [[-3.0825e-02, -2.1276e-02, -3.7002e-02],\n", + " [-6.0327e-02, -7.0949e-02, -8.2225e-02],\n", + " [-3.8889e-02, -3.2270e-02, -2.6290e-02]],\n", + "\n", + " [[-1.0067e-01, -1.1138e-02, -6.7776e-02],\n", + " [-9.9491e-02, -4.8505e-02, -5.5871e-02],\n", + " [-1.2496e-01, -1.1068e-01, -6.7283e-02]]],\n", + "\n", + "\n", + " [[[-1.5475e-03, -7.9660e-03, -3.8476e-02],\n", + " [ 5.9998e-02, 5.8792e-02, 1.5293e-02],\n", + " [ 9.0893e-02, -2.4707e-03, -1.3622e-02]],\n", + "\n", + " [[-2.2771e-02, 2.6777e-02, -2.0016e-03],\n", + " [-1.1872e-03, 7.3358e-02, -4.8738e-02],\n", + " [-3.2554e-03, -2.7650e-02, -3.6791e-03]],\n", + "\n", + " [[ 1.1710e-01, 6.0996e-03, -2.3733e-02],\n", + " [-6.7923e-04, 1.2859e-03, 2.7628e-03],\n", + " [-4.2115e-02, -2.6475e-02, 7.1679e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[ 3.7477e-02, 1.1140e-03, 6.3222e-02],\n", + " [ 4.8929e-02, 5.5243e-02, 1.1429e-03],\n", + " [-3.3385e-02, -4.7582e-02, -3.9440e-02]],\n", + "\n", + " [[-4.3734e-02, 1.0802e-02, 1.5586e-02],\n", + " [-2.0318e-02, 3.2691e-02, 2.9043e-03],\n", + " [-5.3472e-03, 5.9829e-03, 3.4189e-02]],\n", + "\n", + " [[-3.2688e-03, -5.7450e-02, 3.6852e-02],\n", + " [ 3.2528e-03, -1.5297e-03, 5.3320e-03],\n", + " [ 1.7151e-02, -1.1268e-03, 8.1022e-02]]],\n", + "\n", + "\n", + " [[[ 4.3947e-02, 4.7477e-02, -1.7206e-02],\n", + " [-5.8523e-02, -6.2340e-03, -4.0691e-02],\n", + " [-8.4128e-02, -2.6183e-02, 1.9188e-02]],\n", + "\n", + " [[-3.2490e-02, -9.6267e-02, -7.1499e-02],\n", + " [-7.1048e-04, -8.0449e-02, -5.1654e-02],\n", + " [ 2.3208e-02, -2.1737e-02, 4.9298e-02]],\n", + "\n", + " [[-4.4389e-02, -5.6683e-02, -2.3320e-02],\n", + " [-6.4499e-02, -3.0856e-02, 6.9859e-03],\n", + " [-5.6850e-02, -3.9121e-02, -4.1350e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.7911e-02, 4.6162e-02, 1.9657e-02],\n", + " [ 7.5653e-03, -4.2511e-02, -2.7971e-02],\n", + " [-3.6817e-02, -9.0424e-02, -1.6890e-02]],\n", + "\n", + " [[-2.2699e-02, 1.6050e-02, -2.7669e-02],\n", + " [-2.6073e-02, -5.1785e-02, -3.8797e-02],\n", + " [-6.3523e-02, -5.3364e-02, -9.2193e-02]],\n", + "\n", + " [[-5.3270e-02, -4.1235e-03, -5.8335e-03],\n", + " [-1.9613e-02, 5.4143e-02, -1.6930e-02],\n", + " [ 3.3865e-02, -2.8973e-03, -1.6124e-02]]]])\n", + "layer1.0.bn2.weight Parameter containing:\n", + "tensor([0.9079, 0.9314, 0.9273, 0.9928, 0.9441, 0.8893, 1.3586, 0.9319, 0.9911,\n", + " 0.8918, 0.9668, 1.2491, 0.9299, 0.8372, 0.9112, 1.1834, 0.8022, 0.9417,\n", + " 0.8730, 1.4600, 1.0404, 1.0169, 0.9842, 0.8882, 0.9578, 1.0128, 0.9548,\n", + " 0.8057, 0.7646, 0.9008, 1.0668, 0.8646, 0.9026, 1.0237, 0.7285, 1.0507,\n", + " 0.9604, 0.9052, 0.8586, 0.7162, 0.9041, 0.8687, 0.9544, 0.8602, 0.9671,\n", + " 1.0384, 0.9835, 0.9222, 0.9371, 0.7231, 0.8584, 0.9737, 0.9182, 0.7807,\n", + " 0.9644, 0.9615, 0.9434, 0.8674, 1.1694, 1.0835, 0.7759, 0.8638, 1.0058,\n", + " 0.9209])\n", + "layer1.0.bn2.bias Parameter containing:\n", + "tensor([-0.0163, -0.0560, -0.0696, -0.2069, -0.0330, -0.0317, 0.0248, -0.0275,\n", + " -0.0437, -0.0688, -0.0788, 0.1443, -0.1303, -0.0720, -0.0025, 0.0403,\n", + " 0.0716, -0.0432, -0.0148, 0.2234, -0.0015, 0.0852, 0.0106, -0.0835,\n", + " -0.0506, 0.0035, 0.0694, 0.0313, -0.0780, -0.0077, 0.0154, 0.0233,\n", + " -0.0140, 0.0316, 0.0475, -0.0478, -0.0771, 0.0058, -0.0953, 0.0017,\n", + " 0.0255, 0.0810, -0.0305, 0.0693, -0.1069, -0.0252, -0.0649, -0.0238,\n", + " -0.0523, 0.0308, 0.0118, 0.0199, -0.0296, 0.0458, -0.0450, -0.0723,\n", + " -0.0193, 0.0517, 0.0489, -0.0302, -0.1371, -0.1570, 0.0574, -0.1233])\n", + "layer1.1.conv1.weight Parameter containing:\n", + "tensor([[[[ 2.6593e-02, -3.4198e-02, 6.7904e-02],\n", + " [ 4.9502e-02, 1.7229e-02, 5.6829e-02],\n", + " [-8.7818e-03, 6.0911e-02, 4.0403e-02]],\n", + "\n", + " [[-7.6700e-02, -9.6916e-02, -5.3278e-02],\n", + " [-5.9268e-03, -2.2775e-02, -8.3431e-02],\n", + " [ 9.8249e-03, 4.8766e-02, -5.0661e-02]],\n", + "\n", + " [[-4.0173e-02, 2.9130e-02, 3.3362e-02],\n", + " [-1.5308e-02, 2.6540e-02, -9.1486e-03],\n", + " [-8.1136e-02, -7.1668e-02, 1.7501e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 4.9759e-02, 6.7114e-02, 9.5368e-03],\n", + " [ 2.4895e-02, 3.7341e-02, -9.1378e-03],\n", + " [ 1.3741e-02, -2.4999e-02, -9.5631e-03]],\n", + "\n", + " [[-2.3029e-02, -5.6797e-02, -7.2388e-02],\n", + " [ 2.0829e-02, 3.3307e-02, -4.1836e-02],\n", + " [-2.4193e-02, 1.8719e-02, -3.0852e-02]],\n", + "\n", + " [[-4.1146e-02, -5.9768e-02, -1.3481e-02],\n", + " [-2.4279e-02, -2.0338e-02, -4.9791e-02],\n", + " [ 1.4678e-02, -2.1303e-02, 3.9306e-02]]],\n", + "\n", + "\n", + " [[[-2.4511e-02, -1.0791e-02, -5.8488e-02],\n", + " [-3.8544e-02, -6.9196e-03, -4.3451e-02],\n", + " [-3.9745e-02, -4.2166e-02, -4.0130e-02]],\n", + "\n", + " [[ 2.2328e-03, 2.0923e-02, 7.5556e-02],\n", + " [ 1.6712e-02, 6.0325e-02, 2.8951e-02],\n", + " [ 5.7430e-02, 7.4455e-03, 1.1946e-02]],\n", + "\n", + " [[ 2.1712e-02, -1.3877e-02, -2.8872e-03],\n", + " [-6.3337e-02, -6.5354e-02, -4.3473e-02],\n", + " [-7.9494e-03, -4.2597e-02, -3.6345e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.3227e-02, 1.3381e-02, 1.2781e-02],\n", + " [ 1.1976e-02, 1.7176e-03, 7.5455e-02],\n", + " [ 3.3403e-03, 4.9971e-02, 2.0079e-02]],\n", + "\n", + " [[-7.2390e-03, -3.0566e-02, 1.6535e-02],\n", + " [ 2.0263e-02, -1.2138e-02, 4.2431e-02],\n", + " [ 2.3446e-02, 4.1872e-04, -8.2290e-03]],\n", + "\n", + " [[-2.5881e-03, -3.2973e-02, 4.8453e-02],\n", + " [ 1.2530e-02, 2.2892e-02, 4.6676e-02],\n", + " [-4.0720e-02, 4.1090e-02, 3.1536e-02]]],\n", + "\n", + "\n", + " [[[-4.4590e-02, 3.2132e-02, 5.0317e-02],\n", + " [ 5.8723e-02, 2.7518e-02, 5.2837e-02],\n", + " [ 4.5984e-02, -4.8859e-02, 2.8074e-02]],\n", + "\n", + " [[ 3.5439e-02, 4.2154e-02, 6.2620e-03],\n", + " [ 1.2440e-02, 9.4835e-02, -1.4920e-02],\n", + " [ 9.3768e-02, 6.6834e-02, -2.9676e-02]],\n", + "\n", + " [[-1.0067e-02, -2.6803e-02, -3.5374e-02],\n", + " [-8.0500e-03, 2.7707e-02, -4.1172e-02],\n", + " [-4.7432e-02, -3.9620e-02, -4.4891e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 4.3233e-02, 2.7552e-02, 1.3127e-02],\n", + " [ 6.8530e-02, 4.6166e-02, 7.0684e-03],\n", + " [ 5.6551e-02, -3.3509e-02, -5.4552e-02]],\n", + "\n", + " [[-1.7251e-02, 3.8119e-02, 1.4739e-02],\n", + " [-3.9602e-02, 3.8198e-02, 7.9852e-02],\n", + " [-2.2415e-02, 1.7422e-02, 7.2721e-02]],\n", + "\n", + " [[-2.1743e-02, -2.1620e-02, -2.9930e-02],\n", + " [ 5.2186e-02, 4.4009e-02, -1.4880e-02],\n", + " [ 8.2744e-02, 4.0110e-02, 6.9476e-02]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[ 5.4301e-02, 5.8203e-02, 5.1051e-02],\n", + " [ 6.0607e-03, -2.2593e-03, 4.8936e-02],\n", + " [ 2.2871e-02, 9.6164e-03, 1.8444e-02]],\n", + "\n", + " [[ 5.7914e-02, 6.4484e-02, -2.3707e-03],\n", + " [ 2.6293e-02, 4.1130e-02, 1.4373e-02],\n", + " [ 4.6521e-02, 6.0853e-02, 3.2299e-02]],\n", + "\n", + " [[ 4.5852e-02, 2.4686e-02, 2.4303e-02],\n", + " [ 3.1945e-02, 3.4674e-02, -1.7316e-02],\n", + " [-2.4934e-02, 5.7327e-02, 2.0576e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.4928e-02, 9.5490e-03, -6.2237e-02],\n", + " [-2.0978e-02, 1.9526e-02, 6.8305e-03],\n", + " [ 4.8388e-03, -3.7300e-02, -3.6364e-02]],\n", + "\n", + " [[-3.6045e-02, 1.8570e-02, 4.3668e-02],\n", + " [-2.9774e-02, -1.0319e-02, 1.3255e-02],\n", + " [-2.2707e-02, 1.9602e-02, 7.9169e-04]],\n", + "\n", + " [[ 2.3576e-03, -1.9243e-03, 1.2709e-02],\n", + " [ 3.2156e-03, 3.6086e-02, 3.8457e-02],\n", + " [-2.2701e-02, 1.5551e-02, -4.0218e-02]]],\n", + "\n", + "\n", + " [[[ 4.7197e-02, 2.8875e-02, 5.8667e-03],\n", + " [-2.7055e-03, 3.3594e-02, 1.2208e-02],\n", + " [ 9.2712e-04, -1.1782e-02, -1.3044e-03]],\n", + "\n", + " [[-8.2965e-03, -7.0715e-03, -4.0541e-02],\n", + " [-1.0041e-02, -3.4150e-02, -2.7035e-02],\n", + " [ 1.0155e-02, 3.3201e-02, -1.4504e-02]],\n", + "\n", + " [[-7.8979e-02, -7.4035e-02, -8.0417e-02],\n", + " [-3.3371e-03, -4.3919e-04, -9.0875e-02],\n", + " [-4.7493e-02, -5.0132e-02, -2.4897e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[-4.9027e-02, -3.2874e-02, -1.1205e-02],\n", + " [-4.4780e-05, -1.0628e-02, 1.3734e-04],\n", + " [-2.6351e-02, -6.2976e-03, -3.4885e-02]],\n", + "\n", + " [[-3.6896e-02, -5.0080e-02, -1.3959e-02],\n", + " [-3.0033e-02, -1.8250e-02, -1.6171e-02],\n", + " [-1.2968e-02, -7.3765e-02, -3.8385e-02]],\n", + "\n", + " [[-1.1514e-02, -1.0031e-02, -4.4923e-02],\n", + " [-5.2885e-02, -8.0522e-02, 2.2228e-04],\n", + " [-8.0144e-03, -5.1980e-02, -6.4364e-02]]],\n", + "\n", + "\n", + " [[[ 2.0123e-02, 4.2277e-02, -8.7214e-03],\n", + " [-1.0368e-02, 3.4194e-02, 4.5498e-02],\n", + " [ 9.1709e-02, 9.3509e-02, 9.3661e-03]],\n", + "\n", + " [[ 1.2534e-02, 5.9424e-03, 1.4623e-02],\n", + " [ 3.2083e-02, 1.1867e-02, 6.8766e-02],\n", + " [ 1.9279e-02, 1.5177e-03, 1.6413e-02]],\n", + "\n", + " [[ 4.6366e-03, 8.8012e-02, 3.8882e-02],\n", + " [ 6.3008e-02, 9.7451e-02, 2.7030e-03],\n", + " [ 7.2544e-02, 4.6869e-02, 7.6242e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 8.5150e-03, -1.8553e-02, -4.8858e-02],\n", + " [-2.0450e-02, -4.0801e-03, -4.4920e-02],\n", + " [-1.4873e-02, 1.2599e-02, -2.8559e-02]],\n", + "\n", + " [[ 2.6483e-02, -2.5318e-03, 2.3260e-02],\n", + " [ 1.4355e-02, 2.9071e-03, 3.7970e-02],\n", + " [ 1.7436e-02, 1.5663e-02, 2.8179e-02]],\n", + "\n", + " [[ 5.3114e-02, 4.9337e-02, 1.1110e-01],\n", + " [ 4.0582e-02, 7.0569e-02, 3.7544e-02],\n", + " [ 5.0659e-02, 6.7754e-02, 6.0954e-02]]]])\n", + "layer1.1.bn1.weight Parameter containing:\n", + "tensor([1.0599, 0.9466, 1.0883, 1.0216, 0.9549, 1.0047, 0.9343, 1.1112, 0.9535,\n", + " 0.9474, 0.9471, 1.0289, 1.0615, 0.9349, 0.9411, 0.8893, 1.0767, 1.0939,\n", + " 1.0169, 0.9414, 1.1864, 0.9885, 1.0232, 1.3587, 0.9706, 0.9471, 0.9786,\n", + " 0.9206, 0.9999, 0.9639, 1.0253, 1.0520, 1.1145, 0.9727, 0.9409, 0.9877,\n", + " 0.9381, 0.9588, 1.0095, 1.0971, 0.9343, 1.0072, 0.9908, 0.9240, 0.9385,\n", + " 0.9874, 0.8669, 1.1432, 0.8880, 0.9495, 0.9635, 0.9576, 1.1564, 0.9145,\n", + " 0.9872, 0.9669, 0.9895, 1.0291, 0.9717, 0.8540, 1.0361, 0.9782, 0.9716,\n", + " 1.0026])\n", + "layer1.1.bn1.bias Parameter containing:\n", + "tensor([-0.0409, -0.0949, 0.0122, -0.0208, 0.0112, 0.0394, -0.0883, -0.0212,\n", + " -0.0387, -0.0367, -0.1526, -0.0442, -0.0031, -0.0585, 0.0076, -0.1048,\n", + " -0.0662, -0.0024, -0.0183, 0.0024, 0.0064, -0.0465, -0.0546, 0.2693,\n", + " -0.0167, -0.0446, -0.0386, -0.0651, -0.0097, 0.0304, 0.0280, 0.0169,\n", + " 0.0934, -0.1040, -0.0748, -0.0120, -0.0615, -0.0757, -0.0773, -0.0143,\n", + " -0.0449, -0.0480, -0.0719, -0.1599, -0.0103, -0.0578, -0.1456, 0.0326,\n", + " -0.1470, -0.1311, -0.0437, -0.0365, 0.0418, -0.0848, -0.1881, -0.0249,\n", + " 0.0098, -0.0431, 0.0236, -0.1328, -0.0497, -0.0338, -0.0273, -0.0168])\n", + "layer1.1.conv2.weight Parameter containing:\n", + "tensor([[[[-0.0425, 0.0580, 0.0541],\n", + " [-0.0597, 0.0107, -0.0336],\n", + " [ 0.0133, -0.0100, -0.0221]],\n", + "\n", + " [[-0.0303, -0.0345, -0.0381],\n", + " [-0.0578, -0.0366, 0.0113],\n", + " [-0.0443, -0.0016, 0.0039]],\n", + "\n", + " [[-0.0271, -0.0357, -0.0787],\n", + " [ 0.0236, -0.0049, -0.0762],\n", + " [-0.0060, -0.0611, -0.0399]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0652, -0.0172, 0.0024],\n", + " [-0.0414, -0.0345, 0.0072],\n", + " [ 0.0390, 0.0039, -0.0515]],\n", + "\n", + " [[-0.0112, -0.1150, -0.0668],\n", + " [-0.0464, -0.0827, -0.0895],\n", + " [-0.0056, -0.0743, -0.0646]],\n", + "\n", + " [[ 0.0104, 0.0106, 0.0315],\n", + " [ 0.0361, -0.0349, -0.0232],\n", + " [ 0.0040, -0.0481, -0.0696]]],\n", + "\n", + "\n", + " [[[-0.0087, 0.0416, 0.0631],\n", + " [ 0.0012, -0.0020, 0.0648],\n", + " [-0.0254, 0.0522, 0.0958]],\n", + "\n", + " [[ 0.0291, -0.0027, 0.0749],\n", + " [ 0.0146, 0.0427, 0.0398],\n", + " [ 0.0216, 0.0151, 0.0305]],\n", + "\n", + " [[-0.0142, -0.0535, -0.0562],\n", + " [-0.0465, 0.0428, -0.0320],\n", + " [ 0.0533, 0.0121, -0.0514]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0050, 0.0087, -0.0373],\n", + " [-0.0023, -0.0032, -0.0080],\n", + " [-0.0019, 0.0070, -0.0359]],\n", + "\n", + " [[ 0.0295, 0.0137, -0.0127],\n", + " [ 0.0067, -0.0188, -0.0028],\n", + " [-0.0058, 0.0057, -0.0498]],\n", + "\n", + " [[-0.0124, -0.0034, 0.0735],\n", + " [ 0.0373, 0.0344, 0.0147],\n", + " [-0.0111, -0.0272, -0.0573]]],\n", + "\n", + "\n", + " [[[-0.0064, -0.0499, -0.0069],\n", + " [ 0.0320, 0.0582, 0.0055],\n", + " [ 0.0213, 0.0547, 0.0568]],\n", + "\n", + " [[-0.0671, -0.0392, -0.0588],\n", + " [-0.0143, -0.0592, -0.0417],\n", + " [-0.0555, -0.0098, 0.0159]],\n", + "\n", + " [[-0.0233, 0.0063, 0.0173],\n", + " [ 0.0356, 0.0272, 0.0058],\n", + " [ 0.0783, -0.0183, 0.0077]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0018, -0.0190, -0.0682],\n", + " [ 0.0239, -0.0067, -0.0636],\n", + " [ 0.0168, 0.0153, -0.0345]],\n", + "\n", + " [[ 0.0611, -0.0152, -0.0178],\n", + " [ 0.0012, -0.0079, 0.0055],\n", + " [-0.0247, 0.0308, -0.0334]],\n", + "\n", + " [[ 0.0216, -0.0141, -0.0028],\n", + " [ 0.0329, 0.0348, -0.0168],\n", + " [ 0.0369, 0.0396, 0.0111]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[ 0.0227, -0.0292, 0.0172],\n", + " [ 0.0439, -0.0371, -0.0349],\n", + " [-0.0452, -0.0315, -0.0630]],\n", + "\n", + " [[ 0.0070, 0.0408, 0.0525],\n", + " [ 0.0316, 0.0162, 0.0199],\n", + " [-0.0101, 0.0084, 0.0361]],\n", + "\n", + " [[-0.0197, -0.0297, 0.0734],\n", + " [-0.0344, 0.0061, 0.0714],\n", + " [-0.0281, 0.0119, 0.0195]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0260, -0.0386, -0.0111],\n", + " [-0.0640, -0.0643, -0.0123],\n", + " [-0.0149, -0.0794, -0.0550]],\n", + "\n", + " [[ 0.0072, -0.0051, 0.0791],\n", + " [ 0.0199, 0.0463, 0.0407],\n", + " [ 0.0329, 0.0058, -0.0410]],\n", + "\n", + " [[ 0.0060, -0.0138, 0.0508],\n", + " [-0.0337, 0.0088, 0.0093],\n", + " [ 0.0649, 0.0187, 0.0765]]],\n", + "\n", + "\n", + " [[[-0.0246, -0.0090, 0.0458],\n", + " [-0.0393, 0.0483, 0.0357],\n", + " [-0.0376, 0.0400, 0.0190]],\n", + "\n", + " [[-0.0568, 0.0053, -0.0201],\n", + " [ 0.0060, -0.0178, -0.0189],\n", + " [ 0.0246, 0.0467, 0.0395]],\n", + "\n", + " [[-0.0069, 0.0289, 0.0062],\n", + " [ 0.0210, -0.0279, 0.0359],\n", + " [ 0.0072, -0.0665, -0.0325]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0284, 0.0276, -0.0132],\n", + " [-0.0086, 0.0012, -0.0279],\n", + " [ 0.0037, -0.0484, -0.0750]],\n", + "\n", + " [[-0.0388, -0.0717, -0.0088],\n", + " [ 0.0191, -0.0185, -0.0208],\n", + " [-0.0333, -0.0268, 0.0093]],\n", + "\n", + " [[ 0.0156, -0.0465, -0.0360],\n", + " [ 0.0242, 0.0537, 0.0047],\n", + " [ 0.0500, -0.0474, -0.0115]]],\n", + "\n", + "\n", + " [[[-0.0350, -0.0504, -0.0429],\n", + " [-0.0300, -0.0106, -0.0264],\n", + " [-0.0137, -0.0023, 0.0497]],\n", + "\n", + " [[-0.0818, -0.0710, -0.0014],\n", + " [-0.0329, -0.0905, -0.0708],\n", + " [-0.0799, -0.0574, -0.0885]],\n", + "\n", + " [[-0.0414, -0.0102, 0.0219],\n", + " [-0.0278, -0.0462, -0.0012],\n", + " [ 0.0137, -0.0312, -0.0250]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0034, 0.0267, 0.0137],\n", + " [ 0.0144, 0.0389, 0.0243],\n", + " [-0.0171, 0.0232, 0.0449]],\n", + "\n", + " [[-0.0145, -0.0584, -0.0150],\n", + " [-0.0247, 0.0208, -0.0215],\n", + " [ 0.0675, 0.0202, 0.0110]],\n", + "\n", + " [[ 0.0119, -0.0197, 0.0159],\n", + " [ 0.0221, -0.0354, -0.0394],\n", + " [ 0.0256, -0.0396, -0.0515]]]])\n", + "layer1.1.bn2.weight Parameter containing:\n", + "tensor([0.8659, 0.9570, 0.8891, 0.9045, 1.0087, 1.0288, 0.9630, 0.8839, 0.9474,\n", + " 0.8482, 0.9522, 1.0135, 0.9082, 0.7224, 0.7214, 0.9890, 0.8385, 0.9520,\n", + " 0.9855, 1.1083, 0.9786, 0.8302, 0.9394, 1.0564, 0.8630, 0.9151, 1.0819,\n", + " 0.8662, 0.7969, 0.9081, 0.8537, 0.8439, 0.9650, 0.9277, 0.7604, 0.8773,\n", + " 1.0263, 0.9107, 0.8561, 0.7252, 0.9475, 1.0708, 0.8672, 0.9553, 0.9996,\n", + " 0.9246, 0.9422, 0.9534, 0.7877, 0.7036, 0.9022, 0.9321, 0.8827, 0.7797,\n", + " 0.9681, 0.9808, 0.9166, 0.8894, 0.7313, 1.0279, 0.8668, 0.7882, 0.9513,\n", + " 0.9528])\n", + "layer1.1.bn2.bias Parameter containing:\n", + "tensor([-0.0263, -0.0246, -0.0380, -0.1532, 0.0098, 0.0709, -0.0320, -0.0192,\n", + " -0.0470, -0.0412, -0.0918, -0.1333, -0.0979, -0.1266, -0.0348, -0.0114,\n", + " -0.1043, -0.0150, -0.0406, 0.0311, 0.0046, -0.0594, 0.0503, 0.0359,\n", + " -0.0266, -0.0136, -0.0352, -0.0149, -0.1466, 0.0198, -0.0174, 0.0008,\n", + " -0.0475, 0.0132, -0.0096, -0.1214, -0.0484, -0.0154, -0.0565, -0.0403,\n", + " 0.0498, 0.0305, -0.0371, 0.0178, -0.0004, -0.0951, -0.0434, -0.0246,\n", + " -0.0990, -0.0080, -0.0547, -0.0164, -0.0435, 0.0026, -0.0411, -0.0411,\n", + " 0.0899, -0.0145, 0.0086, -0.0031, -0.0222, -0.0939, 0.1107, -0.0317])\n", + "layer2.0.conv1.weight Parameter containing:\n", + "tensor([[[[ 0.0158, 0.0590, 0.0726],\n", + " [ 0.0895, 0.0307, 0.0540],\n", + " [ 0.0257, 0.0603, 0.0361]],\n", + "\n", + " [[-0.0074, 0.0527, 0.0438],\n", + " [-0.0517, 0.0088, 0.0551],\n", + " [ 0.0119, 0.0110, 0.0201]],\n", + "\n", + " [[-0.0075, 0.0335, 0.0092],\n", + " [ 0.0356, -0.0221, -0.0678],\n", + " [ 0.0104, 0.0493, -0.0018]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0101, 0.0331, 0.0198],\n", + " [-0.0509, 0.0216, 0.0288],\n", + " [ 0.0267, 0.0396, 0.0091]],\n", + "\n", + " [[ 0.0025, 0.0212, 0.0050],\n", + " [-0.0509, 0.0145, -0.0062],\n", + " [-0.0173, 0.0092, 0.0103]],\n", + "\n", + " [[ 0.0235, 0.0254, -0.0029],\n", + " [ 0.0238, -0.0455, -0.0025],\n", + " [-0.0164, -0.0220, -0.0331]]],\n", + "\n", + "\n", + " [[[-0.0046, 0.0073, -0.0030],\n", + " [ 0.0025, 0.0385, 0.0005],\n", + " [-0.0113, 0.0183, -0.0256]],\n", + "\n", + " [[-0.0707, -0.0429, 0.0003],\n", + " [-0.0114, -0.0280, 0.0282],\n", + " [-0.0454, -0.0544, -0.0074]],\n", + "\n", + " [[-0.0286, 0.0242, 0.0237],\n", + " [ 0.0335, -0.0229, -0.0133],\n", + " [ 0.0259, 0.0158, 0.0215]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0169, -0.0294, -0.0425],\n", + " [-0.0587, -0.0123, 0.0025],\n", + " [-0.0308, -0.0002, 0.0036]],\n", + "\n", + " [[ 0.0214, -0.0275, -0.0157],\n", + " [-0.0303, -0.0238, 0.0302],\n", + " [-0.0468, 0.0252, -0.0138]],\n", + "\n", + " [[ 0.0201, -0.0123, -0.0140],\n", + " [ 0.0291, -0.0242, 0.0228],\n", + " [-0.0296, -0.0166, 0.0087]]],\n", + "\n", + "\n", + " [[[-0.0199, 0.0162, -0.0597],\n", + " [-0.0074, 0.0055, 0.0203],\n", + " [ 0.0311, -0.0217, 0.0030]],\n", + "\n", + " [[ 0.0209, 0.0524, -0.0269],\n", + " [ 0.0419, 0.0137, -0.0015],\n", + " [ 0.0096, 0.0223, -0.0197]],\n", + "\n", + " [[-0.0443, -0.0454, -0.0844],\n", + " [ 0.0445, 0.0140, -0.0364],\n", + " [ 0.0462, 0.0051, 0.0261]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0373, -0.0513, -0.0283],\n", + " [ 0.0204, -0.0505, -0.0076],\n", + " [-0.0416, -0.0456, -0.0535]],\n", + "\n", + " [[-0.0130, -0.0107, 0.0106],\n", + " [-0.0009, 0.0307, 0.0062],\n", + " [ 0.0187, 0.0160, -0.0581]],\n", + "\n", + " [[-0.0268, -0.0571, -0.0176],\n", + " [ 0.0213, -0.0509, -0.0372],\n", + " [ 0.0095, 0.0063, 0.0298]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[-0.0013, -0.0230, 0.0109],\n", + " [ 0.0346, 0.0035, 0.0074],\n", + " [ 0.0350, 0.0549, 0.0220]],\n", + "\n", + " [[-0.0379, 0.0252, -0.0286],\n", + " [ 0.0210, -0.0278, 0.0077],\n", + " [ 0.0042, -0.0041, 0.0536]],\n", + "\n", + " [[ 0.0404, -0.0401, -0.0298],\n", + " [-0.0400, -0.0587, -0.0317],\n", + " [-0.0152, -0.0491, -0.0283]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0188, -0.0144, -0.0377],\n", + " [-0.0201, -0.0266, -0.0110],\n", + " [-0.0168, 0.0445, -0.0206]],\n", + "\n", + " [[-0.0047, -0.0460, 0.0059],\n", + " [-0.0314, -0.0195, -0.0260],\n", + " [-0.0873, -0.0550, -0.0475]],\n", + "\n", + " [[ 0.0206, -0.0238, -0.0055],\n", + " [-0.0567, 0.0145, 0.0048],\n", + " [ 0.0315, -0.0404, -0.0358]]],\n", + "\n", + "\n", + " [[[ 0.0376, 0.0002, 0.0187],\n", + " [-0.0714, -0.0410, -0.0170],\n", + " [-0.0173, -0.0522, -0.0071]],\n", + "\n", + " [[ 0.0622, 0.0299, -0.0092],\n", + " [ 0.0005, -0.0267, 0.0519],\n", + " [ 0.0598, 0.0151, 0.0693]],\n", + "\n", + " [[-0.0858, -0.0869, -0.0007],\n", + " [-0.0152, -0.0226, 0.0009],\n", + " [-0.0550, 0.0011, -0.0550]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0283, 0.0049, 0.0274],\n", + " [ 0.0528, 0.0098, -0.0040],\n", + " [ 0.0354, 0.0302, 0.0458]],\n", + "\n", + " [[-0.0104, 0.0141, -0.0294],\n", + " [-0.0277, -0.0633, 0.0178],\n", + " [ 0.0387, 0.0451, -0.0141]],\n", + "\n", + " [[-0.0456, -0.0294, -0.0607],\n", + " [ 0.0065, -0.0081, -0.0275],\n", + " [-0.0255, -0.0203, -0.0485]]],\n", + "\n", + "\n", + " [[[-0.0105, 0.0096, -0.0166],\n", + " [-0.0391, -0.0354, -0.0003],\n", + " [-0.0502, 0.0183, 0.0312]],\n", + "\n", + " [[ 0.0169, -0.0024, -0.0177],\n", + " [-0.0148, 0.0153, 0.0100],\n", + " [-0.0366, 0.0393, 0.0128]],\n", + "\n", + " [[ 0.0077, 0.0533, 0.0733],\n", + " [ 0.0163, -0.0186, -0.0017],\n", + " [-0.0240, 0.0063, 0.0089]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0015, 0.0222, -0.0031],\n", + " [-0.0379, -0.0089, 0.0327],\n", + " [-0.0173, -0.0029, 0.0263]],\n", + "\n", + " [[-0.0621, 0.0345, 0.0175],\n", + " [-0.0643, 0.0059, 0.0182],\n", + " [ 0.0182, 0.0213, 0.0743]],\n", + "\n", + " [[ 0.0029, 0.0253, 0.0361],\n", + " [-0.0325, 0.0326, 0.0458],\n", + " [ 0.0121, -0.0050, 0.0581]]]])\n", + "layer2.0.bn1.weight Parameter containing:\n", + "tensor([0.9856, 0.9561, 1.0298, 0.9849, 1.0096, 1.0142, 0.9732, 1.0974, 0.9764,\n", + " 0.9295, 0.9870, 0.9391, 0.9860, 0.9508, 0.9131, 0.9831, 1.0883, 1.0482,\n", + " 1.0211, 1.0542, 0.9074, 0.9631, 0.9923, 1.0211, 0.9294, 0.9730, 0.9817,\n", + " 1.1598, 0.9025, 0.9847, 1.0318, 1.0859, 1.0417, 1.0281, 1.0176, 1.0079,\n", + " 1.0154, 1.0588, 1.0517, 1.1073, 1.0900, 0.9344, 0.9602, 1.0996, 1.0282,\n", + " 1.0062, 1.0235, 1.0970, 0.9816, 1.0375, 0.9065, 0.9550, 1.0763, 0.8873,\n", + " 0.9492, 0.9854, 0.9797, 0.9520, 1.0777, 1.0084, 0.9395, 0.9863, 0.8887,\n", + " 0.9469, 1.1168, 0.9005, 1.1311, 0.9260, 0.9452, 0.9881, 0.8533, 0.9687,\n", + " 0.9784, 0.9794, 1.0136, 0.9528, 0.9522, 0.9981, 0.8862, 1.0042, 0.9056,\n", + " 0.8985, 1.1169, 0.9522, 1.0590, 0.8980, 0.9304, 1.0703, 1.0742, 0.9503,\n", + " 1.0123, 1.0671, 0.9040, 0.9954, 1.1495, 1.0520, 1.0443, 0.9952, 0.9958,\n", + " 0.9039, 1.0339, 0.9977, 0.9204, 0.9406, 0.9400, 1.0673, 0.9102, 1.0681,\n", + " 0.8994, 0.9404, 0.9741, 0.9414, 0.9928, 1.0733, 1.0143, 1.0218, 1.2354,\n", + " 1.1403, 1.0164, 1.0249, 0.9757, 0.9149, 0.9944, 1.1335, 0.8665, 1.1207,\n", + " 0.9400, 0.9708])\n", + "layer2.0.bn1.bias Parameter containing:\n", + "tensor([-4.4408e-02, -6.3460e-02, -5.9487e-02, 3.9959e-02, 1.3809e-02,\n", + " 5.7571e-03, -6.3010e-02, 1.1827e-02, -4.0396e-02, -6.0959e-02,\n", + " -6.8656e-02, -5.4246e-02, 1.2467e-03, -7.8717e-02, -8.8926e-02,\n", + " -4.7148e-03, -6.2139e-02, -1.5113e-02, -8.7445e-02, -4.2238e-02,\n", + " -5.4938e-02, -3.2369e-02, -7.4581e-02, -3.3462e-02, -4.9141e-02,\n", + " -4.9546e-02, -7.5260e-02, 5.9013e-02, -1.0577e-01, -1.0421e-01,\n", + " -2.5820e-04, -2.0431e-02, -5.1903e-02, -1.8563e-02, -6.3990e-02,\n", + " -2.6435e-02, -1.3204e-02, -3.7876e-02, -2.3634e-02, -5.6133e-02,\n", + " -4.8304e-02, -8.8873e-02, -2.3738e-02, -1.1452e-02, -4.6369e-02,\n", + " -2.9358e-03, 8.4575e-03, 2.1221e-02, -4.8208e-02, 1.7766e-02,\n", + " -1.3707e-01, -5.6345e-02, 2.1062e-02, -1.4223e-01, -6.9033e-03,\n", + " -6.6139e-02, -7.3860e-02, -5.2747e-02, 7.9480e-02, -4.0997e-02,\n", + " -1.0225e-02, 1.3592e-02, -8.0994e-02, -6.3811e-02, -1.8184e-02,\n", + " -6.4378e-02, 3.6648e-02, -8.9201e-02, -6.7080e-02, -3.2439e-02,\n", + " -1.4202e-01, -1.4953e-02, -5.7625e-02, -4.2923e-02, -2.5661e-02,\n", + " -3.7217e-02, -2.7865e-02, -5.7743e-02, -4.3572e-02, 9.9375e-03,\n", + " -6.7939e-02, -9.3802e-02, -2.8521e-02, -5.0249e-02, 5.2249e-02,\n", + " -1.2405e-01, -5.6959e-02, -3.5813e-02, -1.4070e-02, -1.1760e-01,\n", + " 1.4158e-03, -1.2018e-02, -8.2153e-02, -1.9360e-02, 1.9967e-02,\n", + " -3.5138e-02, -8.8690e-02, 5.2727e-03, -1.1467e-02, -6.7700e-02,\n", + " 4.8841e-02, -3.8626e-02, -1.0003e-01, -4.1781e-02, -5.3253e-02,\n", + " -3.4413e-02, -7.0329e-02, -9.5402e-05, -1.0329e-01, -5.2985e-02,\n", + " 1.8760e-02, -2.0112e-02, -6.0653e-02, -5.8835e-02, -7.4210e-02,\n", + " 3.0239e-02, -7.5367e-02, -2.2648e-02, -4.0464e-02, -6.7306e-02,\n", + " 4.4584e-02, -6.6090e-02, 2.0774e-02, -4.0260e-02, -1.1678e-01,\n", + " 6.9928e-02, -5.5673e-02, -3.2468e-02])\n", + "layer2.0.conv2.weight Parameter containing:\n", + "tensor([[[[-2.9326e-02, 1.4776e-04, 1.3197e-02],\n", + " [-1.6475e-02, -4.4652e-02, -2.1389e-03],\n", + " [-4.7745e-04, 1.1555e-02, 1.1161e-02]],\n", + "\n", + " [[-8.2114e-03, 7.7707e-03, 3.8702e-02],\n", + " [ 4.8851e-02, 2.2664e-02, -3.6508e-04],\n", + " [-1.8164e-02, -1.9086e-02, -7.5711e-04]],\n", + "\n", + " [[-1.2012e-02, 5.1844e-03, -1.6460e-02],\n", + " [-4.1903e-02, 1.6704e-02, 1.1525e-02],\n", + " [-2.8750e-02, -1.8928e-02, -2.8496e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-3.4732e-02, -2.8028e-02, -4.9673e-02],\n", + " [ 2.3525e-02, -7.4910e-03, -1.2319e-02],\n", + " [-1.1767e-02, -2.2846e-03, 1.4591e-02]],\n", + "\n", + " [[ 2.5598e-02, 3.6546e-02, 3.5433e-02],\n", + " [ 3.2633e-02, 4.1257e-03, -2.0384e-02],\n", + " [-1.6832e-02, -2.4976e-02, -1.0932e-02]],\n", + "\n", + " [[-2.4550e-02, 1.4871e-02, -2.1896e-03],\n", + " [-3.5753e-02, 3.3367e-04, 3.5690e-03],\n", + " [ 9.3313e-03, -1.0965e-02, -3.1038e-03]]],\n", + "\n", + "\n", + " [[[-6.1483e-02, -1.5989e-02, -5.0814e-02],\n", + " [-7.6276e-03, 1.1751e-02, -2.3630e-02],\n", + " [-1.1072e-02, -8.0032e-03, 1.2706e-02]],\n", + "\n", + " [[ 3.1783e-02, 2.4595e-02, 4.2654e-02],\n", + " [ 3.6219e-02, 1.1338e-02, 4.2761e-02],\n", + " [-1.9528e-03, 2.3325e-02, -2.9664e-02]],\n", + "\n", + " [[ 1.7203e-02, 2.9008e-02, 1.6512e-02],\n", + " [-2.1788e-02, -1.5973e-02, 3.8330e-03],\n", + " [-6.3857e-02, -4.3809e-02, -1.8129e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.3763e-02, -1.1545e-02, -1.4335e-02],\n", + " [-3.4292e-02, -3.5164e-03, -3.0173e-02],\n", + " [-4.4377e-02, -1.6386e-02, -2.8011e-02]],\n", + "\n", + " [[-3.7781e-02, -3.0508e-02, -9.6721e-04],\n", + " [ 1.6204e-02, 2.7993e-02, 5.1552e-02],\n", + " [ 5.3977e-03, 1.0185e-02, -3.6914e-02]],\n", + "\n", + " [[ 3.8443e-02, 3.5896e-02, 4.3196e-02],\n", + " [-4.8815e-04, -3.2587e-02, -7.3151e-03],\n", + " [ 1.6852e-02, -3.9039e-02, -3.5906e-02]]],\n", + "\n", + "\n", + " [[[ 1.6662e-02, 3.7526e-02, 3.4958e-02],\n", + " [ 4.3510e-02, 3.0857e-02, 1.3001e-02],\n", + " [-5.7000e-03, -7.4834e-03, -4.0627e-02]],\n", + "\n", + " [[-3.2959e-02, -2.2397e-02, 1.4952e-02],\n", + " [-2.7921e-02, 1.1206e-02, 5.2647e-03],\n", + " [-5.5147e-03, 3.3321e-02, 6.1338e-02]],\n", + "\n", + " [[-1.4455e-02, 1.6981e-02, 2.7939e-02],\n", + " [ 4.5812e-02, 6.6108e-02, 5.7073e-02],\n", + " [-1.3593e-02, -1.7858e-03, 1.6507e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-5.2863e-02, -1.2127e-02, 3.7099e-02],\n", + " [-4.4639e-02, 2.2652e-02, -3.6863e-05],\n", + " [-4.7820e-02, 1.3688e-02, 3.1525e-02]],\n", + "\n", + " [[-3.2835e-02, -4.7681e-02, 2.5508e-02],\n", + " [-2.3645e-02, -7.2019e-03, -2.9590e-02],\n", + " [-1.4003e-02, -2.3828e-02, -2.2238e-02]],\n", + "\n", + " [[-2.7537e-04, 1.6434e-02, 3.0733e-02],\n", + " [ 7.4750e-03, -5.3029e-03, 2.6011e-02],\n", + " [ 4.8595e-02, 2.9357e-03, 5.6132e-02]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[-3.5815e-02, -1.3693e-02, -2.7741e-02],\n", + " [-3.0585e-02, -4.0935e-02, -3.9116e-02],\n", + " [-2.7633e-02, -2.2705e-02, 1.9924e-02]],\n", + "\n", + " [[ 2.8751e-02, 1.9905e-02, 2.9578e-02],\n", + " [ 3.6153e-02, -4.5332e-03, 3.3908e-02],\n", + " [ 3.7232e-03, 3.7734e-04, 2.3319e-02]],\n", + "\n", + " [[-7.5011e-03, -1.1992e-02, -1.8491e-02],\n", + " [-5.2015e-03, 1.5962e-02, -6.3561e-03],\n", + " [ 2.9083e-03, 1.9739e-02, -1.9642e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.7470e-02, -9.8844e-03, 8.6894e-03],\n", + " [ 1.1370e-02, -2.8649e-02, -1.7614e-02],\n", + " [ 7.5035e-03, -2.8190e-02, 2.3285e-02]],\n", + "\n", + " [[ 2.2195e-02, 1.1344e-02, -1.1869e-02],\n", + " [ 1.4061e-03, 3.4397e-02, 2.2290e-02],\n", + " [-8.9835e-03, -3.6462e-04, 4.3276e-03]],\n", + "\n", + " [[ 1.3146e-02, -1.4805e-03, -3.4839e-02],\n", + " [ 9.8965e-03, 1.2123e-03, -1.2450e-02],\n", + " [-2.6127e-02, -1.2640e-03, -2.4437e-02]]],\n", + "\n", + "\n", + " [[[-3.1953e-02, -4.1924e-02, -1.0290e-02],\n", + " [-3.3615e-03, -3.2658e-02, -8.7536e-03],\n", + " [-1.8593e-02, 4.7878e-04, 4.5571e-02]],\n", + "\n", + " [[-1.6380e-02, 9.0561e-03, 3.3792e-02],\n", + " [-4.8492e-03, -3.1668e-02, -1.1819e-02],\n", + " [-2.6978e-02, -4.0084e-03, -7.3130e-03]],\n", + "\n", + " [[-4.7184e-03, -6.5889e-02, -6.8211e-02],\n", + " [ 1.6105e-02, 3.9260e-02, -2.7005e-02],\n", + " [ 2.3144e-02, 3.0311e-02, 1.1514e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-5.8818e-03, -3.2534e-02, 2.3114e-02],\n", + " [ 1.2380e-02, 1.3630e-02, -4.6785e-02],\n", + " [ 1.2566e-02, 3.2790e-02, 2.4120e-02]],\n", + "\n", + " [[-4.9905e-02, -4.5534e-03, -4.3831e-02],\n", + " [-4.3033e-02, -1.5961e-02, -2.9293e-02],\n", + " [-1.2634e-02, 4.5129e-03, 3.3744e-02]],\n", + "\n", + " [[-1.0258e-02, 1.6650e-02, -1.9402e-02],\n", + " [ 4.0746e-02, 3.2129e-02, 3.6650e-02],\n", + " [-4.0866e-02, -4.9628e-02, -2.9688e-02]]],\n", + "\n", + "\n", + " [[[ 2.0279e-02, 8.4315e-03, -2.2303e-02],\n", + " [ 2.8924e-03, -3.4284e-02, -2.5508e-02],\n", + " [ 7.2640e-03, -1.9162e-02, -3.7586e-02]],\n", + "\n", + " [[ 1.6705e-02, 1.3032e-02, -8.9937e-03],\n", + " [ 4.2467e-02, 8.9482e-03, 2.3359e-02],\n", + " [ 6.0444e-02, 1.8572e-02, -2.0024e-02]],\n", + "\n", + " [[ 5.3415e-02, -1.9900e-02, -2.3915e-02],\n", + " [ 9.5328e-03, 6.5805e-03, -2.5049e-03],\n", + " [ 4.5412e-03, 1.4086e-02, 3.1413e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.6861e-02, -5.0322e-03, -5.3216e-03],\n", + " [ 4.4522e-02, 3.5843e-02, 2.8520e-02],\n", + " [-2.2439e-02, 4.2831e-02, 2.1473e-02]],\n", + "\n", + " [[-1.4825e-02, -1.6843e-02, -3.1605e-02],\n", + " [-1.7422e-02, 1.0006e-02, -1.6374e-02],\n", + " [ 1.0959e-02, -2.0312e-02, 2.9210e-02]],\n", + "\n", + " [[ 2.0450e-02, -1.3104e-03, -2.7232e-02],\n", + " [ 1.8129e-02, 9.4467e-03, 5.8069e-05],\n", + " [ 2.1227e-02, -3.1317e-02, 1.4729e-02]]]])\n", + "layer2.0.bn2.weight Parameter containing:\n", + "tensor([0.9931, 1.0149, 1.0834, 1.3062, 0.9436, 1.0593, 1.1705, 0.8661, 1.2560,\n", + " 1.0011, 1.0543, 1.0866, 1.0357, 0.9802, 1.0470, 1.0681, 0.9577, 1.0531,\n", + " 1.1537, 0.9731, 0.8574, 1.3457, 0.9703, 0.8292, 1.1049, 1.0989, 1.0087,\n", + " 0.9208, 1.0278, 0.9294, 0.9599, 1.0768, 1.0131, 0.9518, 1.0358, 1.1543,\n", + " 1.0567, 1.1767, 1.0570, 0.9420, 1.1126, 0.9607, 0.9808, 1.0174, 1.1352,\n", + " 1.0871, 0.9884, 1.0924, 1.0687, 0.9786, 0.8349, 0.9612, 1.1718, 1.0225,\n", + " 1.0773, 0.9246, 0.9866, 1.1723, 0.8953, 1.0282, 1.0341, 0.9762, 0.8887,\n", + " 0.8342, 0.8589, 0.9796, 1.0104, 1.0486, 1.0344, 1.0154, 0.8721, 0.9595,\n", + " 1.0041, 1.2363, 1.0264, 1.0738, 0.8888, 0.7846, 0.9974, 1.0654, 1.0377,\n", + " 0.7588, 0.8791, 0.9104, 0.8667, 0.8919, 1.0725, 1.1285, 0.9482, 0.9912,\n", + " 1.0816, 0.9307, 1.1489, 0.9555, 0.9685, 0.9760, 0.9407, 0.9562, 1.0751,\n", + " 0.9656, 1.0497, 1.3493, 1.0319, 1.0755, 1.1694, 0.9989, 0.9907, 0.9567,\n", + " 0.8660, 1.1067, 1.2941, 1.0724, 0.8705, 1.0477, 0.8153, 1.1535, 0.8960,\n", + " 0.9657, 1.0321, 1.0597, 0.9351, 1.0514, 0.9853, 0.9279, 1.0190, 1.0215,\n", + " 0.9553, 0.9368])\n", + "layer2.0.bn2.bias Parameter containing:\n", + "tensor([-1.2981e-02, -8.6542e-02, -2.7465e-02, -8.9233e-02, -4.8785e-02,\n", + " -5.1820e-02, -7.4506e-02, -8.4871e-03, -8.0788e-02, -4.5926e-02,\n", + " -2.8059e-02, -3.0307e-02, -2.2497e-02, -5.4316e-02, -1.2210e-01,\n", + " -2.3084e-02, -1.5159e-02, -3.7485e-02, -3.9592e-02, -1.7644e-02,\n", + " -9.2507e-02, -6.1548e-02, -4.2106e-02, -2.7228e-02, -4.7395e-02,\n", + " 4.0894e-02, -8.1798e-02, 7.9138e-03, -2.8363e-02, -2.0816e-02,\n", + " 4.5710e-02, -4.7065e-02, -6.4959e-02, -5.2531e-02, -1.4958e-02,\n", + " -2.6834e-02, -7.6824e-02, -4.7256e-02, -2.8755e-02, -3.9290e-02,\n", + " 2.5110e-03, 6.5448e-02, -2.4467e-02, -6.3492e-02, 4.6999e-03,\n", + " -4.1704e-02, -3.9343e-02, 4.9145e-05, -6.0200e-03, -3.4718e-02,\n", + " -5.2317e-02, -5.0237e-02, -4.9025e-02, -2.7047e-02, -5.2278e-02,\n", + " -4.9766e-02, -5.2599e-02, -1.2772e-01, -1.0851e-02, -6.3345e-02,\n", + " -5.3914e-02, -1.0681e-02, -2.0828e-02, -4.9643e-02, -4.6539e-02,\n", + " 2.7344e-02, -2.4344e-02, -3.5232e-02, -3.5927e-02, -4.2018e-02,\n", + " -5.2375e-02, -4.9287e-02, -1.5841e-03, -4.7609e-02, -6.0663e-02,\n", + " -4.4702e-02, -7.3340e-02, -1.0380e-01, -5.8219e-02, -9.6217e-03,\n", + " -6.6633e-03, -3.2506e-02, 4.1905e-02, -5.3879e-02, -3.6673e-02,\n", + " -4.4338e-02, -5.6134e-02, -6.1860e-03, -1.4108e-02, -7.2612e-02,\n", + " -4.0473e-02, -2.9741e-03, -9.7863e-02, -1.9554e-02, -8.3787e-02,\n", + " -3.0235e-02, -8.5696e-02, -7.1485e-02, -4.9228e-02, 1.4736e-02,\n", + " 2.4039e-02, -6.1848e-02, -5.7710e-02, -9.6235e-02, -1.1919e-02,\n", + " -2.6103e-02, -4.3060e-02, 1.3986e-02, -5.9900e-02, -2.8015e-02,\n", + " -1.5239e-02, -2.3644e-02, -5.2308e-02, -7.7621e-02, -1.2540e-01,\n", + " -1.4630e-02, -5.4057e-02, -3.7307e-02, -5.4744e-02, -2.6498e-02,\n", + " -1.0680e-01, -1.1601e-02, -1.6164e-02, -3.0841e-02, -5.8472e-02,\n", + " -1.2627e-02, -8.5479e-02, -4.6238e-02])\n", + "layer2.0.shortcut_conv.weight Parameter containing:\n", + "tensor([[[[ 0.2162]],\n", + "\n", + " [[ 0.0339]],\n", + "\n", + " [[ 0.1120]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.1301]],\n", + "\n", + " [[ 0.0351]],\n", + "\n", + " [[-0.1245]]],\n", + "\n", + "\n", + " [[[-0.1042]],\n", + "\n", + " [[ 0.1120]],\n", + "\n", + " [[-0.0587]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.1505]],\n", + "\n", + " [[ 0.0478]],\n", + "\n", + " [[-0.0350]]],\n", + "\n", + "\n", + " [[[ 0.0636]],\n", + "\n", + " [[-0.0045]],\n", + "\n", + " [[-0.0074]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0249]],\n", + "\n", + " [[-0.0809]],\n", + "\n", + " [[ 0.0892]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[-0.0323]],\n", + "\n", + " [[ 0.1375]],\n", + "\n", + " [[ 0.0807]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0398]],\n", + "\n", + " [[ 0.0997]],\n", + "\n", + " [[ 0.0951]]],\n", + "\n", + "\n", + " [[[-0.0811]],\n", + "\n", + " [[ 0.1213]],\n", + "\n", + " [[ 0.0251]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.1344]],\n", + "\n", + " [[ 0.1247]],\n", + "\n", + " [[-0.0439]]],\n", + "\n", + "\n", + " [[[ 0.0143]],\n", + "\n", + " [[-0.0505]],\n", + "\n", + " [[-0.0858]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0293]],\n", + "\n", + " [[-0.0258]],\n", + "\n", + " [[-0.1046]]]])\n", + "layer2.0.shortcut_bn.weight Parameter containing:\n", + "tensor([1.0056, 1.0531, 1.0043, 0.9518, 0.9567, 1.0751, 1.1187, 1.0293, 1.0514,\n", + " 1.0665, 1.0102, 1.0254, 1.0789, 1.0322, 1.1177, 1.0840, 1.0340, 1.0115,\n", + " 1.1417, 1.0085, 0.9390, 1.1613, 1.0394, 1.0185, 0.9537, 1.0575, 1.1024,\n", + " 1.0102, 1.0730, 1.0289, 1.0484, 1.1239, 1.0378, 1.0818, 0.9914, 1.0452,\n", + " 1.0040, 1.0889, 1.1251, 1.0409, 1.0309, 1.0383, 1.0509, 0.9937, 1.0173,\n", + " 1.0388, 1.1697, 1.0248, 1.0797, 0.9914, 0.9138, 0.9931, 0.9795, 1.0041,\n", + " 0.9696, 0.9805, 0.9996, 1.0714, 0.9691, 1.0337, 1.0087, 1.1008, 1.0351,\n", + " 0.9816, 1.0039, 1.0413, 1.0856, 1.0510, 0.9705, 1.0622, 0.9942, 1.0241,\n", + " 1.0249, 1.0274, 1.0713, 1.0870, 0.9671, 0.9346, 1.0485, 1.0752, 1.0161,\n", + " 0.9574, 1.0789, 1.0132, 0.9525, 1.0304, 0.9617, 1.0545, 0.9895, 1.0222,\n", + " 0.9965, 0.9874, 1.0452, 0.9885, 0.9668, 1.0647, 0.9278, 0.9529, 1.0227,\n", + " 1.0354, 1.0489, 1.1118, 0.9953, 1.0224, 1.0967, 1.0673, 1.0372, 1.0253,\n", + " 0.9920, 1.1055, 1.0380, 1.0451, 0.9972, 1.0180, 0.9287, 1.0507, 0.9703,\n", + " 1.0450, 1.0503, 1.0136, 1.0655, 1.0190, 1.0881, 1.0169, 1.0723, 0.9594,\n", + " 0.9919, 1.0379])\n", + "layer2.0.shortcut_bn.bias Parameter containing:\n", + "tensor([-1.2981e-02, -8.6542e-02, -2.7465e-02, -8.9233e-02, -4.8785e-02,\n", + " -5.1820e-02, -7.4506e-02, -8.4871e-03, -8.0788e-02, -4.5926e-02,\n", + " -2.8059e-02, -3.0307e-02, -2.2497e-02, -5.4316e-02, -1.2210e-01,\n", + " -2.3084e-02, -1.5159e-02, -3.7485e-02, -3.9592e-02, -1.7644e-02,\n", + " -9.2507e-02, -6.1548e-02, -4.2106e-02, -2.7228e-02, -4.7395e-02,\n", + " 4.0894e-02, -8.1798e-02, 7.9138e-03, -2.8363e-02, -2.0816e-02,\n", + " 4.5710e-02, -4.7065e-02, -6.4959e-02, -5.2531e-02, -1.4958e-02,\n", + " -2.6834e-02, -7.6824e-02, -4.7256e-02, -2.8755e-02, -3.9290e-02,\n", + " 2.5110e-03, 6.5448e-02, -2.4467e-02, -6.3492e-02, 4.6999e-03,\n", + " -4.1704e-02, -3.9343e-02, 4.9145e-05, -6.0200e-03, -3.4718e-02,\n", + " -5.2317e-02, -5.0237e-02, -4.9025e-02, -2.7047e-02, -5.2278e-02,\n", + " -4.9766e-02, -5.2599e-02, -1.2772e-01, -1.0851e-02, -6.3345e-02,\n", + " -5.3914e-02, -1.0681e-02, -2.0828e-02, -4.9643e-02, -4.6539e-02,\n", + " 2.7344e-02, -2.4344e-02, -3.5232e-02, -3.5927e-02, -4.2018e-02,\n", + " -5.2375e-02, -4.9287e-02, -1.5841e-03, -4.7609e-02, -6.0663e-02,\n", + " -4.4702e-02, -7.3340e-02, -1.0380e-01, -5.8219e-02, -9.6217e-03,\n", + " -6.6633e-03, -3.2506e-02, 4.1905e-02, -5.3879e-02, -3.6673e-02,\n", + " -4.4338e-02, -5.6134e-02, -6.1860e-03, -1.4108e-02, -7.2612e-02,\n", + " -4.0473e-02, -2.9741e-03, -9.7863e-02, -1.9554e-02, -8.3787e-02,\n", + " -3.0235e-02, -8.5696e-02, -7.1485e-02, -4.9228e-02, 1.4736e-02,\n", + " 2.4039e-02, -6.1848e-02, -5.7710e-02, -9.6235e-02, -1.1919e-02,\n", + " -2.6103e-02, -4.3060e-02, 1.3986e-02, -5.9900e-02, -2.8015e-02,\n", + " -1.5239e-02, -2.3644e-02, -5.2308e-02, -7.7621e-02, -1.2540e-01,\n", + " -1.4630e-02, -5.4057e-02, -3.7307e-02, -5.4744e-02, -2.6498e-02,\n", + " -1.0680e-01, -1.1601e-02, -1.6164e-02, -3.0841e-02, -5.8472e-02,\n", + " -1.2627e-02, -8.5479e-02, -4.6238e-02])\n", + "layer2.1.conv1.weight Parameter containing:\n", + "tensor([[[[-1.2671e-02, -2.3938e-02, -1.9099e-02],\n", + " [ 1.1139e-02, -2.8324e-02, -2.8830e-02],\n", + " [-2.3069e-02, -1.9006e-02, 2.0482e-02]],\n", + "\n", + " [[-1.5592e-03, -3.6956e-02, -4.0356e-02],\n", + " [ 1.4791e-02, -3.0037e-02, -3.2506e-02],\n", + " [ 5.5356e-03, 1.8473e-02, -1.3575e-02]],\n", + "\n", + " [[-1.5707e-02, 2.3170e-02, 9.3615e-04],\n", + " [-2.3199e-02, -3.8213e-03, -1.6252e-02],\n", + " [ 3.6089e-02, -3.3433e-03, 2.4761e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[ 2.0869e-02, 2.5472e-02, 1.1077e-02],\n", + " [ 4.3677e-02, 4.7397e-02, 5.2356e-04],\n", + " [-2.2244e-02, 3.8463e-02, 1.2060e-04]],\n", + "\n", + " [[ 3.4758e-02, 5.2871e-02, 4.4993e-02],\n", + " [ 4.9931e-02, 4.9706e-02, 2.1665e-02],\n", + " [ 4.7596e-02, 1.3522e-02, 2.3822e-02]],\n", + "\n", + " [[ 2.2113e-02, -4.2537e-03, -4.7076e-02],\n", + " [ 3.1440e-02, 5.0481e-03, 3.0701e-03],\n", + " [ 1.2286e-02, 2.8508e-02, 1.2381e-03]]],\n", + "\n", + "\n", + " [[[-2.3738e-02, -1.3931e-02, -2.1727e-02],\n", + " [-2.1146e-02, -5.2254e-03, -8.0258e-05],\n", + " [-3.2895e-02, -1.6613e-02, -4.7330e-02]],\n", + "\n", + " [[ 3.0972e-02, -1.8397e-02, 2.1567e-02],\n", + " [ 3.8275e-02, 3.8623e-02, 2.3262e-02],\n", + " [ 1.7491e-02, 2.6040e-02, -9.0549e-03]],\n", + "\n", + " [[ 2.3262e-02, 7.3751e-04, 1.5063e-02],\n", + " [ 4.1470e-02, 2.8551e-02, 3.1405e-02],\n", + " [ 2.1048e-02, 1.2375e-02, 4.5953e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-2.8296e-02, -9.1726e-03, -7.3690e-03],\n", + " [ 1.7543e-02, 1.2311e-03, 1.9880e-02],\n", + " [ 1.0991e-02, -5.1386e-02, 3.1852e-03]],\n", + "\n", + " [[-2.8715e-02, -2.7118e-02, -7.1705e-03],\n", + " [ 4.2610e-02, 2.4501e-02, 1.3790e-02],\n", + " [-6.7407e-03, -2.6416e-03, -7.7918e-03]],\n", + "\n", + " [[ 1.2130e-02, 1.8906e-02, 1.1645e-03],\n", + " [ 3.3207e-02, 1.5823e-02, -3.7893e-03],\n", + " [ 4.3907e-02, 3.4477e-02, 3.8621e-02]]],\n", + "\n", + "\n", + " [[[ 4.2056e-02, 5.3050e-03, 7.1378e-02],\n", + " [ 2.3840e-02, 3.1197e-02, 6.1677e-02],\n", + " [ 4.0494e-02, -2.5194e-02, -6.7998e-03]],\n", + "\n", + " [[-3.6769e-02, -1.3199e-02, -3.2012e-02],\n", + " [-7.7791e-03, -6.0538e-02, -1.0908e-02],\n", + " [ 3.7365e-02, 5.9451e-02, 1.1005e-02]],\n", + "\n", + " [[-3.7870e-02, 3.5183e-02, -2.5808e-02],\n", + " [-2.6994e-02, -3.9472e-02, -2.0497e-02],\n", + " [-5.8050e-02, -4.7024e-03, -1.8292e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[-7.0749e-03, -3.6702e-02, -2.4704e-02],\n", + " [ 2.5994e-02, 9.4172e-03, -2.3231e-02],\n", + " [ 2.3903e-02, 2.8629e-02, 1.9264e-02]],\n", + "\n", + " [[-6.7792e-03, -7.7618e-03, -6.9538e-02],\n", + " [ 4.0581e-02, 5.6283e-02, 2.9467e-02],\n", + " [-3.3667e-03, -6.5002e-02, -4.8670e-02]],\n", + "\n", + " [[ 1.8321e-02, -2.4471e-02, 9.8842e-03],\n", + " [ 3.9916e-03, 2.2426e-02, -3.6026e-03],\n", + " [ 3.8083e-02, 3.0053e-02, 1.9913e-02]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[ 2.5227e-03, 5.2476e-02, 2.3908e-02],\n", + " [-3.4207e-02, -3.0952e-02, -3.6447e-02],\n", + " [ 1.3552e-02, 2.7308e-02, 5.1622e-02]],\n", + "\n", + " [[-2.4969e-02, -2.1576e-02, -4.5853e-02],\n", + " [ 1.3879e-02, 6.2150e-02, 1.4080e-02],\n", + " [ 1.1243e-02, -4.2055e-02, 7.8067e-03]],\n", + "\n", + " [[-2.3491e-02, -5.1456e-03, -1.2547e-02],\n", + " [-1.6141e-02, 1.0643e-02, -2.5370e-02],\n", + " [-1.9203e-03, -5.2159e-02, 5.9291e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.1122e-02, -1.4899e-02, 2.1300e-02],\n", + " [ 3.6331e-03, 6.5192e-03, 4.7576e-02],\n", + " [ 2.3239e-02, -2.0316e-02, 2.7265e-02]],\n", + "\n", + " [[ 1.8296e-02, 4.2794e-02, 4.8028e-02],\n", + " [-8.7513e-03, -2.8567e-02, -4.4760e-02],\n", + " [ 5.6493e-03, -1.7444e-02, -2.9112e-02]],\n", + "\n", + " [[-2.2753e-03, -2.9479e-03, 2.8816e-02],\n", + " [ 1.9996e-03, 2.7665e-03, 4.4159e-02],\n", + " [-3.1448e-02, -4.9231e-02, 1.4012e-02]]],\n", + "\n", + "\n", + " [[[ 3.5169e-02, 3.1395e-02, -8.4316e-03],\n", + " [ 5.0173e-02, -1.6380e-02, 3.1588e-03],\n", + " [ 9.7549e-03, -1.6155e-03, 2.5622e-02]],\n", + "\n", + " [[ 2.9783e-03, -1.0025e-02, 4.0323e-03],\n", + " [ 2.4023e-02, 1.8000e-02, -2.0841e-02],\n", + " [ 1.7581e-03, 2.7513e-02, 3.2596e-02]],\n", + "\n", + " [[ 8.9868e-03, 1.2796e-02, -1.1676e-02],\n", + " [-8.7491e-03, 3.1234e-03, -8.2818e-03],\n", + " [ 2.3207e-02, 7.2682e-03, 4.6091e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 5.1432e-02, 3.8542e-02, 1.0702e-02],\n", + " [ 1.3778e-03, 1.6634e-02, -2.2224e-02],\n", + " [ 2.5061e-02, 3.7500e-02, -1.4531e-02]],\n", + "\n", + " [[ 3.4381e-02, -4.6145e-05, -2.3151e-02],\n", + " [ 4.0692e-02, 1.4137e-02, 3.8377e-03],\n", + " [ 6.2407e-03, -4.0193e-02, -4.6756e-02]],\n", + "\n", + " [[ 1.8292e-02, 3.8774e-03, -2.9178e-02],\n", + " [-3.8458e-02, 9.2545e-03, 2.3120e-02],\n", + " [ 1.1178e-02, 4.1157e-02, -2.7000e-03]]],\n", + "\n", + "\n", + " [[[ 4.6502e-04, 2.7367e-02, -1.8018e-02],\n", + " [-8.8800e-04, 3.2783e-02, 8.7174e-03],\n", + " [-1.3216e-02, 3.0975e-02, -1.3285e-03]],\n", + "\n", + " [[-1.9450e-02, 2.7712e-03, 2.3936e-02],\n", + " [-2.4075e-02, -2.3674e-03, -3.2410e-02],\n", + " [-1.9093e-02, -2.5896e-02, -7.2411e-03]],\n", + "\n", + " [[-1.0581e-02, 1.2100e-02, -2.1283e-02],\n", + " [-1.3377e-02, 2.6153e-04, -3.6690e-02],\n", + " [-2.5659e-02, -2.2961e-02, -4.4005e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-3.8530e-02, -1.6303e-02, -2.1558e-02],\n", + " [ 2.9950e-02, -1.8199e-02, -2.5115e-04],\n", + " [-2.1965e-02, -1.9656e-02, -2.8129e-03]],\n", + "\n", + " [[-6.2657e-03, -6.5198e-02, -3.4879e-02],\n", + " [-1.6735e-02, -2.4723e-02, -1.7410e-02],\n", + " [ 1.5641e-02, -1.8898e-02, -2.1710e-02]],\n", + "\n", + " [[ 1.3554e-02, 5.3783e-03, 2.0453e-03],\n", + " [ 2.4260e-02, -2.5494e-02, 1.7761e-02],\n", + " [-2.7732e-02, -2.5852e-02, -9.7575e-03]]]])\n", + "layer2.1.bn1.weight Parameter containing:\n", + "tensor([0.9469, 0.9599, 1.0408, 1.0384, 0.9564, 0.9625, 0.9672, 0.9335, 0.9827,\n", + " 1.0143, 1.1081, 0.9712, 0.9803, 1.0061, 1.0102, 0.9363, 1.0412, 0.9488,\n", + " 1.0403, 1.0078, 0.9938, 0.9206, 1.0497, 0.9752, 0.9866, 0.9762, 1.0210,\n", + " 0.9545, 0.9374, 1.0035, 0.9426, 0.9650, 0.9559, 0.9546, 1.1425, 0.9578,\n", + " 1.0181, 0.9881, 1.0366, 1.0273, 1.0111, 0.9961, 1.0984, 0.9692, 0.9251,\n", + " 1.0338, 0.9482, 1.0294, 0.9981, 0.9767, 0.9457, 0.9172, 1.0938, 1.0430,\n", + " 1.1256, 1.0880, 1.0017, 1.0415, 0.9938, 1.1456, 1.0022, 0.9620, 0.9321,\n", + " 1.0135, 1.0086, 0.9947, 0.8943, 1.0304, 1.1431, 0.9540, 1.0271, 0.9725,\n", + " 1.0664, 1.0102, 1.0321, 1.0395, 1.0043, 1.0035, 0.9879, 1.0392, 1.0987,\n", + " 0.9446, 0.9369, 0.9410, 1.0415, 1.0260, 0.9566, 0.9980, 0.8972, 1.0363,\n", + " 0.9065, 0.9226, 0.9759, 1.0496, 0.9694, 0.9439, 0.9542, 0.9537, 1.0676,\n", + " 1.0573, 1.0023, 1.0358, 1.0215, 0.8992, 0.9458, 0.9192, 0.8680, 1.0038,\n", + " 1.0073, 0.9660, 1.0339, 1.0393, 0.9894, 1.0238, 1.0558, 0.9541, 0.9365,\n", + " 1.0021, 0.9507, 1.0065, 1.0754, 0.9610, 0.9913, 0.9653, 1.0200, 1.0461,\n", + " 1.0000, 0.9818])\n", + "layer2.1.bn1.bias Parameter containing:\n", + "tensor([-0.0673, -0.0338, -0.0798, -0.1658, -0.0523, -0.0830, -0.1435, -0.0267,\n", + " -0.0379, -0.0789, -0.1152, -0.0774, -0.0791, -0.0604, -0.0548, -0.0850,\n", + " -0.0389, -0.1077, -0.0098, -0.0761, -0.0727, -0.1523, -0.0636, -0.0924,\n", + " -0.0278, -0.0917, -0.0908, -0.0543, -0.0616, -0.0411, -0.0775, -0.0810,\n", + " -0.0747, -0.0256, -0.1169, -0.0433, -0.0451, -0.1127, -0.0723, -0.1305,\n", + " -0.1152, -0.0542, -0.0730, -0.0952, -0.1198, -0.0611, -0.0426, -0.0912,\n", + " -0.0861, -0.0774, -0.1032, -0.0950, -0.0548, -0.0913, -0.0961, -0.0706,\n", + " -0.0829, -0.0349, -0.0849, -0.1585, -0.0403, -0.0749, -0.0919, -0.0914,\n", + " -0.0395, -0.0244, -0.0733, -0.0359, -0.1861, -0.0820, -0.0695, -0.0439,\n", + " -0.0523, -0.0789, -0.0342, -0.1350, -0.0378, -0.0595, -0.0199, -0.0518,\n", + " -0.0643, -0.0442, -0.0830, -0.0320, -0.0403, -0.0410, -0.1096, -0.0638,\n", + " -0.0641, -0.0460, -0.0740, -0.0923, -0.0401, -0.0058, -0.0803, -0.0584,\n", + " -0.0683, -0.0805, -0.0518, -0.0837, -0.0866, -0.1192, -0.0260, -0.0665,\n", + " -0.0630, -0.1184, -0.0660, -0.0332, -0.0414, -0.1052, -0.0006, -0.0555,\n", + " -0.0702, -0.0166, -0.1229, -0.0376, -0.0829, -0.0606, -0.0979, -0.0913,\n", + " -0.1539, -0.0495, -0.0757, -0.1031, -0.0495, -0.0982, -0.1310, -0.0706])\n", + "layer2.1.conv2.weight Parameter containing:\n", + "tensor([[[[ 9.0620e-03, 1.3316e-02, 1.0495e-02],\n", + " [ 2.3512e-02, -1.3791e-03, 1.4756e-02],\n", + " [ 3.8027e-02, 1.9696e-02, -2.3292e-03]],\n", + "\n", + " [[-1.1708e-02, 1.3712e-02, -5.3562e-03],\n", + " [-3.9382e-02, -3.5332e-02, 2.0547e-02],\n", + " [ 6.4936e-03, 9.7651e-04, -3.3298e-03]],\n", + "\n", + " [[-2.7199e-02, -3.8818e-02, -2.6458e-02],\n", + " [-2.8592e-02, -4.5702e-02, -5.0769e-03],\n", + " [ 2.2421e-02, -1.0246e-02, 6.8318e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[-2.4165e-02, -6.6939e-03, 2.7444e-02],\n", + " [-1.9542e-02, -4.5656e-02, 6.5853e-03],\n", + " [-3.9443e-02, -2.1822e-03, -3.6982e-02]],\n", + "\n", + " [[ 1.1827e-02, -1.2058e-02, -6.3863e-03],\n", + " [-3.9754e-02, -7.1089e-02, -6.8766e-02],\n", + " [-1.2733e-02, -3.8078e-02, -2.0078e-02]],\n", + "\n", + " [[ 1.2081e-02, 6.0684e-02, -1.3995e-02],\n", + " [-2.9236e-03, 2.1464e-03, -2.5115e-02],\n", + " [ 4.3272e-02, 7.2460e-02, 1.9598e-02]]],\n", + "\n", + "\n", + " [[[ 1.3322e-02, 3.8556e-02, 2.1243e-02],\n", + " [ 1.8053e-02, 2.3932e-02, -1.5429e-02],\n", + " [ 1.7548e-02, -7.2034e-03, 7.1181e-03]],\n", + "\n", + " [[-5.8165e-03, -4.9490e-04, -3.5662e-02],\n", + " [-1.0573e-02, -1.0955e-02, -3.1837e-02],\n", + " [ 1.2337e-02, -1.6183e-02, 3.5962e-04]],\n", + "\n", + " [[-1.6080e-02, 9.4185e-03, -1.9492e-02],\n", + " [ 5.1737e-02, 4.8596e-02, 5.4428e-02],\n", + " [-7.2937e-03, 5.3389e-03, 3.1466e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-4.3382e-02, -5.6812e-02, -2.2632e-02],\n", + " [-3.3871e-03, -1.2200e-02, -6.9654e-03],\n", + " [ 1.3559e-02, 1.6472e-02, -1.0585e-02]],\n", + "\n", + " [[ 1.4523e-03, 1.7408e-02, 6.0656e-03],\n", + " [ 3.9184e-02, 1.5487e-02, 1.1933e-02],\n", + " [ 1.2125e-02, -6.5983e-03, 4.9426e-02]],\n", + "\n", + " [[ 7.4308e-03, -1.3106e-02, 1.6033e-02],\n", + " [ 2.4728e-02, -7.3725e-03, 2.4851e-02],\n", + " [-9.9803e-03, -4.9287e-02, 7.5864e-03]]],\n", + "\n", + "\n", + " [[[-4.9769e-02, -1.4696e-02, -2.3030e-02],\n", + " [-4.7186e-02, -6.6357e-02, -2.2492e-02],\n", + " [-5.2339e-02, -4.6180e-02, -5.6023e-03]],\n", + "\n", + " [[ 3.9187e-03, 1.5691e-02, 7.6873e-03],\n", + " [-3.1093e-02, -4.1031e-02, 2.3911e-02],\n", + " [-2.7188e-02, 7.6569e-03, -2.5654e-02]],\n", + "\n", + " [[ 1.9282e-02, -1.3629e-02, 1.8839e-02],\n", + " [-5.8584e-03, 2.7730e-02, 7.9821e-03],\n", + " [ 8.8666e-03, -3.1551e-02, 1.9405e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 2.4015e-02, -3.3378e-02, -5.3954e-04],\n", + " [-4.5346e-03, -2.1869e-03, 3.7101e-02],\n", + " [-3.1134e-02, -1.5882e-02, -2.8024e-02]],\n", + "\n", + " [[ 2.4350e-02, 1.7124e-02, -3.3614e-02],\n", + " [ 2.1010e-02, -2.8533e-02, -2.6490e-02],\n", + " [ 3.0194e-02, -3.2890e-02, 4.3761e-03]],\n", + "\n", + " [[-3.4936e-02, -1.5606e-02, 1.6964e-04],\n", + " [-3.0659e-02, -8.4506e-03, 2.2456e-02],\n", + " [ 2.6422e-02, 2.0289e-02, 9.1066e-03]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[-2.6236e-02, -3.4802e-02, -5.1656e-02],\n", + " [-2.8055e-02, -8.5368e-03, -2.1985e-02],\n", + " [-2.5813e-02, -1.0077e-02, -2.8292e-02]],\n", + "\n", + " [[-4.4758e-02, -5.8627e-02, -3.4626e-03],\n", + " [ 4.3765e-03, 2.5680e-02, 2.2966e-02],\n", + " [-1.1803e-02, -1.6649e-02, -1.0021e-02]],\n", + "\n", + " [[ 8.3297e-03, -1.7979e-03, -2.6646e-02],\n", + " [ 5.8005e-03, -4.2564e-02, 2.8873e-02],\n", + " [-1.8539e-02, -3.3966e-02, 1.7152e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-4.1487e-02, -3.3227e-02, -4.9985e-03],\n", + " [-1.9580e-02, 7.5260e-03, 2.4756e-02],\n", + " [ 2.8643e-03, -1.1592e-02, -1.8972e-02]],\n", + "\n", + " [[-5.8783e-02, -2.4775e-02, 6.1720e-04],\n", + " [-3.2235e-02, -5.0035e-03, 1.1334e-02],\n", + " [ 1.2860e-02, 2.2570e-03, 3.1842e-02]],\n", + "\n", + " [[ 5.6097e-03, -1.5921e-02, -1.5595e-02],\n", + " [-2.3235e-02, -2.1884e-02, 1.2557e-03],\n", + " [ 3.4493e-03, 7.8864e-03, -2.4334e-02]]],\n", + "\n", + "\n", + " [[[-3.9457e-03, 8.6690e-03, -2.3602e-02],\n", + " [ 8.4130e-03, -7.7556e-03, 2.1163e-02],\n", + " [-3.4924e-05, 1.1466e-02, 4.7752e-03]],\n", + "\n", + " [[-7.1314e-02, -2.9460e-02, -5.2409e-02],\n", + " [-2.2031e-02, -3.2676e-02, -2.9026e-02],\n", + " [-2.9801e-02, -3.6596e-02, -5.7815e-02]],\n", + "\n", + " [[-3.4363e-02, -1.0897e-03, -5.6942e-03],\n", + " [ 8.1054e-03, 4.0558e-02, 6.3196e-03],\n", + " [ 5.4653e-02, -4.7664e-03, 1.6212e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 2.5112e-02, 3.1311e-02, -1.5590e-02],\n", + " [ 7.0375e-02, 6.6667e-02, 6.6758e-02],\n", + " [-2.0276e-02, 2.5930e-02, 7.1680e-03]],\n", + "\n", + " [[ 5.9966e-04, -3.9549e-03, 3.0694e-02],\n", + " [ 2.4426e-02, 1.7874e-02, -2.1178e-02],\n", + " [-2.7890e-03, -7.3431e-03, -2.7056e-02]],\n", + "\n", + " [[-5.6897e-02, 1.9620e-02, -5.0968e-02],\n", + " [-2.3845e-02, 2.4067e-03, -3.8237e-02],\n", + " [-2.7367e-02, -4.8455e-02, -4.9325e-02]]],\n", + "\n", + "\n", + " [[[-2.0665e-02, -2.9083e-02, -2.0078e-02],\n", + " [-1.2909e-02, -4.0212e-02, 1.2505e-02],\n", + " [ 3.1596e-02, 2.8582e-02, 5.2544e-02]],\n", + "\n", + " [[-1.4806e-02, -3.6197e-02, -3.2830e-03],\n", + " [-3.7018e-02, 3.2018e-02, 4.3647e-02],\n", + " [ 1.0543e-02, 4.0195e-03, 5.5371e-02]],\n", + "\n", + " [[-1.9200e-02, -6.9304e-02, -5.1689e-02],\n", + " [-1.2329e-03, -3.0719e-02, -5.0884e-03],\n", + " [ 1.5455e-02, -2.5330e-03, 4.8852e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-6.2182e-03, 1.1251e-02, -3.3474e-02],\n", + " [-1.6241e-02, -1.3244e-02, -7.0312e-02],\n", + " [ 2.9194e-02, -9.2372e-04, 1.0911e-02]],\n", + "\n", + " [[-1.2244e-02, -1.1533e-02, 2.8962e-02],\n", + " [ 3.9788e-02, 2.6832e-02, 1.3362e-02],\n", + " [ 3.6984e-02, -2.2891e-02, -3.2071e-02]],\n", + "\n", + " [[ 1.0560e-02, 1.4959e-02, 3.4440e-02],\n", + " [-3.0914e-03, 7.2294e-04, 4.0611e-03],\n", + " [ 5.2256e-03, 2.2626e-02, 5.8988e-03]]]])\n", + "layer2.1.bn2.weight Parameter containing:\n", + "tensor([0.9669, 0.9217, 0.9698, 1.0776, 0.8574, 0.9466, 0.9596, 0.8409, 0.9237,\n", + " 0.9604, 0.9425, 0.9107, 0.9738, 0.9234, 1.0417, 0.9799, 0.8988, 0.9993,\n", + " 0.8893, 0.9971, 0.7985, 0.9634, 0.9848, 0.8414, 0.9464, 0.8792, 1.0549,\n", + " 0.8567, 0.9816, 0.8909, 0.8655, 0.9836, 0.9409, 0.9513, 1.0185, 1.0638,\n", + " 0.9734, 0.9440, 0.9009, 0.9552, 0.9628, 0.9267, 0.9959, 0.8419, 1.0082,\n", + " 0.9657, 0.9144, 1.0230, 0.9963, 0.8837, 0.8098, 0.9780, 1.0287, 0.9983,\n", + " 1.0574, 0.8762, 0.9858, 0.9770, 0.9042, 0.8879, 0.9769, 0.9435, 0.8430,\n", + " 0.8512, 0.8121, 0.9104, 1.1214, 1.0722, 0.9728, 1.0472, 0.8562, 0.9894,\n", + " 0.9725, 1.0077, 0.9654, 1.0169, 0.8221, 0.7715, 0.9867, 0.9983, 0.9642,\n", + " 0.7586, 0.8606, 0.9114, 0.8180, 0.9168, 0.8913, 1.0396, 0.9633, 1.0078,\n", + " 1.0352, 0.8088, 0.9190, 0.9156, 0.9024, 0.9893, 0.9497, 0.9773, 1.0055,\n", + " 0.9963, 0.9586, 0.9777, 1.0456, 0.9290, 1.0590, 0.8944, 1.0062, 0.9626,\n", + " 0.9110, 0.9895, 0.9059, 1.0135, 0.8311, 0.9899, 0.8107, 0.9532, 1.0057,\n", + " 0.9510, 0.9788, 1.0119, 0.8398, 0.9277, 0.9727, 0.9752, 0.9373, 0.9794,\n", + " 0.8550, 0.9457])\n", + "layer2.1.bn2.bias Parameter containing:\n", + "tensor([-0.0631, -0.1395, -0.0694, -0.1654, -0.0476, -0.1273, -0.0678, -0.1002,\n", + " -0.0325, -0.0740, -0.1281, -0.1002, -0.0175, -0.0615, -0.0975, -0.0528,\n", + " -0.0580, -0.0978, -0.0452, -0.0390, -0.0917, -0.1588, -0.0876, -0.0832,\n", + " -0.0911, -0.0724, -0.1224, -0.0939, -0.0113, -0.0898, -0.0591, -0.0562,\n", + " -0.0751, -0.0556, -0.0274, -0.1637, -0.1727, -0.0791, -0.0973, -0.0137,\n", + " -0.0436, -0.0590, -0.0689, -0.0996, -0.0387, -0.0633, -0.1175, -0.0700,\n", + " -0.0791, -0.1303, -0.0326, -0.1016, -0.0775, -0.0637, -0.1030, -0.0783,\n", + " -0.0339, -0.1278, -0.0296, -0.0876, -0.0457, -0.0382, -0.0050, -0.0750,\n", + " -0.0725, -0.0017, -0.0492, -0.0742, -0.0510, -0.0604, -0.0823, -0.0722,\n", + " -0.0629, -0.0963, -0.0659, -0.0422, -0.1470, -0.0926, -0.0610, -0.0476,\n", + " -0.0626, -0.1338, -0.0497, -0.0442, -0.0355, -0.0271, -0.0796, -0.1819,\n", + " -0.0443, -0.1882, -0.0896, -0.0531, -0.1230, -0.0728, -0.1342, -0.0824,\n", + " -0.0452, -0.0723, -0.0767, -0.0520, -0.0523, -0.0455, -0.0583, -0.1160,\n", + " -0.0359, -0.0998, -0.0496, -0.0457, -0.0740, -0.0542, -0.0391, -0.0431,\n", + " -0.1116, -0.1315, -0.1144, -0.0804, -0.0101, -0.0375, -0.0806, -0.0234,\n", + " -0.0276, -0.0227, -0.0626, 0.0198, -0.0983, -0.0464, -0.1192, -0.0449])\n", + "layer3.0.conv1.weight Parameter containing:\n", + "tensor([[[[-2.8938e-03, -3.0210e-02, 2.4601e-02],\n", + " [ 6.5808e-03, -1.7178e-02, 2.3446e-03],\n", + " [-1.6810e-02, 1.1131e-02, -3.5020e-03]],\n", + "\n", + " [[ 2.4755e-02, -3.2800e-02, 3.2009e-02],\n", + " [ 1.9366e-02, -1.1100e-03, 4.9418e-03],\n", + " [ 2.5708e-02, -1.1415e-02, 1.0658e-02]],\n", + "\n", + " [[ 1.6535e-02, -5.5367e-03, 2.4412e-02],\n", + " [ 4.0183e-03, -2.3662e-02, -4.1649e-02],\n", + " [ 1.5828e-02, -3.0302e-02, -2.3527e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 3.0539e-03, -3.9847e-03, -3.5681e-02],\n", + " [ 5.4152e-03, -1.0178e-02, 1.4431e-02],\n", + " [-7.5026e-02, -5.5198e-02, 4.9839e-02]],\n", + "\n", + " [[ 3.7400e-02, 6.3145e-03, 4.9693e-02],\n", + " [ 2.0811e-03, 4.8901e-03, -1.7024e-02],\n", + " [ 3.5467e-02, -4.9344e-03, 1.6703e-02]],\n", + "\n", + " [[-4.6327e-03, 2.8507e-02, 4.8495e-02],\n", + " [ 1.7117e-02, -2.4190e-03, 3.9677e-02],\n", + " [-1.7740e-02, 1.2457e-02, 6.9044e-02]]],\n", + "\n", + "\n", + " [[[ 2.1504e-02, 5.2210e-02, 1.2908e-02],\n", + " [-2.6065e-02, -1.5139e-02, 3.4843e-02],\n", + " [-1.7559e-02, -5.9843e-02, 1.8697e-02]],\n", + "\n", + " [[ 3.6117e-02, -2.0767e-02, -1.6192e-02],\n", + " [ 4.7451e-02, 2.7762e-02, -2.7552e-02],\n", + " [-2.8505e-03, -5.8121e-02, -5.2318e-03]],\n", + "\n", + " [[-3.4431e-02, -2.0220e-02, 1.0149e-02],\n", + " [-1.5946e-02, 9.6492e-03, 1.6029e-04],\n", + " [-3.6388e-02, 1.3628e-02, -7.8374e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[-4.6442e-02, -2.9462e-02, 5.9477e-05],\n", + " [ 7.8429e-03, 2.8139e-02, 3.6842e-02],\n", + " [ 8.4957e-04, 1.0251e-03, 1.3028e-02]],\n", + "\n", + " [[-8.7456e-03, 4.2044e-02, -2.0930e-02],\n", + " [-2.2964e-02, 8.5618e-03, -3.3057e-02],\n", + " [ 1.2229e-02, -2.6854e-02, -1.5350e-03]],\n", + "\n", + " [[-2.5714e-02, 8.3247e-03, -1.1987e-02],\n", + " [-1.9772e-02, -9.5624e-03, 4.1201e-02],\n", + " [ 2.0758e-02, 1.6308e-02, 1.8109e-02]]],\n", + "\n", + "\n", + " [[[ 2.0949e-02, 1.8669e-02, -1.4822e-02],\n", + " [ 3.7938e-02, -3.4789e-02, 1.0094e-02],\n", + " [-2.0915e-03, -4.5284e-03, 1.6016e-02]],\n", + "\n", + " [[ 2.8699e-02, -3.7645e-02, -5.9033e-03],\n", + " [ 5.7764e-02, -7.0500e-04, 5.1310e-02],\n", + " [-2.5892e-02, 3.4094e-02, -1.7359e-02]],\n", + "\n", + " [[-2.8955e-02, -1.7045e-02, -9.4707e-03],\n", + " [-9.5516e-03, -8.9949e-03, 1.2905e-02],\n", + " [-3.1667e-02, 1.7581e-02, 2.1029e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-8.9293e-04, 8.4387e-03, 1.9757e-02],\n", + " [-2.4149e-02, -1.4277e-02, 5.7769e-03],\n", + " [-3.3923e-02, 2.1901e-02, -1.3470e-02]],\n", + "\n", + " [[ 2.6607e-02, -1.6824e-03, 1.5170e-02],\n", + " [ 3.2252e-02, -2.1156e-02, 7.6330e-04],\n", + " [ 1.6243e-02, 1.0417e-02, -3.2937e-02]],\n", + "\n", + " [[-8.1500e-03, 6.0820e-03, -2.3481e-02],\n", + " [ 7.4835e-03, 2.2905e-02, 2.9158e-02],\n", + " [-5.4513e-02, -8.0323e-02, 1.0160e-02]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[-1.7320e-02, -1.1638e-02, 3.1909e-02],\n", + " [-8.7472e-03, -1.5413e-02, -1.9191e-02],\n", + " [-1.7741e-03, 4.8668e-03, 9.9375e-03]],\n", + "\n", + " [[ 5.7990e-03, 2.6241e-03, 4.7727e-03],\n", + " [-6.8299e-03, 1.6355e-02, 1.5993e-02],\n", + " [-3.5001e-02, 4.5646e-02, -1.3821e-02]],\n", + "\n", + " [[-2.1266e-02, -5.8349e-03, -3.3456e-03],\n", + " [ 1.5735e-02, -5.0056e-02, -3.8845e-02],\n", + " [ 1.7596e-02, -1.8245e-03, 1.6151e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.0309e-02, -1.3287e-02, 3.8571e-03],\n", + " [ 6.3270e-03, -3.5081e-02, -2.1760e-02],\n", + " [ 2.4283e-02, -3.5668e-02, 1.6085e-02]],\n", + "\n", + " [[ 2.4923e-02, 1.6274e-02, 5.0354e-02],\n", + " [ 2.7379e-03, 1.3269e-02, 3.0485e-02],\n", + " [-4.1538e-02, -1.9868e-02, -4.0987e-02]],\n", + "\n", + " [[-6.5920e-03, 9.8657e-03, -1.8862e-02],\n", + " [ 2.4254e-02, -2.3374e-02, -2.2944e-02],\n", + " [-5.6579e-03, 5.4663e-03, 1.4742e-02]]],\n", + "\n", + "\n", + " [[[ 1.6981e-03, 1.7682e-02, 2.2609e-02],\n", + " [ 1.6195e-02, -3.6837e-03, -9.3915e-03],\n", + " [ 1.2224e-03, -2.6607e-02, -2.7341e-03]],\n", + "\n", + " [[ 2.0577e-02, -2.1010e-02, 4.7610e-02],\n", + " [ 1.2225e-02, -1.6388e-02, -4.9970e-03],\n", + " [ 2.0274e-02, -1.1799e-02, 3.1679e-02]],\n", + "\n", + " [[-2.8022e-02, -2.3188e-02, -1.6347e-02],\n", + " [-1.5444e-02, 3.4093e-02, 1.0499e-02],\n", + " [ 9.3748e-03, -1.0355e-02, 5.8578e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-2.1788e-02, 1.5121e-02, 3.4046e-02],\n", + " [-1.4173e-02, -9.8143e-03, -1.2345e-02],\n", + " [ 3.0159e-02, 3.3297e-02, 1.0358e-02]],\n", + "\n", + " [[-2.5139e-03, 2.3068e-03, 2.6721e-02],\n", + " [-2.3739e-02, 2.3444e-02, -1.1328e-02],\n", + " [ 6.5631e-03, 4.0345e-03, -1.9840e-02]],\n", + "\n", + " [[ 1.7692e-02, -2.0572e-02, -2.8111e-04],\n", + " [-1.2023e-02, 2.4996e-02, 2.4255e-02],\n", + " [-1.9400e-02, 3.1629e-02, -2.4052e-02]]],\n", + "\n", + "\n", + " [[[-2.4696e-02, 2.3890e-02, 1.3941e-03],\n", + " [ 6.2046e-03, 7.4803e-03, -1.8126e-02],\n", + " [-1.8920e-02, -3.0955e-02, -8.0770e-03]],\n", + "\n", + " [[ 7.2004e-03, -4.3122e-02, -3.4520e-02],\n", + " [ 6.0169e-03, -1.7975e-02, 1.6013e-02],\n", + " [-1.2754e-02, 7.1421e-03, -2.0949e-02]],\n", + "\n", + " [[-9.5524e-03, -1.2836e-02, -3.4424e-03],\n", + " [ 5.2468e-04, -2.8843e-02, -6.9024e-03],\n", + " [-3.4978e-03, 7.0781e-02, 6.6465e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.7539e-02, -2.4638e-02, -5.0424e-03],\n", + " [ 6.8030e-03, 2.3334e-02, 1.4998e-02],\n", + " [ 2.7096e-02, 3.0525e-02, -7.3968e-04]],\n", + "\n", + " [[-3.0516e-02, -3.9977e-02, -1.0588e-02],\n", + " [-2.5134e-03, 8.7634e-05, -2.6856e-02],\n", + " [ 1.7962e-02, 6.0497e-03, -6.6210e-05]],\n", + "\n", + " [[ 6.9240e-03, -1.9080e-02, 1.1087e-02],\n", + " [ 7.3493e-03, 4.4861e-03, 1.1068e-02],\n", + " [-2.2379e-02, 1.1944e-02, -1.2094e-02]]]])\n", + "layer3.0.bn1.weight Parameter containing:\n", + "tensor([0.9061, 1.0347, 1.0760, 1.0512, 0.9397, 1.0302, 0.9768, 1.0078, 1.0257,\n", + " 0.9997, 1.0039, 1.0391, 1.1620, 1.0561, 0.9218, 1.0240, 0.9866, 1.0253,\n", + " 0.9590, 0.9844, 1.0767, 1.0360, 1.0178, 1.0186, 1.0044, 1.0089, 1.0140,\n", + " 1.0252, 1.0301, 1.0595, 1.0679, 0.9882, 0.9082, 1.0731, 1.0115, 0.9792,\n", + " 0.9640, 1.0047, 0.8549, 1.0938, 0.8506, 0.9995, 0.9996, 0.9056, 0.9211,\n", + " 0.8706, 0.9308, 1.0217, 1.0128, 1.0225, 0.9260, 0.9987, 1.0239, 1.0860,\n", + " 0.9033, 0.9618, 1.0172, 1.0064, 1.0140, 1.0230, 0.8893, 1.0679, 0.9589,\n", + " 0.8701, 1.0907, 0.9620, 1.0249, 1.0191, 1.0027, 1.0092, 0.9996, 1.0263,\n", + " 0.9019, 0.9809, 1.0138, 0.9840, 1.0444, 1.0324, 1.0631, 0.9025, 1.0824,\n", + " 0.9573, 0.9810, 1.0726, 1.0022, 1.0539, 1.0418, 1.0293, 1.0278, 0.9352,\n", + " 0.9852, 0.9486, 1.0826, 1.0651, 1.0367, 1.0072, 1.0827, 1.0474, 1.0641,\n", + " 1.0313, 0.9774, 0.9143, 0.9635, 1.0805, 0.9259, 1.0173, 0.8997, 1.0448,\n", + " 0.9693, 1.0160, 1.0576, 0.9509, 1.0058, 0.9660, 1.0458, 1.0454, 0.8917,\n", + " 0.8641, 1.0203, 1.0884, 0.8875, 0.8684, 1.0861, 1.0081, 0.9273, 1.0426,\n", + " 1.1006, 0.9325, 1.0276, 0.9169, 0.9800, 1.0925, 1.0594, 0.8335, 1.0556,\n", + " 1.0944, 1.0630, 0.9928, 1.0146, 0.9829, 0.8950, 0.9355, 0.9619, 1.0186,\n", + " 0.9641, 1.0488, 0.9760, 1.0066, 0.9680, 1.0791, 1.0448, 0.9901, 1.0138,\n", + " 0.9991, 1.0693, 1.0020, 1.0520, 0.9433, 1.0379, 0.9295, 0.9748, 1.0369,\n", + " 1.0888, 1.0418, 0.9144, 0.9162, 0.9564, 0.9888, 0.8738, 0.9832, 0.9302,\n", + " 0.9893, 0.8909, 0.9891, 0.8922, 0.9782, 0.9909, 1.0133, 0.9086, 0.9760,\n", + " 0.9767, 1.0158, 0.9915, 1.0143, 0.9966, 0.9907, 0.9776, 0.9047, 1.0013,\n", + " 0.9864, 1.0067, 0.9248, 0.9688, 0.9908, 0.8471, 0.8832, 1.0943, 1.0505,\n", + " 1.0315, 1.0334, 0.8834, 0.9405, 1.0267, 0.9765, 1.0431, 0.9543, 0.9125,\n", + " 1.0020, 1.0517, 0.9656, 1.0035, 0.9053, 0.9510, 1.0689, 1.0049, 0.9757,\n", + " 1.0286, 0.8977, 1.0217, 1.0082, 1.0706, 1.0439, 0.9330, 1.0338, 1.1586,\n", + " 0.9350, 0.9765, 0.8640, 1.0371, 1.0314, 0.9612, 0.9800, 1.0228, 1.0002,\n", + " 1.0706, 1.0041, 1.0228, 0.9090, 0.9900, 1.0428, 0.9930, 1.0704, 0.9801,\n", + " 0.9030, 0.9475, 0.9210, 1.0132, 0.9513, 1.0110, 1.0776, 1.0226, 1.0207,\n", + " 0.9829, 0.9680, 0.9230, 1.0265])\n", + "layer3.0.bn1.bias Parameter containing:\n", + "tensor([-0.1787, -0.0670, -0.1104, -0.0911, -0.0801, -0.1031, -0.1063, -0.1099,\n", + " -0.0839, -0.1178, -0.0499, -0.1453, -0.0722, -0.0695, -0.1185, -0.1001,\n", + " -0.0961, -0.1003, -0.0839, -0.0936, -0.0897, -0.0942, -0.0798, -0.0622,\n", + " -0.0943, -0.0427, -0.0825, -0.0613, -0.0811, -0.0732, -0.0727, -0.0806,\n", + " -0.1232, -0.0963, -0.0856, -0.1087, -0.0833, -0.0848, -0.1352, -0.0787,\n", + " -0.1479, -0.0362, -0.1301, -0.1184, -0.1073, -0.1560, -0.0911, -0.0814,\n", + " -0.0877, -0.1103, -0.1217, -0.0934, -0.1284, -0.0638, -0.1504, -0.0730,\n", + " -0.0922, -0.1440, -0.0523, -0.0886, -0.1526, -0.0309, -0.1076, -0.1189,\n", + " -0.0806, -0.1403, -0.0696, -0.0722, -0.0710, -0.0701, -0.1449, -0.1027,\n", + " -0.1150, -0.0578, -0.0902, -0.0958, -0.0891, -0.0910, -0.0921, -0.0914,\n", + " -0.0470, -0.0853, -0.1400, -0.0813, -0.0946, -0.0688, -0.0459, -0.0683,\n", + " -0.0741, -0.0944, -0.1044, -0.0858, -0.0640, -0.0608, -0.1008, -0.0600,\n", + " -0.0892, -0.1165, -0.1024, -0.1160, -0.1608, -0.1190, -0.1343, -0.1135,\n", + " -0.1093, -0.0494, -0.0885, -0.0808, -0.0432, -0.0850, -0.0874, -0.0886,\n", + " -0.0916, -0.0610, -0.1053, -0.0913, -0.1357, -0.0879, -0.1158, -0.0854,\n", + " -0.1429, -0.1255, -0.0469, -0.0949, -0.0887, -0.1348, -0.0951, -0.1010,\n", + " -0.0553, -0.1506, -0.0773, -0.0961, -0.1000, -0.1271, -0.0111, -0.0747,\n", + " -0.0983, -0.0901, -0.1035, -0.0523, -0.1340, -0.1469, -0.1036, -0.1320,\n", + " -0.0788, -0.0828, -0.0806, -0.0947, -0.1260, -0.1188, -0.0600, -0.0903,\n", + " -0.1103, -0.0842, -0.0682, -0.0880, -0.0689, -0.1146, -0.0663, -0.1001,\n", + " -0.0970, -0.1130, -0.0439, -0.0618, -0.1329, -0.0918, -0.1008, -0.1022,\n", + " -0.1366, -0.0886, -0.0923, -0.1017, -0.0743, -0.1107, -0.1248, -0.0922,\n", + " -0.0982, -0.1354, -0.1418, -0.0818, -0.0899, -0.0553, -0.0871, -0.1282,\n", + " -0.0706, -0.0805, -0.1150, -0.1173, -0.0729, -0.0961, -0.0577, -0.0929,\n", + " -0.0875, -0.1348, -0.1249, -0.1365, -0.0689, -0.0694, -0.0860, -0.0885,\n", + " -0.1044, -0.0753, -0.0606, -0.1009, -0.1671, -0.1205, -0.1800, -0.0734,\n", + " -0.0862, -0.1019, -0.0975, -0.1568, -0.1117, -0.1548, -0.0654, -0.1274,\n", + " -0.0855, -0.1067, -0.1046, -0.0634, -0.0671, -0.0478, -0.1219, -0.0808,\n", + " -0.0356, -0.0614, -0.0897, -0.1689, -0.1237, -0.1068, -0.0926, -0.0682,\n", + " -0.0673, -0.0980, -0.0866, -0.1098, -0.0595, -0.1140, -0.0806, -0.1057,\n", + " -0.1030, -0.0569, -0.0721, -0.1279, -0.1018, -0.1139, -0.0615, -0.1067,\n", + " -0.0735, -0.1219, -0.1145, -0.0392, -0.1084, -0.0808, -0.0675, -0.0412])\n", + "layer3.0.conv2.weight Parameter containing:\n", + "tensor([[[[-0.0214, -0.0107, -0.0109],\n", + " [-0.0276, -0.0046, -0.0019],\n", + " [-0.0285, 0.0086, 0.0139]],\n", + "\n", + " [[ 0.0022, 0.0144, -0.0277],\n", + " [ 0.0292, 0.0160, -0.0108],\n", + " [ 0.0132, 0.0574, 0.0289]],\n", + "\n", + " [[-0.0278, 0.0145, -0.0049],\n", + " [ 0.0079, 0.0042, 0.0348],\n", + " [ 0.0134, 0.0028, -0.0223]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0053, -0.0094, 0.0096],\n", + " [ 0.0040, -0.0227, -0.0190],\n", + " [ 0.0357, 0.0016, 0.0191]],\n", + "\n", + " [[ 0.0082, -0.0114, -0.0138],\n", + " [ 0.0073, -0.0244, -0.0040],\n", + " [ 0.0254, 0.0098, 0.0303]],\n", + "\n", + " [[-0.0097, 0.0010, -0.0121],\n", + " [-0.0180, 0.0130, -0.0219],\n", + " [ 0.0072, 0.0129, -0.0006]]],\n", + "\n", + "\n", + " [[[ 0.0099, 0.0353, 0.0159],\n", + " [-0.0033, -0.0035, -0.0017],\n", + " [-0.0143, -0.0117, -0.0098]],\n", + "\n", + " [[ 0.0281, -0.0012, 0.0075],\n", + " [ 0.0288, 0.0100, -0.0020],\n", + " [ 0.0016, 0.0260, -0.0053]],\n", + "\n", + " [[ 0.0293, -0.0065, -0.0090],\n", + " [-0.0288, -0.0237, -0.0069],\n", + " [-0.0452, -0.0268, -0.0010]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0031, -0.0219, 0.0129],\n", + " [-0.0017, 0.0037, 0.0417],\n", + " [-0.0069, 0.0084, -0.0066]],\n", + "\n", + " [[ 0.0249, 0.0099, 0.0340],\n", + " [ 0.0362, 0.0196, 0.0393],\n", + " [ 0.0078, 0.0027, 0.0078]],\n", + "\n", + " [[ 0.0255, 0.0275, 0.0331],\n", + " [ 0.0168, 0.0004, 0.0083],\n", + " [-0.0081, -0.0078, -0.0197]]],\n", + "\n", + "\n", + " [[[ 0.0481, 0.0492, 0.0372],\n", + " [ 0.0172, -0.0387, -0.0074],\n", + " [-0.0222, -0.0242, -0.0616]],\n", + "\n", + " [[ 0.0019, 0.0307, 0.0023],\n", + " [ 0.0091, 0.0094, -0.0030],\n", + " [ 0.0219, -0.0130, -0.0093]],\n", + "\n", + " [[ 0.0172, 0.0382, 0.0198],\n", + " [ 0.0220, 0.0401, 0.0123],\n", + " [-0.0001, -0.0152, 0.0177]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0011, 0.0025, 0.0049],\n", + " [-0.0053, -0.0176, -0.0262],\n", + " [ 0.0189, 0.0165, -0.0020]],\n", + "\n", + " [[ 0.0171, 0.0041, -0.0016],\n", + " [-0.0089, -0.0153, -0.0008],\n", + " [ 0.0189, 0.0374, 0.0036]],\n", + "\n", + " [[ 0.0264, -0.0059, -0.0073],\n", + " [-0.0090, 0.0259, 0.0117],\n", + " [ 0.0016, -0.0017, 0.0198]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[ 0.0057, -0.0139, -0.0260],\n", + " [-0.0335, -0.0236, -0.0194],\n", + " [-0.0076, 0.0078, -0.0027]],\n", + "\n", + " [[ 0.0255, -0.0182, -0.0244],\n", + " [ 0.0128, -0.0183, -0.0070],\n", + " [ 0.0053, -0.0110, 0.0198]],\n", + "\n", + " [[ 0.0052, -0.0080, 0.0055],\n", + " [-0.0020, -0.0061, -0.0151],\n", + " [ 0.0397, 0.0304, 0.0184]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0044, -0.0023, -0.0192],\n", + " [ 0.0090, -0.0075, 0.0229],\n", + " [-0.0066, 0.0129, 0.0241]],\n", + "\n", + " [[-0.0238, -0.0174, 0.0229],\n", + " [ 0.0057, -0.0127, -0.0277],\n", + " [-0.0239, -0.0150, 0.0047]],\n", + "\n", + " [[-0.0127, -0.0376, -0.0095],\n", + " [-0.0089, -0.0277, -0.0007],\n", + " [ 0.0142, 0.0022, -0.0197]]],\n", + "\n", + "\n", + " [[[ 0.0211, -0.0185, -0.0099],\n", + " [-0.0169, 0.0075, 0.0242],\n", + " [-0.0085, -0.0022, 0.0176]],\n", + "\n", + " [[-0.0230, -0.0001, 0.0112],\n", + " [ 0.0105, -0.0100, 0.0319],\n", + " [ 0.0436, 0.0082, 0.0280]],\n", + "\n", + " [[-0.0323, -0.0262, -0.0027],\n", + " [ 0.0013, 0.0182, -0.0145],\n", + " [-0.0491, -0.0368, 0.0062]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0272, 0.0284, 0.0171],\n", + " [ 0.0156, 0.0173, 0.0069],\n", + " [-0.0398, 0.0368, 0.0365]],\n", + "\n", + " [[-0.0210, -0.0316, -0.0054],\n", + " [-0.0044, 0.0173, -0.0213],\n", + " [ 0.0172, 0.0107, 0.0009]],\n", + "\n", + " [[-0.0052, -0.0406, 0.0017],\n", + " [ 0.0289, 0.0095, -0.0076],\n", + " [-0.0059, 0.0249, -0.0022]]],\n", + "\n", + "\n", + " [[[-0.0250, -0.0216, 0.0087],\n", + " [ 0.0114, -0.0257, -0.0030],\n", + " [-0.0314, -0.0055, -0.0429]],\n", + "\n", + " [[-0.0375, -0.0081, -0.0144],\n", + " [ 0.0251, 0.0122, -0.0046],\n", + " [ 0.0348, 0.0360, 0.0058]],\n", + "\n", + " [[ 0.0087, 0.0354, 0.0022],\n", + " [ 0.0136, 0.0033, -0.0153],\n", + " [-0.0095, 0.0035, -0.0282]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0125, -0.0045, 0.0161],\n", + " [ 0.0001, 0.0047, -0.0010],\n", + " [ 0.0173, 0.0316, 0.0081]],\n", + "\n", + " [[-0.0102, -0.0060, 0.0126],\n", + " [-0.0195, -0.0114, 0.0203],\n", + " [-0.0016, 0.0137, 0.0091]],\n", + "\n", + " [[ 0.0342, 0.0114, 0.0099],\n", + " [ 0.0082, -0.0026, -0.0128],\n", + " [-0.0224, 0.0479, -0.0131]]]])\n", + "layer3.0.bn2.weight Parameter containing:\n", + "tensor([1.0119, 0.9384, 1.0827, 1.0648, 1.0254, 1.0561, 1.0390, 1.0293, 1.0542,\n", + " 1.1104, 1.0157, 1.0315, 1.0969, 1.1143, 1.1288, 1.0071, 1.0915, 1.0707,\n", + " 0.9861, 1.0539, 1.0197, 1.0023, 1.0544, 1.0703, 1.0772, 1.0126, 1.0308,\n", + " 0.9459, 1.0869, 1.0696, 1.0180, 1.0500, 0.9495, 0.9905, 0.9891, 0.9669,\n", + " 1.1350, 1.0054, 1.0277, 1.1071, 1.0835, 1.0175, 1.1351, 0.9666, 1.0443,\n", + " 1.0653, 0.9898, 1.0989, 1.1159, 1.0458, 1.1057, 0.9943, 0.9539, 0.9514,\n", + " 1.1175, 0.8948, 1.1319, 1.1366, 1.0852, 1.0669, 1.0453, 0.9932, 1.0121,\n", + " 1.0484, 1.0055, 1.0050, 1.0672, 1.0756, 1.0323, 0.9517, 1.0874, 0.9885,\n", + " 0.9905, 1.0063, 1.0442, 0.9888, 1.0702, 1.0889, 0.9140, 0.9522, 1.0362,\n", + " 1.0711, 1.0422, 1.1315, 1.0167, 1.1342, 0.8848, 1.0353, 1.0933, 1.0097,\n", + " 1.0541, 1.0233, 1.0751, 1.0267, 0.9831, 0.9609, 1.0753, 1.0302, 1.0146,\n", + " 1.0836, 1.1139, 1.0072, 1.0257, 0.9531, 1.0608, 1.0504, 1.0427, 0.9754,\n", + " 1.1409, 1.0762, 1.1007, 1.0515, 1.0316, 1.0814, 1.0715, 1.0389, 1.0805,\n", + " 0.9638, 1.0554, 0.9259, 1.0075, 1.0366, 0.9909, 1.0715, 1.0229, 1.0265,\n", + " 0.9960, 1.0285, 1.0408, 1.0881, 1.0641, 1.1841, 0.9831, 1.1062, 1.0657,\n", + " 1.0388, 1.1276, 1.0904, 1.0004, 1.0091, 0.9559, 1.0164, 0.9498, 0.9922,\n", + " 1.1715, 1.0042, 1.0876, 1.1105, 1.0552, 1.0394, 1.0504, 0.9417, 1.0226,\n", + " 0.9802, 1.0365, 0.9911, 0.9943, 0.9997, 0.9738, 0.9340, 1.1164, 1.0912,\n", + " 1.0402, 1.0733, 1.0423, 0.9550, 1.0115, 0.9951, 1.0020, 0.9947, 1.0187,\n", + " 1.0672, 0.9908, 1.0988, 0.9963, 1.0964, 1.0812, 1.1380, 1.0009, 0.9751,\n", + " 1.0534, 1.1093, 1.0231, 1.0245, 0.8350, 1.0952, 1.0597, 0.9853, 1.0292,\n", + " 1.0843, 1.0527, 1.0730, 0.9615, 1.1262, 0.9904, 1.0345, 1.0790, 1.0669,\n", + " 1.0724, 1.0305, 0.9846, 1.0156, 1.0134, 0.9018, 1.1278, 1.0555, 1.0536,\n", + " 1.0175, 0.9649, 1.0734, 1.0182, 0.9771, 0.9521, 1.0517, 0.9989, 0.9906,\n", + " 1.0523, 0.9449, 1.0905, 1.0786, 1.1407, 1.0546, 0.9064, 1.0434, 0.9211,\n", + " 1.0347, 0.8981, 1.0382, 1.0147, 1.0623, 0.9599, 1.0316, 1.0457, 1.0186,\n", + " 1.1254, 1.1296, 1.1115, 0.9257, 0.9933, 1.0522, 0.9944, 1.0506, 0.9876,\n", + " 0.9906, 1.0165, 0.8943, 0.9334, 1.0913, 0.9658, 1.0258, 1.0106, 0.9817,\n", + " 1.0933, 0.9433, 1.0289, 1.0667])\n", + "layer3.0.bn2.bias Parameter containing:\n", + "tensor([-0.0755, -0.0719, -0.0956, -0.1168, -0.1131, -0.0915, -0.0598, -0.1081,\n", + " -0.1179, -0.0952, -0.0843, -0.0633, -0.0956, -0.0895, -0.0859, -0.0657,\n", + " -0.0718, -0.0734, -0.0954, -0.0896, -0.0917, -0.0895, -0.0989, -0.0541,\n", + " -0.0895, -0.0864, -0.0950, -0.0644, -0.1172, -0.0876, -0.1271, -0.1238,\n", + " -0.0690, -0.1027, -0.0964, -0.0774, -0.0782, -0.1324, -0.0958, -0.0867,\n", + " -0.0766, -0.0599, -0.1177, -0.1013, -0.0862, -0.0843, -0.0909, -0.0763,\n", + " -0.0594, -0.0976, -0.0880, -0.0501, -0.0928, -0.0974, -0.1033, -0.1063,\n", + " -0.0736, -0.0788, -0.0619, -0.0717, -0.0822, -0.0713, -0.0758, -0.0622,\n", + " -0.0840, -0.0806, -0.1197, -0.0879, -0.0915, -0.1236, -0.0803, -0.0861,\n", + " -0.0713, -0.1014, -0.1262, -0.0877, -0.0946, -0.0900, -0.0831, -0.1069,\n", + " -0.0975, -0.1071, -0.0907, -0.1050, -0.0708, -0.0908, -0.1228, -0.0650,\n", + " -0.1774, -0.1430, -0.0872, -0.0692, -0.1001, -0.0892, -0.0924, -0.0851,\n", + " -0.1070, -0.0903, -0.0953, -0.0785, -0.0789, -0.1276, -0.0917, -0.0950,\n", + " -0.1073, -0.0781, -0.1001, -0.0755, -0.0976, -0.0857, -0.1092, -0.1065,\n", + " -0.1036, -0.0960, -0.0764, -0.0968, -0.1196, -0.0907, -0.0702, -0.0925,\n", + " -0.1132, -0.0908, -0.0933, -0.0812, -0.0916, -0.1031, -0.0769, -0.1044,\n", + " -0.1095, -0.0834, -0.0466, -0.0939, -0.1198, -0.0800, -0.1381, -0.0655,\n", + " -0.1258, -0.0981, -0.0926, -0.0740, -0.1012, -0.1057, -0.1134, -0.0877,\n", + " -0.1252, -0.0627, -0.1136, -0.0758, -0.1047, -0.0449, -0.0940, -0.1118,\n", + " -0.1121, -0.0740, -0.0836, -0.0855, -0.0764, -0.1345, -0.1051, -0.1177,\n", + " -0.0892, -0.0686, -0.0428, -0.0641, -0.0841, -0.0900, -0.1286, -0.0907,\n", + " -0.0964, -0.0649, -0.0511, -0.1129, -0.0960, -0.0700, -0.0757, -0.0907,\n", + " -0.0469, -0.0820, -0.1011, -0.0797, -0.1031, -0.0720, -0.0966, -0.0867,\n", + " -0.1202, -0.0668, -0.0907, -0.1095, -0.0925, -0.0882, -0.1421, -0.0695,\n", + " -0.1184, -0.0824, -0.0652, -0.1343, -0.1453, -0.0800, -0.0881, -0.0690,\n", + " -0.0882, -0.0614, -0.1213, -0.1122, -0.0952, -0.0945, -0.0591, -0.0854,\n", + " -0.1119, -0.0706, -0.0883, -0.1093, -0.0873, -0.0830, -0.0958, -0.0559,\n", + " -0.1080, -0.1117, -0.0335, -0.0701, -0.1163, -0.0996, -0.0995, -0.0798,\n", + " -0.0977, -0.0878, -0.0828, -0.1044, -0.0683, -0.0817, -0.0601, -0.0860,\n", + " -0.1298, -0.0825, -0.1011, -0.0311, -0.0796, -0.1116, -0.1409, -0.0530,\n", + " -0.0907, -0.1049, -0.0740, -0.1359, -0.0716, -0.0704, -0.1097, -0.0979,\n", + " -0.0780, -0.0969, -0.0487, -0.1053, -0.0492, -0.1102, -0.0792, -0.0921])\n", + "layer3.0.shortcut_conv.weight Parameter containing:\n", + "tensor([[[[ 0.1287]],\n", + "\n", + " [[-0.0332]],\n", + "\n", + " [[-0.0495]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0439]],\n", + "\n", + " [[ 0.0160]],\n", + "\n", + " [[ 0.0147]]],\n", + "\n", + "\n", + " [[[ 0.0304]],\n", + "\n", + " [[-0.0746]],\n", + "\n", + " [[-0.0246]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0141]],\n", + "\n", + " [[-0.0442]],\n", + "\n", + " [[ 0.0355]]],\n", + "\n", + "\n", + " [[[ 0.1153]],\n", + "\n", + " [[-0.0623]],\n", + "\n", + " [[-0.0273]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0472]],\n", + "\n", + " [[-0.0022]],\n", + "\n", + " [[-0.0559]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[-0.0353]],\n", + "\n", + " [[-0.0721]],\n", + "\n", + " [[ 0.0526]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0233]],\n", + "\n", + " [[ 0.1033]],\n", + "\n", + " [[-0.0018]]],\n", + "\n", + "\n", + " [[[-0.0268]],\n", + "\n", + " [[ 0.1036]],\n", + "\n", + " [[ 0.0685]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0291]],\n", + "\n", + " [[-0.0261]],\n", + "\n", + " [[-0.0402]]],\n", + "\n", + "\n", + " [[[-0.0194]],\n", + "\n", + " [[-0.0895]],\n", + "\n", + " [[ 0.0330]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0139]],\n", + "\n", + " [[-0.0454]],\n", + "\n", + " [[-0.0404]]]])\n", + "layer3.0.shortcut_bn.weight Parameter containing:\n", + "tensor([0.9961, 1.0297, 1.0394, 0.9886, 0.9974, 0.9935, 1.0201, 0.9893, 1.0292,\n", + " 1.0017, 1.0291, 0.9921, 0.9860, 1.0360, 1.0332, 0.9975, 0.9960, 0.9750,\n", + " 1.0053, 0.9879, 0.9756, 0.9804, 1.0613, 1.0357, 1.0245, 1.0222, 0.9865,\n", + " 1.0318, 0.9853, 1.0022, 0.9951, 0.9852, 0.9804, 0.9720, 1.0117, 1.0068,\n", + " 1.0167, 0.9779, 0.9935, 1.0345, 0.9953, 1.0129, 0.9515, 0.9842, 1.0005,\n", + " 1.0093, 1.0047, 1.0619, 1.0454, 1.0080, 1.0235, 0.9593, 0.9769, 0.9948,\n", + " 1.0086, 0.9753, 1.0169, 1.0139, 1.0104, 1.0031, 1.0335, 0.9674, 0.9881,\n", + " 1.0314, 0.9951, 1.0025, 0.9884, 0.9770, 1.0086, 0.9939, 0.9998, 0.9796,\n", + " 0.9894, 1.0292, 1.0134, 0.9811, 1.0466, 1.0405, 0.9304, 0.9861, 0.9896,\n", + " 0.9120, 1.0059, 1.0098, 0.9744, 0.9987, 0.8989, 1.0100, 0.9922, 0.9780,\n", + " 0.9932, 1.0067, 0.9855, 0.9746, 0.9926, 0.9818, 1.0324, 0.9936, 1.0296,\n", + " 1.0343, 0.9481, 0.9449, 0.9791, 0.9461, 1.0283, 1.0003, 1.0000, 0.9881,\n", + " 0.9747, 0.9710, 0.9883, 1.0095, 0.9782, 1.0363, 0.9810, 1.0152, 1.0161,\n", + " 0.9391, 1.0096, 0.9620, 0.9905, 1.0231, 1.0099, 1.0045, 1.0248, 0.9783,\n", + " 0.9516, 1.0396, 1.0492, 1.0304, 0.9702, 0.9914, 1.0251, 1.0117, 1.0149,\n", + " 1.0228, 1.0268, 0.9980, 1.0049, 0.9876, 1.0151, 0.9991, 0.9651, 0.9457,\n", + " 1.0144, 0.9495, 0.9432, 1.0230, 1.0877, 1.0209, 0.9865, 0.9554, 1.0235,\n", + " 0.9386, 1.0033, 1.0058, 0.9780, 1.0165, 0.9841, 0.9669, 1.0396, 1.0317,\n", + " 1.0194, 0.9639, 1.0583, 0.9585, 1.0002, 1.0167, 0.9869, 1.0233, 1.0380,\n", + " 0.9895, 1.0125, 0.9881, 1.0050, 0.9928, 1.0075, 1.0085, 1.0002, 0.9457,\n", + " 1.0325, 1.0370, 1.0178, 1.0055, 0.9116, 1.0270, 1.0444, 0.9693, 0.9844,\n", + " 0.9828, 0.9783, 0.9802, 0.9674, 1.0374, 0.9531, 1.0550, 1.0066, 0.9793,\n", + " 1.0020, 0.9739, 0.9739, 1.0216, 1.0090, 0.9594, 0.9608, 1.0677, 1.0088,\n", + " 0.9768, 0.9741, 0.9988, 1.0017, 0.9587, 0.9666, 0.9970, 0.9987, 1.0117,\n", + " 0.9786, 0.9544, 0.9625, 1.0127, 0.9780, 1.0227, 0.9385, 1.0324, 0.9408,\n", + " 1.0003, 0.9506, 1.0368, 1.0347, 0.9962, 0.9938, 1.0976, 1.0253, 1.0174,\n", + " 0.9370, 0.9770, 1.0025, 0.9594, 0.9978, 1.0000, 1.0066, 1.0002, 0.9581,\n", + " 0.9819, 0.9930, 0.9400, 0.9619, 1.0309, 0.9531, 0.9920, 0.9624, 1.0137,\n", + " 1.0539, 0.9746, 0.9789, 0.9881])\n", + "layer3.0.shortcut_bn.bias Parameter containing:\n", + "tensor([-0.0755, -0.0719, -0.0956, -0.1168, -0.1131, -0.0915, -0.0598, -0.1081,\n", + " -0.1179, -0.0952, -0.0843, -0.0633, -0.0956, -0.0895, -0.0859, -0.0657,\n", + " -0.0718, -0.0734, -0.0954, -0.0896, -0.0917, -0.0895, -0.0989, -0.0541,\n", + " -0.0895, -0.0864, -0.0950, -0.0644, -0.1172, -0.0876, -0.1271, -0.1238,\n", + " -0.0690, -0.1027, -0.0964, -0.0774, -0.0782, -0.1324, -0.0958, -0.0867,\n", + " -0.0766, -0.0599, -0.1177, -0.1013, -0.0862, -0.0843, -0.0909, -0.0763,\n", + " -0.0594, -0.0976, -0.0880, -0.0501, -0.0928, -0.0974, -0.1033, -0.1063,\n", + " -0.0736, -0.0788, -0.0619, -0.0717, -0.0822, -0.0713, -0.0758, -0.0622,\n", + " -0.0840, -0.0806, -0.1197, -0.0879, -0.0915, -0.1236, -0.0803, -0.0861,\n", + " -0.0713, -0.1014, -0.1262, -0.0877, -0.0946, -0.0900, -0.0831, -0.1069,\n", + " -0.0975, -0.1071, -0.0907, -0.1050, -0.0708, -0.0908, -0.1228, -0.0650,\n", + " -0.1774, -0.1430, -0.0872, -0.0692, -0.1001, -0.0892, -0.0924, -0.0851,\n", + " -0.1070, -0.0903, -0.0953, -0.0785, -0.0789, -0.1276, -0.0917, -0.0950,\n", + " -0.1073, -0.0781, -0.1001, -0.0755, -0.0976, -0.0857, -0.1092, -0.1065,\n", + " -0.1036, -0.0960, -0.0764, -0.0968, -0.1196, -0.0907, -0.0702, -0.0925,\n", + " -0.1132, -0.0908, -0.0933, -0.0812, -0.0916, -0.1031, -0.0769, -0.1044,\n", + " -0.1095, -0.0834, -0.0466, -0.0939, -0.1198, -0.0800, -0.1381, -0.0655,\n", + " -0.1258, -0.0981, -0.0926, -0.0740, -0.1012, -0.1057, -0.1134, -0.0877,\n", + " -0.1252, -0.0627, -0.1136, -0.0758, -0.1047, -0.0449, -0.0940, -0.1118,\n", + " -0.1121, -0.0740, -0.0836, -0.0855, -0.0764, -0.1345, -0.1051, -0.1177,\n", + " -0.0892, -0.0686, -0.0428, -0.0641, -0.0841, -0.0900, -0.1286, -0.0907,\n", + " -0.0964, -0.0649, -0.0511, -0.1129, -0.0960, -0.0700, -0.0757, -0.0907,\n", + " -0.0469, -0.0820, -0.1011, -0.0797, -0.1031, -0.0720, -0.0966, -0.0867,\n", + " -0.1202, -0.0668, -0.0907, -0.1095, -0.0925, -0.0882, -0.1421, -0.0695,\n", + " -0.1184, -0.0824, -0.0652, -0.1343, -0.1453, -0.0800, -0.0881, -0.0690,\n", + " -0.0882, -0.0614, -0.1213, -0.1122, -0.0952, -0.0945, -0.0591, -0.0854,\n", + " -0.1119, -0.0706, -0.0883, -0.1093, -0.0873, -0.0830, -0.0958, -0.0559,\n", + " -0.1080, -0.1117, -0.0335, -0.0701, -0.1163, -0.0996, -0.0995, -0.0798,\n", + " -0.0977, -0.0878, -0.0828, -0.1044, -0.0683, -0.0817, -0.0601, -0.0860,\n", + " -0.1298, -0.0825, -0.1011, -0.0311, -0.0796, -0.1116, -0.1409, -0.0530,\n", + " -0.0907, -0.1049, -0.0740, -0.1359, -0.0716, -0.0704, -0.1097, -0.0979,\n", + " -0.0780, -0.0969, -0.0487, -0.1053, -0.0492, -0.1102, -0.0792, -0.0921])\n", + "layer3.1.conv1.weight Parameter containing:\n", + "tensor([[[[-0.0089, -0.0060, -0.0078],\n", + " [-0.0083, -0.0144, -0.0087],\n", + " [ 0.0079, -0.0106, -0.0188]],\n", + "\n", + " [[ 0.0221, 0.0075, -0.0010],\n", + " [ 0.0078, 0.0238, 0.0237],\n", + " [-0.0274, -0.0009, -0.0108]],\n", + "\n", + " [[-0.0164, -0.0243, 0.0176],\n", + " [-0.0234, 0.0124, 0.0029],\n", + " [ 0.0187, 0.0050, 0.0236]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0223, 0.0203, 0.0211],\n", + " [ 0.0131, 0.0264, 0.0292],\n", + " [ 0.0029, 0.0059, -0.0041]],\n", + "\n", + " [[ 0.0133, 0.0358, 0.0329],\n", + " [-0.0023, -0.0191, -0.0157],\n", + " [ 0.0123, -0.0018, 0.0313]],\n", + "\n", + " [[ 0.0187, -0.0173, -0.0203],\n", + " [-0.0146, -0.0170, 0.0135],\n", + " [-0.0026, -0.0255, 0.0016]]],\n", + "\n", + "\n", + " [[[ 0.0035, -0.0063, -0.0086],\n", + " [ 0.0185, -0.0081, -0.0105],\n", + " [ 0.0020, 0.0077, 0.0059]],\n", + "\n", + " [[-0.0036, 0.0112, -0.0225],\n", + " [ 0.0080, 0.0294, 0.0012],\n", + " [-0.0424, -0.0239, 0.0066]],\n", + "\n", + " [[ 0.0141, 0.0135, -0.0167],\n", + " [-0.0067, 0.0298, -0.0052],\n", + " [-0.0146, 0.0024, -0.0051]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0113, -0.0088, 0.0199],\n", + " [ 0.0106, 0.0112, 0.0165],\n", + " [ 0.0094, 0.0180, -0.0111]],\n", + "\n", + " [[-0.0123, -0.0243, 0.0136],\n", + " [-0.0087, 0.0113, 0.0470],\n", + " [-0.0176, -0.0201, -0.0065]],\n", + "\n", + " [[ 0.0094, -0.0160, -0.0185],\n", + " [-0.0041, 0.0019, -0.0137],\n", + " [-0.0054, -0.0282, 0.0160]]],\n", + "\n", + "\n", + " [[[ 0.0223, 0.0120, 0.0172],\n", + " [-0.0038, -0.0191, -0.0039],\n", + " [-0.0211, -0.0171, -0.0283]],\n", + "\n", + " [[-0.0338, -0.0060, -0.0046],\n", + " [-0.0102, 0.0018, -0.0045],\n", + " [ 0.0404, 0.0651, 0.0263]],\n", + "\n", + " [[-0.0213, -0.0207, 0.0050],\n", + " [-0.0230, -0.0233, 0.0018],\n", + " [-0.0216, -0.0057, -0.0030]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0011, 0.0098, -0.0103],\n", + " [-0.0096, 0.0110, 0.0143],\n", + " [-0.0016, -0.0168, -0.0251]],\n", + "\n", + " [[ 0.0131, 0.0102, 0.0009],\n", + " [-0.0201, -0.0286, -0.0234],\n", + " [ 0.0084, 0.0084, 0.0085]],\n", + "\n", + " [[ 0.0129, 0.0274, 0.0181],\n", + " [-0.0245, -0.0168, -0.0007],\n", + " [ 0.0059, -0.0039, -0.0160]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[ 0.0009, 0.0100, -0.0076],\n", + " [-0.0015, -0.0030, 0.0290],\n", + " [ 0.0140, 0.0312, -0.0070]],\n", + "\n", + " [[ 0.0038, 0.0222, 0.0296],\n", + " [-0.0521, 0.0081, 0.0183],\n", + " [-0.0018, 0.0102, 0.0051]],\n", + "\n", + " [[ 0.0273, 0.0253, 0.0158],\n", + " [ 0.0034, 0.0326, 0.0349],\n", + " [ 0.0144, 0.0285, -0.0126]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0185, 0.0185, -0.0245],\n", + " [-0.0177, 0.0276, 0.0237],\n", + " [ 0.0372, 0.0134, 0.0167]],\n", + "\n", + " [[ 0.0060, -0.0221, -0.0462],\n", + " [-0.0235, -0.0135, 0.0108],\n", + " [ 0.0101, 0.0221, 0.0232]],\n", + "\n", + " [[-0.0022, 0.0149, 0.0309],\n", + " [-0.0073, -0.0145, -0.0246],\n", + " [-0.0053, 0.0137, -0.0079]]],\n", + "\n", + "\n", + " [[[-0.0138, 0.0086, -0.0118],\n", + " [-0.0012, -0.0124, 0.0198],\n", + " [-0.0138, -0.0054, -0.0149]],\n", + "\n", + " [[-0.0191, 0.0080, 0.0100],\n", + " [ 0.0100, -0.0196, -0.0140],\n", + " [ 0.0006, 0.0038, 0.0065]],\n", + "\n", + " [[ 0.0015, -0.0038, 0.0016],\n", + " [ 0.0130, 0.0152, -0.0030],\n", + " [-0.0230, 0.0007, 0.0087]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0062, 0.0096, -0.0146],\n", + " [ 0.0090, -0.0076, 0.0036],\n", + " [ 0.0204, -0.0055, 0.0151]],\n", + "\n", + " [[-0.0062, 0.0229, 0.0011],\n", + " [ 0.0150, 0.0210, -0.0116],\n", + " [ 0.0360, 0.0070, 0.0110]],\n", + "\n", + " [[ 0.0160, 0.0141, 0.0097],\n", + " [ 0.0261, 0.0006, 0.0071],\n", + " [ 0.0121, 0.0004, -0.0189]]],\n", + "\n", + "\n", + " [[[-0.0238, -0.0239, -0.0015],\n", + " [ 0.0065, -0.0170, 0.0101],\n", + " [ 0.0099, -0.0033, 0.0196]],\n", + "\n", + " [[ 0.0159, 0.0363, -0.0085],\n", + " [-0.0036, 0.0383, 0.0210],\n", + " [-0.0322, 0.0326, -0.0029]],\n", + "\n", + " [[-0.0424, -0.0062, 0.0063],\n", + " [ 0.0033, -0.0216, -0.0250],\n", + " [ 0.0295, -0.0276, -0.0038]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0210, -0.0169, -0.0170],\n", + " [ 0.0021, 0.0172, -0.0203],\n", + " [-0.0038, 0.0303, -0.0193]],\n", + "\n", + " [[-0.0305, -0.0032, -0.0458],\n", + " [-0.0235, -0.0152, -0.0123],\n", + " [ 0.0110, -0.0191, -0.0239]],\n", + "\n", + " [[-0.0303, -0.0097, -0.0278],\n", + " [-0.0252, 0.0002, -0.0171],\n", + " [-0.0116, -0.0104, 0.0082]]]])\n", + "layer3.1.bn1.weight Parameter containing:\n", + "tensor([0.8894, 1.0014, 0.9618, 1.0368, 0.9632, 1.0074, 0.9507, 1.0399, 1.0297,\n", + " 0.9589, 0.9957, 0.9815, 1.0905, 0.9845, 0.9782, 0.9698, 1.0131, 0.9468,\n", + " 1.0746, 1.0122, 0.9194, 1.0240, 0.9088, 1.0045, 0.9597, 0.9893, 1.0038,\n", + " 0.9854, 0.9923, 0.9757, 0.9893, 0.9965, 1.0263, 0.9475, 1.0418, 1.0299,\n", + " 0.9838, 0.9242, 0.9551, 1.0142, 1.0137, 0.9776, 0.9908, 0.9987, 1.0010,\n", + " 1.0373, 0.9471, 0.9570, 1.0478, 0.9621, 0.9507, 0.9834, 1.0207, 1.0203,\n", + " 1.0108, 0.9961, 0.9667, 1.0613, 0.9601, 0.9387, 1.0048, 0.9592, 1.0594,\n", + " 1.0004, 1.0143, 0.9796, 1.0497, 0.9597, 0.9803, 0.9279, 0.9861, 1.0094,\n", + " 0.9469, 0.9917, 0.9735, 0.9667, 0.9915, 1.0499, 0.9900, 1.0731, 1.0523,\n", + " 0.9240, 1.0704, 1.0292, 0.9927, 1.0082, 0.9503, 1.0278, 0.9857, 0.9680,\n", + " 0.9988, 0.9566, 1.0067, 1.0650, 1.0072, 1.0185, 0.9237, 1.0230, 0.9879,\n", + " 0.9544, 0.9825, 1.0734, 0.9985, 0.9623, 1.0024, 1.0158, 1.0302, 1.0588,\n", + " 0.9510, 0.9827, 1.0133, 0.9888, 0.9917, 0.9738, 1.0048, 0.9558, 0.9368,\n", + " 0.9419, 0.9548, 0.9739, 0.9351, 1.0554, 0.9564, 0.9554, 0.9746, 0.9900,\n", + " 0.9743, 1.0129, 1.0568, 1.0261, 0.8503, 0.9825, 1.0072, 1.0561, 1.0295,\n", + " 0.9250, 0.9449, 0.9652, 1.0274, 0.9811, 0.9868, 0.9987, 0.9758, 0.9687,\n", + " 0.9447, 1.0006, 0.9603, 1.0075, 1.0120, 0.9941, 1.0805, 1.0201, 0.9765,\n", + " 0.9931, 1.0093, 1.0358, 1.0294, 1.0004, 1.0088, 1.0204, 0.9912, 0.9467,\n", + " 1.0267, 0.9717, 0.9932, 1.0262, 0.9624, 0.9232, 1.0127, 0.9702, 1.0690,\n", + " 0.9871, 0.9572, 0.9898, 1.0525, 1.0083, 1.0368, 0.9964, 1.0338, 0.8748,\n", + " 0.9964, 0.9670, 0.9804, 1.0126, 0.9944, 1.0274, 0.9512, 0.9921, 0.9697,\n", + " 1.0625, 0.9837, 0.9659, 1.0131, 0.9245, 1.0110, 0.9741, 0.9294, 0.9803,\n", + " 1.0149, 1.0334, 0.9848, 1.0037, 0.9699, 1.0264, 0.9817, 0.9894, 0.8620,\n", + " 1.0600, 1.0214, 0.9951, 1.0127, 0.9452, 0.9733, 1.0047, 1.0568, 1.0307,\n", + " 1.0134, 1.0109, 0.9813, 0.9758, 0.9878, 0.9932, 0.9886, 0.9710, 1.0007,\n", + " 1.0168, 0.9869, 0.9622, 1.0478, 1.0258, 0.9697, 1.0451, 1.0032, 0.9455,\n", + " 0.9796, 1.0273, 1.0156, 1.0052, 0.9761, 1.0013, 0.9775, 0.9848, 0.9489,\n", + " 0.9581, 1.0368, 0.9701, 1.0158, 1.0756, 1.0068, 0.9919, 0.9778, 0.9810,\n", + " 1.0233, 1.0328, 0.9609, 1.0135])\n", + "layer3.1.bn1.bias Parameter containing:\n", + "tensor([-0.2024, -0.0811, -0.0922, -0.0835, -0.1042, -0.0950, -0.1535, -0.1058,\n", + " -0.0725, -0.1229, -0.1334, -0.1228, -0.1481, -0.1424, -0.0705, -0.1125,\n", + " -0.1025, -0.1020, -0.1740, -0.1615, -0.1094, -0.1909, -0.0887, -0.0977,\n", + " -0.1232, -0.1152, -0.1137, -0.1736, -0.0889, -0.1433, -0.1535, -0.1320,\n", + " -0.1768, -0.1870, -0.0638, -0.0644, -0.1462, -0.1610, -0.0984, -0.0944,\n", + " -0.1171, -0.1277, -0.0961, -0.1230, -0.0759, -0.1204, -0.2193, -0.1107,\n", + " -0.1224, -0.0838, -0.0812, -0.1262, -0.1052, -0.0962, -0.1050, -0.1323,\n", + " -0.1272, -0.0985, -0.0897, -0.1146, -0.0901, -0.0966, -0.0778, -0.0866,\n", + " -0.1539, -0.0671, -0.0921, -0.0988, -0.1000, -0.1371, -0.1845, -0.1234,\n", + " -0.1169, -0.1379, -0.1190, -0.0677, -0.1166, -0.0890, -0.0833, -0.0995,\n", + " -0.1120, -0.1498, -0.0999, -0.1339, -0.0877, -0.1662, -0.0699, -0.1084,\n", + " -0.1234, -0.1270, -0.1432, -0.1244, -0.0884, -0.1719, -0.0971, -0.1034,\n", + " -0.0922, -0.0886, -0.1187, -0.1214, -0.0703, -0.1064, -0.0996, -0.1092,\n", + " -0.1004, -0.0719, -0.1040, -0.1137, -0.1492, -0.1240, -0.1093, -0.0972,\n", + " -0.1311, -0.0610, -0.0975, -0.1565, -0.1471, -0.1279, -0.1107, -0.0912,\n", + " -0.1073, -0.1538, -0.1031, -0.1418, -0.1104, -0.1107, -0.0868, -0.1587,\n", + " -0.0774, -0.1199, -0.1776, -0.0625, -0.1020, -0.1542, -0.1789, -0.1120,\n", + " -0.1138, -0.1263, -0.1202, -0.0816, -0.1033, -0.1082, -0.1073, -0.1751,\n", + " -0.1201, -0.0667, -0.1008, -0.1173, -0.1198, -0.1258, -0.0948, -0.1198,\n", + " -0.0755, -0.0960, -0.1107, -0.1115, -0.0655, -0.1048, -0.1052, -0.1628,\n", + " -0.1215, -0.0877, -0.1073, -0.1295, -0.1083, -0.1357, -0.1008, -0.0685,\n", + " -0.1130, -0.1215, -0.1278, -0.1115, -0.0825, -0.0946, -0.0917, -0.1774,\n", + " -0.1626, -0.1593, -0.1224, -0.1393, -0.0884, -0.1989, -0.0981, -0.1358,\n", + " -0.1323, -0.1139, -0.1106, -0.1173, -0.0777, -0.1334, -0.0942, -0.1504,\n", + " -0.0807, -0.1390, -0.0976, -0.1683, -0.2453, -0.1118, -0.1057, -0.1144,\n", + " -0.1056, -0.0634, -0.1000, -0.0946, -0.0647, -0.1244, -0.2231, -0.1076,\n", + " -0.1368, -0.1047, -0.1290, -0.0922, -0.1252, -0.1696, -0.0970, -0.1317,\n", + " -0.1719, -0.0943, -0.1631, -0.1203, -0.1298, -0.1731, -0.0776, -0.0930,\n", + " -0.1405, -0.0987, -0.1073, -0.1154, -0.1397, -0.1594, -0.0856, -0.0981,\n", + " -0.0991, -0.0981, -0.0937, -0.1345, -0.0902, -0.0947, -0.1005, -0.0755,\n", + " -0.0743, -0.1312, -0.1276, -0.1481, -0.1267, -0.1051, -0.1152, -0.0936,\n", + " -0.1153, -0.1397, -0.0979, -0.1554, -0.0985, -0.1164, -0.1008, -0.1003])\n", + "layer3.1.conv2.weight Parameter containing:\n", + "tensor([[[[ 0.0057, 0.0059, 0.0141],\n", + " [ 0.0202, 0.0084, -0.0168],\n", + " [ 0.0042, -0.0093, -0.0161]],\n", + "\n", + " [[-0.0071, -0.0256, -0.0123],\n", + " [ 0.0112, 0.0096, -0.0154],\n", + " [ 0.0074, 0.0087, 0.0290]],\n", + "\n", + " [[ 0.0193, 0.0200, 0.0237],\n", + " [ 0.0042, 0.0013, -0.0049],\n", + " [ 0.0255, 0.0222, 0.0247]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0049, 0.0150, 0.0204],\n", + " [ 0.0048, 0.0343, 0.0375],\n", + " [ 0.0155, -0.0052, 0.0337]],\n", + "\n", + " [[ 0.0311, -0.0015, -0.0184],\n", + " [ 0.0050, 0.0187, -0.0070],\n", + " [ 0.0176, 0.0203, 0.0322]],\n", + "\n", + " [[ 0.0042, 0.0106, -0.0159],\n", + " [-0.0078, 0.0319, -0.0009],\n", + " [ 0.0051, 0.0336, 0.0013]]],\n", + "\n", + "\n", + " [[[ 0.0060, -0.0394, -0.0219],\n", + " [-0.0232, -0.0402, -0.0123],\n", + " [-0.0100, -0.0346, -0.0340]],\n", + "\n", + " [[ 0.0050, 0.0313, 0.0401],\n", + " [-0.0034, -0.0222, 0.0018],\n", + " [ 0.0058, 0.0137, -0.0026]],\n", + "\n", + " [[ 0.0069, 0.0209, -0.0171],\n", + " [-0.0181, -0.0249, -0.0093],\n", + " [ 0.0046, -0.0241, 0.0079]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0178, 0.0215, -0.0115],\n", + " [ 0.0114, -0.0228, 0.0059],\n", + " [-0.0155, -0.0046, -0.0166]],\n", + "\n", + " [[ 0.0264, 0.0007, 0.0135],\n", + " [ 0.0048, -0.0157, 0.0191],\n", + " [-0.0309, -0.0115, -0.0107]],\n", + "\n", + " [[ 0.0129, 0.0396, 0.0171],\n", + " [ 0.0243, 0.0272, -0.0214],\n", + " [ 0.0051, 0.0171, -0.0188]]],\n", + "\n", + "\n", + " [[[ 0.0133, -0.0035, -0.0081],\n", + " [-0.0374, -0.0194, 0.0275],\n", + " [-0.0025, -0.0212, 0.0177]],\n", + "\n", + " [[-0.0197, -0.0340, -0.0122],\n", + " [ 0.0075, 0.0080, -0.0352],\n", + " [-0.0108, -0.0002, -0.0069]],\n", + "\n", + " [[ 0.0178, 0.0160, -0.0038],\n", + " [ 0.0167, -0.0179, -0.0075],\n", + " [ 0.0003, -0.0214, -0.0138]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0011, 0.0236, 0.0067],\n", + " [ 0.0328, -0.0104, 0.0123],\n", + " [-0.0324, -0.0213, -0.0151]],\n", + "\n", + " [[ 0.0477, 0.0054, 0.0267],\n", + " [ 0.0214, 0.0166, 0.0255],\n", + " [ 0.0306, 0.0288, 0.0306]],\n", + "\n", + " [[-0.0001, 0.0285, 0.0163],\n", + " [-0.0139, 0.0104, -0.0082],\n", + " [ 0.0160, 0.0019, -0.0099]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[ 0.0320, 0.0127, -0.0041],\n", + " [ 0.0126, 0.0071, -0.0102],\n", + " [ 0.0449, -0.0208, 0.0264]],\n", + "\n", + " [[-0.0218, -0.0197, 0.0082],\n", + " [ 0.0212, -0.0167, -0.0232],\n", + " [-0.0246, -0.0144, -0.0146]],\n", + "\n", + " [[ 0.0196, 0.0360, 0.0037],\n", + " [ 0.0174, 0.0118, -0.0163],\n", + " [ 0.0149, -0.0078, 0.0053]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0014, -0.0044, -0.0235],\n", + " [-0.0170, -0.0056, -0.0082],\n", + " [-0.0130, 0.0023, 0.0069]],\n", + "\n", + " [[ 0.0082, 0.0044, 0.0025],\n", + " [ 0.0071, -0.0022, -0.0127],\n", + " [ 0.0079, -0.0118, 0.0401]],\n", + "\n", + " [[ 0.0127, 0.0259, 0.0020],\n", + " [-0.0149, -0.0037, -0.0306],\n", + " [ 0.0252, 0.0114, -0.0398]]],\n", + "\n", + "\n", + " [[[-0.0067, 0.0159, -0.0116],\n", + " [-0.0255, 0.0141, -0.0340],\n", + " [-0.0379, -0.0369, -0.0058]],\n", + "\n", + " [[-0.0109, -0.0326, -0.0156],\n", + " [-0.0299, -0.0425, -0.0245],\n", + " [-0.0656, -0.0205, 0.0051]],\n", + "\n", + " [[-0.0484, -0.0066, -0.0070],\n", + " [-0.0176, -0.0068, -0.0352],\n", + " [-0.0163, 0.0053, -0.0278]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0009, 0.0022, -0.0277],\n", + " [ 0.0112, 0.0170, -0.0237],\n", + " [-0.0069, 0.0070, -0.0239]],\n", + "\n", + " [[ 0.0105, 0.0199, -0.0199],\n", + " [-0.0018, 0.0375, 0.0256],\n", + " [ 0.0020, -0.0027, 0.0162]],\n", + "\n", + " [[-0.0299, 0.0109, -0.0016],\n", + " [-0.0151, 0.0335, 0.0003],\n", + " [-0.0382, 0.0336, 0.0335]]],\n", + "\n", + "\n", + " [[[ 0.0120, -0.0307, -0.0122],\n", + " [-0.0485, 0.0204, -0.0066],\n", + " [-0.0292, 0.0023, -0.0102]],\n", + "\n", + " [[ 0.0103, -0.0036, -0.0009],\n", + " [-0.0090, -0.0189, -0.0218],\n", + " [-0.0022, 0.0280, 0.0087]],\n", + "\n", + " [[ 0.0105, -0.0090, 0.0102],\n", + " [-0.0376, -0.0115, -0.0124],\n", + " [ 0.0302, 0.0081, 0.0010]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0258, 0.0199, 0.0164],\n", + " [-0.0088, -0.0327, 0.0286],\n", + " [-0.0255, -0.0063, 0.0088]],\n", + "\n", + " [[ 0.0173, -0.0106, -0.0176],\n", + " [-0.0141, 0.0150, -0.0110],\n", + " [-0.0014, 0.0298, 0.0401]],\n", + "\n", + " [[-0.0097, 0.0207, 0.0122],\n", + " [ 0.0098, 0.0361, 0.0345],\n", + " [ 0.0262, 0.0573, -0.0103]]]])\n", + "layer3.1.bn2.weight Parameter containing:\n", + "tensor([0.9815, 0.9771, 0.9937, 0.9977, 0.9733, 0.9092, 0.9829, 0.9169, 0.9609,\n", + " 0.9104, 0.8757, 0.9897, 0.9988, 1.0001, 1.0297, 0.8942, 0.9350, 0.9427,\n", + " 0.9168, 0.9627, 0.9683, 0.9354, 0.9495, 0.9964, 0.9769, 0.9127, 1.0025,\n", + " 0.9153, 0.9689, 0.9639, 0.9897, 1.0123, 0.8282, 0.9270, 0.9236, 0.9664,\n", + " 0.9928, 0.9019, 0.9426, 0.9955, 0.9358, 0.9749, 0.9828, 0.9415, 1.0004,\n", + " 0.9123, 0.9367, 0.9801, 0.9494, 1.0004, 0.9750, 0.9334, 1.0004, 0.9339,\n", + " 0.8686, 0.8470, 0.9903, 1.0119, 1.0165, 0.9926, 0.9256, 0.9330, 0.9856,\n", + " 0.9424, 1.0011, 0.9819, 0.9341, 1.0159, 0.9502, 0.9183, 0.9090, 0.9548,\n", + " 1.0113, 0.9721, 0.9436, 0.9371, 0.9473, 0.9046, 0.9444, 0.9053, 0.9450,\n", + " 0.9036, 1.0273, 0.9669, 0.9598, 0.9456, 0.8899, 0.9520, 0.9313, 0.9089,\n", + " 0.9674, 0.9503, 0.8887, 0.9139, 0.9624, 0.9340, 1.0025, 0.8952, 0.9185,\n", + " 0.9119, 0.9541, 0.9354, 0.9323, 0.9323, 1.0323, 0.9860, 0.9632, 0.9852,\n", + " 0.9095, 0.9466, 0.9968, 1.0321, 0.8953, 0.9952, 0.9438, 0.9600, 1.0027,\n", + " 0.9396, 0.9181, 0.9470, 0.9278, 0.9733, 1.0050, 0.9730, 0.9877, 0.9527,\n", + " 0.9141, 0.8995, 0.9542, 0.9364, 0.9155, 0.9720, 0.9578, 1.0013, 1.0498,\n", + " 0.9335, 0.9727, 0.9802, 0.9157, 0.8526, 0.8498, 0.9313, 0.9292, 0.9360,\n", + " 0.9503, 0.9614, 1.0159, 0.9742, 0.9591, 0.9834, 0.9546, 0.8601, 0.9151,\n", + " 0.9533, 0.9906, 0.8816, 0.9180, 0.8996, 0.9657, 0.8390, 0.9566, 0.9444,\n", + " 0.9687, 0.9363, 1.0103, 0.8784, 0.9648, 0.9563, 0.9561, 1.0161, 0.9506,\n", + " 0.9709, 0.9270, 0.9539, 0.9587, 0.9816, 1.0350, 0.9984, 0.9343, 0.9864,\n", + " 0.9699, 1.0122, 0.9517, 0.9404, 0.9368, 0.9919, 0.9200, 0.9307, 0.9481,\n", + " 0.9009, 0.9305, 0.9460, 0.9792, 0.9991, 0.8814, 0.9784, 0.9849, 0.9834,\n", + " 0.9355, 0.9520, 0.9651, 0.9303, 0.9693, 0.9200, 0.9991, 0.9569, 0.9554,\n", + " 0.9542, 0.8686, 1.0286, 0.9254, 0.8529, 0.9538, 0.9630, 0.8646, 0.9576,\n", + " 0.9614, 0.9150, 0.9488, 1.0792, 0.9474, 0.9442, 0.9534, 0.8712, 0.9632,\n", + " 0.9681, 0.9126, 0.9659, 0.9386, 0.9330, 0.9911, 0.9612, 0.9779, 0.9494,\n", + " 0.8655, 0.9691, 0.9830, 1.0095, 0.9631, 1.0426, 1.0039, 0.9078, 0.9335,\n", + " 0.9570, 0.9473, 0.9557, 0.9158, 0.9798, 0.9176, 0.8368, 0.9515, 0.9396,\n", + " 0.9669, 0.9546, 0.9310, 0.8975])\n", + "layer3.1.bn2.bias Parameter containing:\n", + "tensor([-0.1031, -0.1217, -0.1326, -0.1258, -0.1390, -0.0922, -0.0841, -0.1330,\n", + " -0.1585, -0.1332, -0.1653, -0.0609, -0.0702, -0.1586, -0.1070, -0.1527,\n", + " -0.1529, -0.1246, -0.1743, -0.1270, -0.1165, -0.1431, -0.1666, -0.0825,\n", + " -0.1168, -0.1519, -0.1148, -0.0939, -0.1390, -0.1273, -0.1161, -0.1486,\n", + " -0.1854, -0.1222, -0.1213, -0.1401, -0.1382, -0.1273, -0.1230, -0.1030,\n", + " -0.1122, -0.1197, -0.1796, -0.1384, -0.0926, -0.1447, -0.1543, -0.1356,\n", + " -0.0909, -0.1386, -0.1104, -0.1318, -0.1326, -0.1224, -0.1635, -0.1458,\n", + " -0.0994, -0.1207, -0.0871, -0.1215, -0.1321, -0.1024, -0.1104, -0.1533,\n", + " -0.0775, -0.1287, -0.1569, -0.1287, -0.1081, -0.0918, -0.1163, -0.1112,\n", + " -0.0950, -0.1562, -0.1089, -0.0883, -0.1193, -0.0902, -0.1262, -0.1601,\n", + " -0.1588, -0.2315, -0.1030, -0.1270, -0.0776, -0.1719, -0.1275, -0.1466,\n", + " -0.1918, -0.2018, -0.1322, -0.1183, -0.1523, -0.1425, -0.1505, -0.1293,\n", + " -0.1456, -0.1323, -0.1759, -0.1544, -0.1577, -0.1870, -0.1309, -0.0995,\n", + " -0.1034, -0.1288, -0.1301, -0.1242, -0.1568, -0.1373, -0.1541, -0.1013,\n", + " -0.1291, -0.0844, -0.1278, -0.1105, -0.1554, -0.1066, -0.0992, -0.1333,\n", + " -0.1511, -0.1137, -0.0950, -0.1285, -0.1438, -0.1257, -0.1515, -0.1425,\n", + " -0.1811, -0.1336, -0.1211, -0.1363, -0.1193, -0.1516, -0.1115, -0.1189,\n", + " -0.1272, -0.1287, -0.1752, -0.1583, -0.2094, -0.1342, -0.1100, -0.1305,\n", + " -0.1064, -0.1320, -0.1358, -0.1208, -0.1355, -0.1005, -0.1263, -0.1729,\n", + " -0.1199, -0.0826, -0.0964, -0.1561, -0.1532, -0.1950, -0.1397, -0.1617,\n", + " -0.1192, -0.1215, -0.0973, -0.1043, -0.1371, -0.1510, -0.1281, -0.1020,\n", + " -0.1469, -0.0536, -0.0947, -0.1230, -0.1114, -0.1059, -0.1259, -0.1463,\n", + " -0.0779, -0.1328, -0.1282, -0.1450, -0.1135, -0.1125, -0.1291, -0.1261,\n", + " -0.1648, -0.1404, -0.1629, -0.1324, -0.1645, -0.1539, -0.1878, -0.1183,\n", + " -0.1369, -0.1324, -0.1300, -0.1355, -0.1158, -0.1134, -0.1212, -0.1176,\n", + " -0.1827, -0.0990, -0.1417, -0.1345, -0.1632, -0.1474, -0.1211, -0.1373,\n", + " -0.1854, -0.1219, -0.1227, -0.1246, -0.1039, -0.0776, -0.1508, -0.0964,\n", + " -0.1037, -0.1521, -0.0945, -0.0795, -0.1662, -0.1484, -0.0996, -0.1285,\n", + " -0.1003, -0.1070, -0.1675, -0.1326, -0.1186, -0.1332, -0.0866, -0.1046,\n", + " -0.1211, -0.0580, -0.1816, -0.0841, -0.1563, -0.1192, -0.1486, -0.0603,\n", + " -0.0866, -0.1133, -0.1044, -0.2205, -0.1373, -0.0997, -0.1647, -0.1118,\n", + " -0.0883, -0.1750, -0.0968, -0.1916, -0.1200, -0.1369, -0.1829, -0.1554])\n", + "layer4.0.conv1.weight Parameter containing:\n", + "tensor([[[[-1.6195e-02, -4.3060e-02, -3.5322e-02],\n", + " [ 1.3287e-02, -3.1615e-02, 6.6042e-03],\n", + " [-4.0014e-04, -1.5452e-02, -9.1024e-03]],\n", + "\n", + " [[-1.4157e-02, -2.3567e-02, 1.4208e-02],\n", + " [-1.2353e-02, 1.6866e-02, 1.7598e-02],\n", + " [-2.0796e-03, 2.8303e-02, 1.2474e-02]],\n", + "\n", + " [[-1.2214e-02, -2.6098e-03, -3.1828e-02],\n", + " [-1.0792e-02, -2.9335e-02, 3.1974e-03],\n", + " [-6.6637e-03, -4.7699e-03, -1.9124e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-5.4387e-03, -7.4101e-03, -1.2255e-02],\n", + " [-9.4177e-03, -3.5547e-02, 8.5545e-03],\n", + " [-4.9719e-03, 4.6679e-03, 1.3081e-02]],\n", + "\n", + " [[ 9.7803e-03, -9.9373e-03, -1.0429e-02],\n", + " [ 5.0466e-03, -2.2508e-03, 1.5473e-02],\n", + " [-8.3493e-03, 8.0397e-03, -6.2141e-03]],\n", + "\n", + " [[-3.4990e-02, -1.2068e-02, -4.8616e-03],\n", + " [ 1.5458e-02, 3.9716e-05, 1.8083e-02],\n", + " [-2.0388e-02, -2.6572e-02, -3.2755e-02]]],\n", + "\n", + "\n", + " [[[ 2.7451e-02, -1.6241e-02, -5.4663e-03],\n", + " [-2.5122e-03, 9.7752e-03, 1.9901e-03],\n", + " [-1.1972e-02, 6.6506e-03, 2.7320e-02]],\n", + "\n", + " [[ 8.9008e-03, -3.0568e-03, 1.7271e-02],\n", + " [-2.8983e-03, 2.2242e-02, 1.4016e-02],\n", + " [ 1.3805e-02, 1.3811e-02, 1.8592e-02]],\n", + "\n", + " [[-1.1898e-02, 1.1536e-02, -1.6407e-02],\n", + " [ 1.8869e-02, 1.0717e-02, -4.6789e-03],\n", + " [-3.3365e-03, 1.0226e-02, 7.3902e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[-6.0078e-03, 2.0500e-02, 3.9849e-03],\n", + " [ 1.4595e-02, 1.4047e-02, -2.0757e-02],\n", + " [ 1.1118e-02, 3.0116e-02, 5.9856e-03]],\n", + "\n", + " [[-1.8926e-02, 2.3320e-02, 2.0495e-03],\n", + " [-9.0993e-03, 1.3444e-02, 1.7396e-02],\n", + " [ 5.3953e-03, -1.1074e-02, -1.0029e-02]],\n", + "\n", + " [[ 3.5316e-02, -1.0975e-02, 2.5410e-02],\n", + " [ 1.1206e-02, 3.4463e-02, 8.5281e-03],\n", + " [ 5.2337e-03, -1.8434e-02, 2.2219e-02]]],\n", + "\n", + "\n", + " [[[ 1.5784e-02, -2.8938e-02, -2.7249e-02],\n", + " [ 3.6373e-04, 1.7282e-03, 1.1883e-02],\n", + " [ 1.6549e-02, -2.4496e-02, -1.4036e-02]],\n", + "\n", + " [[-9.9012e-03, 1.3803e-02, -2.6475e-02],\n", + " [ 2.0324e-03, -3.2427e-02, 2.6785e-03],\n", + " [-2.3451e-02, -1.3071e-02, -1.3452e-02]],\n", + "\n", + " [[ 8.0489e-03, -8.7835e-03, 2.3109e-02],\n", + " [ 5.7620e-03, 5.4874e-03, 7.9275e-03],\n", + " [ 2.2045e-02, -7.0428e-03, -2.5761e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[-6.4942e-05, -9.1919e-03, 1.5764e-02],\n", + " [ 1.1611e-02, -8.8280e-03, 8.2890e-03],\n", + " [ 9.9335e-03, 1.7437e-02, -2.8565e-03]],\n", + "\n", + " [[-2.0067e-02, -8.6341e-03, -1.6160e-02],\n", + " [-2.8909e-02, -2.4014e-03, 2.9025e-02],\n", + " [ 1.4045e-02, -2.3337e-02, -7.0309e-03]],\n", + "\n", + " [[-2.1758e-02, 1.3736e-03, 8.1537e-03],\n", + " [-6.1176e-03, 1.5924e-02, -5.0352e-03],\n", + " [ 1.9466e-02, 2.9483e-02, 1.9495e-02]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[-2.3250e-02, 2.0028e-02, 2.9691e-02],\n", + " [ 4.9181e-03, 6.5347e-04, -3.8907e-03],\n", + " [-5.9318e-03, 2.8015e-02, -4.9769e-02]],\n", + "\n", + " [[-6.7370e-03, -2.8322e-02, -2.8513e-02],\n", + " [ 1.4032e-02, -4.3892e-04, -7.2716e-03],\n", + " [ 2.4244e-02, -1.1580e-02, -1.0610e-04]],\n", + "\n", + " [[-2.0042e-02, -4.6635e-02, -8.6490e-03],\n", + " [-6.1748e-03, 2.5440e-02, 3.8151e-02],\n", + " [-1.4358e-02, 3.1131e-03, -1.8387e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-4.9713e-02, -1.6618e-02, -2.0875e-02],\n", + " [ 1.5922e-02, -1.2795e-02, 1.7373e-02],\n", + " [ 2.1518e-02, 1.2405e-02, -2.8119e-03]],\n", + "\n", + " [[ 1.7024e-02, -1.3888e-03, 1.4596e-02],\n", + " [-2.7721e-02, 1.4081e-02, -4.4708e-03],\n", + " [ 1.6595e-02, 7.6307e-03, 1.7189e-03]],\n", + "\n", + " [[-3.4501e-02, -2.7804e-02, 3.3934e-02],\n", + " [-3.6700e-02, -2.7948e-03, -2.7233e-02],\n", + " [ 3.1503e-02, 2.9363e-03, -1.7693e-02]]],\n", + "\n", + "\n", + " [[[-7.1497e-03, -9.2680e-03, -1.6912e-02],\n", + " [-2.5124e-02, 1.3490e-02, -3.0191e-03],\n", + " [-3.4414e-02, -1.7793e-02, -6.2751e-04]],\n", + "\n", + " [[-1.0422e-02, 7.3811e-03, 4.8907e-04],\n", + " [-3.3726e-02, -3.7883e-02, 6.4931e-03],\n", + " [ 1.0105e-02, -1.0586e-02, 8.0654e-03]],\n", + "\n", + " [[ 1.3769e-02, -2.7422e-02, 1.1539e-02],\n", + " [ 3.9909e-02, 1.7185e-02, -9.6350e-03],\n", + " [ 2.6662e-02, 2.6606e-02, 3.8052e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.3738e-02, -2.4817e-02, 1.7280e-02],\n", + " [ 6.5198e-03, 3.2310e-02, 6.7140e-03],\n", + " [ 1.7105e-03, 1.2700e-02, 1.7214e-02]],\n", + "\n", + " [[-1.8066e-02, -2.9962e-02, -1.9691e-02],\n", + " [-1.8091e-03, 4.9750e-03, 3.6765e-02],\n", + " [-6.4435e-04, -1.4589e-02, 3.2546e-02]],\n", + "\n", + " [[-3.8230e-02, 7.1454e-03, -1.3799e-02],\n", + " [-1.3242e-02, 3.7163e-03, 4.9353e-03],\n", + " [ 2.6065e-03, -8.7617e-03, 1.9123e-04]]],\n", + "\n", + "\n", + " [[[ 2.7979e-02, 2.8406e-02, 1.1925e-03],\n", + " [ 2.6382e-03, -1.0972e-02, -2.5136e-02],\n", + " [ 4.6399e-03, -1.9916e-02, -3.2441e-02]],\n", + "\n", + " [[-5.8886e-03, -2.7692e-02, -3.8166e-04],\n", + " [-1.0334e-02, -1.1055e-02, -2.7850e-02],\n", + " [-8.6016e-05, -2.0821e-02, -2.4210e-02]],\n", + "\n", + " [[-1.0134e-02, 2.7711e-02, 1.8510e-02],\n", + " [ 2.5513e-02, 4.7151e-02, 2.7833e-02],\n", + " [-1.1608e-02, -1.1251e-03, 4.3334e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.5657e-02, -1.6549e-02, 2.1792e-02],\n", + " [-2.4143e-02, -7.6517e-03, -1.7215e-02],\n", + " [-1.3531e-02, 1.7204e-02, 2.0492e-02]],\n", + "\n", + " [[ 1.7322e-03, -2.6014e-02, -1.5783e-02],\n", + " [-1.8252e-02, -3.1492e-02, 2.1263e-02],\n", + " [-4.6617e-03, -4.9437e-02, 1.2522e-02]],\n", + "\n", + " [[-5.0840e-03, -1.1199e-02, 1.7259e-02],\n", + " [-2.2503e-03, -2.5246e-02, -1.4658e-02],\n", + " [ 2.6199e-03, -2.5215e-02, -1.4805e-02]]]])\n", + "layer4.0.bn1.weight Parameter containing:\n", + "tensor([0.9534, 0.9690, 0.9779, 0.9481, 0.9661, 1.0123, 0.9933, 0.9159, 0.9811,\n", + " 0.9206, 1.0421, 1.0466, 1.0091, 0.9920, 0.9846, 0.9807, 0.9996, 1.0425,\n", + " 0.9745, 0.9941, 1.0109, 0.9837, 1.0157, 0.9821, 1.0174, 0.9381, 1.0259,\n", + " 0.9520, 0.9610, 0.9995, 1.0986, 1.0051, 0.9452, 0.9813, 1.0712, 0.9902,\n", + " 1.0526, 1.0237, 1.0363, 0.9657, 0.9260, 1.0403, 1.0011, 0.9876, 0.9901,\n", + " 0.9819, 1.0163, 0.9651, 1.0024, 1.0359, 0.9798, 0.9556, 0.9099, 0.9861,\n", + " 1.0158, 0.9633, 1.0052, 1.0268, 0.8805, 0.9295, 0.9899, 1.0083, 1.0109,\n", + " 1.0104, 1.0468, 1.0115, 1.0763, 0.9179, 0.9844, 0.9861, 0.9508, 0.9012,\n", + " 0.9999, 0.9355, 0.9764, 0.9445, 0.9730, 1.0513, 1.0292, 1.0491, 1.0474,\n", + " 0.9812, 1.0326, 1.0261, 1.0150, 0.9971, 0.9672, 1.0045, 0.9567, 1.0128,\n", + " 1.0227, 1.0236, 1.0422, 0.9762, 0.9627, 1.0149, 1.0167, 1.0260, 1.0085,\n", + " 1.0187, 0.9780, 1.0015, 0.9958, 0.9427, 0.9465, 1.0131, 0.9742, 0.9766,\n", + " 0.9988, 0.9577, 1.0082, 0.9634, 0.9935, 1.0256, 0.9242, 1.0261, 0.9989,\n", + " 0.9949, 1.0067, 1.0322, 0.9711, 0.9497, 1.0195, 0.9055, 1.0064, 0.9844,\n", + " 0.9814, 0.9728, 1.0224, 0.9679, 0.9614, 1.0349, 1.0423, 0.9596, 0.9673,\n", + " 0.9739, 0.9891, 0.9619, 0.9872, 1.0058, 0.9381, 1.0493, 0.9425, 0.9895,\n", + " 1.0559, 1.0504, 1.0078, 0.9718, 0.9642, 0.9505, 1.0015, 0.9588, 0.9706,\n", + " 0.9973, 1.0377, 0.9871, 0.9899, 0.9458, 1.0140, 0.9517, 1.0294, 0.9697,\n", + " 1.0190, 1.0207, 0.9097, 0.9653, 0.9344, 1.0160, 0.9589, 0.9415, 0.9669,\n", + " 1.0531, 0.9557, 0.9938, 0.9876, 1.0481, 0.9405, 0.9630, 1.0038, 0.9767,\n", + " 1.0061, 1.0094, 1.0396, 0.9756, 0.9531, 1.0328, 0.9250, 0.9481, 0.9716,\n", + " 0.9042, 0.9550, 1.0123, 0.9906, 0.9989, 0.9321, 1.0071, 1.0029, 0.9474,\n", + " 0.9859, 1.0049, 1.0039, 1.0362, 1.0165, 1.0032, 0.9968, 1.0036, 1.0069,\n", + " 0.9967, 0.9749, 0.9296, 0.9297, 1.0247, 0.9740, 1.0138, 0.9648, 0.9531,\n", + " 1.0081, 1.0393, 0.9835, 1.0207, 1.0476, 0.9793, 1.0733, 1.0040, 1.0099,\n", + " 0.9120, 0.9656, 0.9239, 0.9971, 1.0198, 0.9772, 0.9728, 1.0036, 0.9730,\n", + " 1.0232, 0.9582, 0.9908, 0.9818, 0.9472, 0.9884, 1.0278, 1.0084, 1.0249,\n", + " 0.9170, 0.9236, 1.0222, 1.0302, 0.9916, 0.9527, 0.9615, 0.9906, 0.9947,\n", + " 0.9456, 0.9808, 0.9915, 0.9550, 1.0177, 0.9746, 0.9998, 0.9447, 0.9958,\n", + " 0.9704, 1.0024, 0.9991, 1.0001, 0.9900, 0.9976, 0.9804, 1.0698, 0.9950,\n", + " 0.9954, 0.9723, 0.9784, 0.9220, 1.0127, 1.0142, 0.9826, 1.0532, 1.0166,\n", + " 1.0256, 0.9422, 0.9629, 1.0517, 0.9915, 1.0066, 1.0042, 0.9927, 1.0177,\n", + " 1.0036, 0.9713, 0.9261, 0.9843, 0.9568, 0.8978, 1.0141, 1.0160, 0.9223,\n", + " 0.9694, 0.9814, 0.9412, 0.9750, 1.0287, 1.0105, 0.9912, 1.0344, 0.9992,\n", + " 0.9417, 1.0092, 0.9744, 0.9925, 0.9815, 0.9724, 0.9699, 0.9620, 0.9853,\n", + " 0.9886, 0.9213, 0.9863, 1.0109, 0.9926, 0.8956, 0.9822, 0.9506, 0.9584,\n", + " 0.9641, 1.0227, 0.9630, 1.0442, 0.8903, 1.0128, 0.9640, 1.0305, 0.9512,\n", + " 0.9702, 1.0634, 0.9853, 1.0459, 1.0792, 0.9543, 0.9602, 1.0012, 1.0700,\n", + " 1.0283, 1.0327, 1.0158, 1.0004, 0.9954, 1.0422, 0.9440, 1.0295, 0.9736,\n", + " 0.9947, 1.0174, 0.9618, 1.0027, 0.9854, 1.0047, 1.0212, 0.9810, 1.0320,\n", + " 1.0045, 1.0151, 0.9683, 0.9965, 1.0195, 0.9927, 0.9648, 0.9768, 1.0195,\n", + " 0.9632, 1.0413, 0.9660, 0.9954, 0.9986, 0.9687, 0.9979, 0.9959, 0.9528,\n", + " 0.9811, 0.9361, 0.9754, 0.9515, 1.0222, 0.9065, 0.9596, 1.0143, 0.9654,\n", + " 1.0097, 1.0257, 0.9529, 0.9987, 1.0023, 1.0235, 0.9968, 0.9391, 1.0246,\n", + " 1.0415, 0.8944, 1.0037, 0.9884, 0.9838, 1.0100, 0.9934, 1.0158, 0.9992,\n", + " 0.9945, 1.0269, 0.9988, 0.9746, 0.9849, 0.9837, 0.9853, 0.9961, 0.9961,\n", + " 0.9367, 1.0434, 1.0180, 0.9880, 1.0014, 0.9977, 1.0152, 0.9922, 0.9402,\n", + " 0.9401, 0.9328, 0.9936, 0.9267, 0.9819, 0.9703, 1.0273, 1.0675, 0.9921,\n", + " 1.0293, 0.9936, 0.9824, 0.9465, 0.9486, 0.9946, 0.9840, 1.0305, 1.0158,\n", + " 1.0217, 1.0431, 0.9639, 1.0107, 1.0006, 1.0016, 0.9956, 1.0443, 0.9965,\n", + " 1.0049, 1.0160, 1.0023, 0.9719, 1.0077, 0.9105, 1.0384, 1.0231, 1.0132,\n", + " 1.0314, 0.9939, 0.9616, 1.0067, 1.0498, 0.9761, 1.0339, 0.9703, 1.0237,\n", + " 0.9617, 0.9809, 1.0074, 0.9658, 0.9654, 0.9655, 0.9909, 0.9582, 0.9765,\n", + " 1.0340, 0.9436, 0.9821, 0.9701, 1.0471, 0.9938, 1.0092, 1.0402, 1.0519,\n", + " 0.9758, 0.9893, 1.0087, 0.9748, 1.0169, 1.0258, 0.9980, 0.9642, 0.9609,\n", + " 1.0336, 1.0385, 1.0035, 0.9217, 1.0398, 0.9457, 0.9696, 0.9964, 1.0307,\n", + " 0.9936, 0.9372, 0.9538, 0.9759, 0.9730, 1.0237, 0.9669, 1.0096])\n", + "layer4.0.bn1.bias Parameter containing:\n", + "tensor([-0.1313, -0.1636, -0.1204, -0.1072, -0.1112, -0.0995, -0.1602, -0.1600,\n", + " -0.1085, -0.1312, -0.1073, -0.1452, -0.0914, -0.1252, -0.1654, -0.1372,\n", + " -0.1688, -0.1468, -0.1604, -0.0992, -0.1659, -0.1345, -0.1312, -0.1844,\n", + " -0.1081, -0.1385, -0.0724, -0.1348, -0.1353, -0.1345, -0.1175, -0.1163,\n", + " -0.1461, -0.1335, -0.1096, -0.1257, -0.1304, -0.1433, -0.1206, -0.1464,\n", + " -0.1422, -0.1156, -0.1401, -0.1501, -0.1286, -0.1440, -0.1155, -0.1654,\n", + " -0.1292, -0.1382, -0.1352, -0.1472, -0.1765, -0.1114, -0.1139, -0.1491,\n", + " -0.1383, -0.1242, -0.1776, -0.1247, -0.1412, -0.1067, -0.1106, -0.1121,\n", + " -0.1345, -0.1257, -0.1265, -0.1531, -0.1490, -0.1013, -0.1468, -0.1372,\n", + " -0.1792, -0.1360, -0.1198, -0.1086, -0.1617, -0.1660, -0.1228, -0.1483,\n", + " -0.0945, -0.1110, -0.1253, -0.1089, -0.1319, -0.1490, -0.1170, -0.1364,\n", + " -0.1354, -0.1216, -0.1171, -0.1352, -0.1398, -0.1338, -0.1219, -0.1710,\n", + " -0.1311, -0.1457, -0.1274, -0.1073, -0.1173, -0.1264, -0.1163, -0.1548,\n", + " -0.1421, -0.1454, -0.1357, -0.1657, -0.1303, -0.1507, -0.1211, -0.1941,\n", + " -0.1222, -0.1104, -0.1506, -0.1259, -0.1086, -0.1291, -0.1060, -0.1140,\n", + " -0.1336, -0.1632, -0.1211, -0.1490, -0.1231, -0.1739, -0.1502, -0.1617,\n", + " -0.1118, -0.1397, -0.1093, -0.1283, -0.0986, -0.1305, -0.1638, -0.1666,\n", + " -0.1335, -0.1897, -0.1499, -0.1116, -0.1827, -0.1186, -0.1934, -0.1323,\n", + " -0.1338, -0.1509, -0.1360, -0.1442, -0.1689, -0.1680, -0.1318, -0.1438,\n", + " -0.1223, -0.1627, -0.1273, -0.1161, -0.1623, -0.1373, -0.1423, -0.1192,\n", + " -0.1059, -0.0972, -0.1507, -0.1390, -0.1647, -0.1244, -0.1644, -0.1498,\n", + " -0.1398, -0.1459, -0.1667, -0.1316, -0.1936, -0.1124, -0.1577, -0.1089,\n", + " -0.1290, -0.1378, -0.1233, -0.1526, -0.1156, -0.1673, -0.1438, -0.1725,\n", + " -0.1309, -0.1138, -0.1615, -0.1055, -0.1204, -0.1736, -0.1382, -0.1065,\n", + " -0.1216, -0.1760, -0.1800, -0.0827, -0.0992, -0.1225, -0.1353, -0.1260,\n", + " -0.1208, -0.1600, -0.1191, -0.1677, -0.1315, -0.1472, -0.1356, -0.1218,\n", + " -0.1187, -0.1620, -0.1179, -0.1591, -0.1496, -0.1036, -0.2044, -0.1428,\n", + " -0.1366, -0.1065, -0.1027, -0.1475, -0.1195, -0.1082, -0.1053, -0.1461,\n", + " -0.1691, -0.1706, -0.1300, -0.1864, -0.2089, -0.1170, -0.1480, -0.1576,\n", + " -0.1383, -0.1583, -0.1084, -0.1296, -0.1503, -0.1452, -0.1285, -0.1061,\n", + " -0.1148, -0.1278, -0.1292, -0.1543, -0.1608, -0.0848, -0.1095, -0.1184,\n", + " -0.1345, -0.1086, -0.1362, -0.1644, -0.1193, -0.1414, -0.1349, -0.1336,\n", + " -0.1059, -0.1538, -0.0908, -0.1335, -0.1084, -0.1521, -0.1355, -0.0934,\n", + " -0.1003, -0.1576, -0.1681, -0.1774, -0.1273, -0.1339, -0.1386, -0.1601,\n", + " -0.1350, -0.1508, -0.1035, -0.1223, -0.1325, -0.1116, -0.1383, -0.1226,\n", + " -0.1313, -0.1304, -0.1158, -0.1140, -0.1451, -0.1867, -0.1453, -0.1258,\n", + " -0.1335, -0.1790, -0.1468, -0.1700, -0.1195, -0.2057, -0.2011, -0.1837,\n", + " -0.1499, -0.1168, -0.1462, -0.2006, -0.1992, -0.1075, -0.1263, -0.1031,\n", + " -0.1248, -0.1572, -0.1477, -0.1345, -0.1228, -0.1207, -0.1314, -0.1488,\n", + " -0.2015, -0.1650, -0.1175, -0.1292, -0.1483, -0.1341, -0.1506, -0.1534,\n", + " -0.1612, -0.1565, -0.1284, -0.1653, -0.2038, -0.1381, -0.2177, -0.1264,\n", + " -0.1406, -0.1023, -0.1211, -0.1226, -0.1256, -0.1343, -0.1095, -0.1531,\n", + " -0.1050, -0.1293, -0.1657, -0.1386, -0.1151, -0.1445, -0.1054, -0.1113,\n", + " -0.1317, -0.1385, -0.1116, -0.1155, -0.1675, -0.1208, -0.1604, -0.1439,\n", + " -0.1369, -0.1378, -0.1473, -0.1286, -0.0956, -0.0966, -0.1479, -0.1595,\n", + " -0.1180, -0.1170, -0.1364, -0.1524, -0.1400, -0.1881, -0.1347, -0.1378,\n", + " -0.1508, -0.0998, -0.1386, -0.1385, -0.1280, -0.1320, -0.1358, -0.1685,\n", + " -0.1500, -0.1462, -0.1385, -0.1401, -0.1104, -0.1173, -0.1374, -0.1525,\n", + " -0.1615, -0.1578, -0.1641, -0.1375, -0.1990, -0.1294, -0.1518, -0.1708,\n", + " -0.1386, -0.1386, -0.1358, -0.0890, -0.1142, -0.1281, -0.1446, -0.1275,\n", + " -0.1392, -0.1407, -0.1318, -0.1416, -0.1491, -0.1211, -0.1878, -0.1545,\n", + " -0.1248, -0.1439, -0.1822, -0.1450, -0.1298, -0.1392, -0.1355, -0.1164,\n", + " -0.1082, -0.1086, -0.1520, -0.1792, -0.1319, -0.1390, -0.1580, -0.1233,\n", + " -0.1509, -0.1281, -0.1462, -0.1493, -0.1781, -0.1589, -0.1167, -0.1008,\n", + " -0.0953, -0.1499, -0.1674, -0.1244, -0.1516, -0.1418, -0.1308, -0.1808,\n", + " -0.0961, -0.1359, -0.1715, -0.1258, -0.1250, -0.1408, -0.1399, -0.1728,\n", + " -0.1399, -0.1178, -0.1152, -0.1127, -0.1319, -0.1285, -0.1043, -0.1546,\n", + " -0.1249, -0.1142, -0.1647, -0.1136, -0.0991, -0.1799, -0.1017, -0.1487,\n", + " -0.1128, -0.1135, -0.1450, -0.1197, -0.1501, -0.1437, -0.1226, -0.1115,\n", + " -0.1573, -0.1187, -0.2205, -0.1382, -0.1460, -0.0959, -0.1433, -0.1307,\n", + " -0.1407, -0.1565, -0.1459, -0.1592, -0.0840, -0.1889, -0.1479, -0.1313,\n", + " -0.1135, -0.1300, -0.1179, -0.0983, -0.1496, -0.1093, -0.1448, -0.1395,\n", + " -0.1300, -0.1570, -0.1539, -0.1475, -0.1276, -0.1268, -0.1325, -0.1324,\n", + " -0.1404, -0.1444, -0.1513, -0.1405, -0.1721, -0.1507, -0.1526, -0.0866])\n", + "layer4.0.conv2.weight Parameter containing:\n", + "tensor([[[[-1.7886e-02, -1.8696e-02, -6.5051e-03],\n", + " [ 7.2588e-03, -2.0011e-02, -6.7659e-03],\n", + " [-5.8594e-03, -2.5900e-02, -1.4812e-02]],\n", + "\n", + " [[-1.9013e-02, 1.5084e-03, 1.4686e-03],\n", + " [-4.1722e-03, 1.4034e-02, 6.9301e-03],\n", + " [ 3.9946e-03, 3.1264e-02, 1.7031e-02]],\n", + "\n", + " [[-4.3595e-03, -1.5497e-02, -7.2879e-04],\n", + " [-2.1562e-02, -1.8119e-02, -2.8660e-02],\n", + " [ 4.3782e-03, 1.1628e-02, 2.1428e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.3216e-02, 1.7132e-04, -1.1600e-02],\n", + " [ 1.0509e-02, 6.1292e-03, 9.3679e-04],\n", + " [ 9.9824e-03, -2.7226e-03, -7.6713e-04]],\n", + "\n", + " [[ 1.9380e-02, 4.5168e-04, -1.6261e-02],\n", + " [ 9.4663e-03, 9.7768e-03, 1.0472e-02],\n", + " [ 1.0579e-03, -1.0838e-02, -1.6651e-02]],\n", + "\n", + " [[ 7.1076e-03, -1.0444e-02, 1.2159e-02],\n", + " [-4.5764e-03, -6.8533e-03, -2.4884e-02],\n", + " [ 3.1004e-03, -5.1720e-03, -9.0846e-03]]],\n", + "\n", + "\n", + " [[[-9.9561e-03, -2.9929e-02, -1.3286e-02],\n", + " [ 1.2054e-02, 1.7240e-02, 3.4074e-02],\n", + " [-1.3281e-02, 2.6545e-02, 5.2089e-03]],\n", + "\n", + " [[ 1.0709e-02, -6.2842e-03, 3.2054e-03],\n", + " [-2.6670e-02, -1.0388e-03, -7.6723e-03],\n", + " [ 6.0955e-03, 2.3389e-02, 4.6433e-03]],\n", + "\n", + " [[ 3.2169e-03, -2.2192e-03, 1.3805e-02],\n", + " [ 1.7897e-02, 1.2070e-02, -1.4670e-02],\n", + " [ 6.3306e-03, 3.5517e-02, 3.2253e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[-3.2339e-04, -1.4459e-02, -6.0876e-03],\n", + " [ 1.4088e-02, 4.1493e-03, 2.0326e-03],\n", + " [ 2.6903e-02, 6.3054e-03, 1.0476e-02]],\n", + "\n", + " [[-1.3155e-02, 1.9230e-02, -2.7132e-02],\n", + " [ 1.4182e-02, 1.3118e-02, 6.7906e-04],\n", + " [ 1.9165e-02, -8.0848e-03, 1.1247e-02]],\n", + "\n", + " [[-3.5696e-03, -1.3386e-02, 1.5999e-02],\n", + " [-5.2860e-03, 4.1539e-03, -6.4848e-03],\n", + " [-9.2003e-04, 2.5272e-03, -2.2210e-02]]],\n", + "\n", + "\n", + " [[[-6.4596e-04, 2.6375e-03, 1.9750e-02],\n", + " [ 3.5908e-03, -9.7965e-03, 3.7559e-03],\n", + " [ 2.3296e-02, 1.5638e-02, 4.5195e-03]],\n", + "\n", + " [[-2.1403e-02, -1.8759e-02, 1.0075e-02],\n", + " [-4.0992e-02, -1.0487e-02, 4.4678e-03],\n", + " [-2.7970e-02, -1.7760e-02, -7.7252e-04]],\n", + "\n", + " [[-2.1670e-02, -1.3132e-02, 9.6461e-03],\n", + " [-3.7094e-02, 3.3480e-03, -2.0624e-02],\n", + " [-3.3094e-02, -2.6236e-02, -1.4729e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-2.3484e-02, 3.4918e-02, 5.6633e-03],\n", + " [ 1.3121e-02, -4.1489e-03, 1.8868e-02],\n", + " [ 5.6241e-03, 2.3452e-02, -1.3444e-02]],\n", + "\n", + " [[ 1.8781e-02, -6.7572e-03, -1.7906e-02],\n", + " [-2.1573e-02, 2.1743e-02, -1.1799e-02],\n", + " [-8.9948e-04, 9.6979e-03, 9.4103e-03]],\n", + "\n", + " [[ 3.0581e-04, -1.2199e-02, 1.0750e-02],\n", + " [-3.6209e-02, -1.1889e-03, 4.2709e-03],\n", + " [ 3.2966e-02, 1.4731e-02, 1.1646e-02]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[ 7.0561e-03, -1.1558e-02, -1.6542e-02],\n", + " [ 7.7660e-03, 2.1384e-02, -2.2963e-02],\n", + " [-2.4396e-02, 6.4466e-03, 2.3426e-02]],\n", + "\n", + " [[ 9.9911e-04, -1.6690e-02, -7.5850e-03],\n", + " [-1.5195e-03, -2.1704e-03, -3.2688e-02],\n", + " [-1.4812e-02, 9.4784e-06, -1.1086e-02]],\n", + "\n", + " [[ 1.3946e-02, 1.2512e-02, -4.8051e-03],\n", + " [ 1.3617e-02, 3.2708e-03, 2.3468e-02],\n", + " [ 7.1597e-03, -3.1580e-04, -1.6377e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-9.0539e-03, 1.7373e-02, -2.9568e-02],\n", + " [-4.2147e-03, -1.6024e-02, -1.4590e-02],\n", + " [-2.5991e-02, 1.3760e-02, -4.2937e-03]],\n", + "\n", + " [[ 2.5654e-02, -1.5972e-02, -1.3203e-02],\n", + " [ 2.3070e-03, -2.2005e-03, 8.4325e-03],\n", + " [ 8.5600e-03, -3.0322e-02, -3.4817e-02]],\n", + "\n", + " [[-2.3064e-02, 1.9884e-02, -6.6768e-03],\n", + " [-1.3571e-02, -8.7839e-03, -6.3031e-04],\n", + " [-1.0146e-03, 1.1744e-02, 2.2794e-02]]],\n", + "\n", + "\n", + " [[[-5.9354e-03, -2.4573e-02, -6.1468e-03],\n", + " [-1.7341e-02, 1.6610e-02, -2.3676e-03],\n", + " [-4.0607e-03, 2.3387e-02, -2.7001e-02]],\n", + "\n", + " [[ 2.9649e-02, 1.0233e-03, 1.1201e-02],\n", + " [-8.8892e-03, 3.3015e-03, 9.1620e-03],\n", + " [ 4.9614e-02, 1.6541e-02, 5.4178e-03]],\n", + "\n", + " [[ 5.7061e-03, -2.0687e-04, 1.8121e-02],\n", + " [ 9.9960e-03, -2.2996e-02, -3.8443e-02],\n", + " [ 2.9459e-02, -8.1446e-03, -7.2190e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[-2.1950e-02, -5.8360e-03, -8.6734e-03],\n", + " [-6.4084e-04, 1.6201e-02, -2.3953e-03],\n", + " [-9.0676e-04, 1.2011e-02, -1.3787e-02]],\n", + "\n", + " [[-7.7046e-04, 7.1940e-03, 1.6575e-02],\n", + " [-4.9565e-02, -1.0073e-02, -2.1843e-02],\n", + " [-4.4395e-03, -2.4704e-02, -3.3025e-02]],\n", + "\n", + " [[-2.1877e-02, 2.5838e-02, 1.5783e-02],\n", + " [-4.3521e-02, 4.0280e-03, -1.4476e-04],\n", + " [-1.0410e-03, 2.4805e-03, 7.8111e-03]]],\n", + "\n", + "\n", + " [[[ 6.3897e-03, -1.0151e-02, -7.8149e-03],\n", + " [ 3.2832e-03, -3.7016e-03, -1.6792e-03],\n", + " [ 3.0218e-04, -7.0481e-03, 1.4635e-02]],\n", + "\n", + " [[-6.8840e-03, -4.6635e-03, -1.7792e-02],\n", + " [-5.9637e-03, -1.8775e-02, 7.1388e-03],\n", + " [-1.8852e-02, 4.5739e-03, -8.0146e-04]],\n", + "\n", + " [[-8.8560e-03, 9.4655e-03, -2.5622e-03],\n", + " [-1.2442e-02, -5.4622e-03, 4.5498e-03],\n", + " [-1.0570e-02, -7.3137e-04, -2.4904e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-2.5443e-02, -8.4710e-03, -4.8651e-05],\n", + " [ 2.7532e-02, -7.2788e-03, 3.8603e-02],\n", + " [ 1.2313e-02, 3.8152e-03, 6.0188e-03]],\n", + "\n", + " [[ 4.4933e-03, 7.1814e-03, 1.5395e-02],\n", + " [-1.4873e-02, 1.3530e-02, -7.5244e-03],\n", + " [-5.2787e-03, -5.8126e-04, -2.5811e-02]],\n", + "\n", + " [[-2.2218e-02, -7.7884e-03, 2.2616e-03],\n", + " [-1.0254e-02, -2.6434e-02, 1.5597e-02],\n", + " [-1.9715e-02, 1.5517e-02, 1.6434e-02]]]])\n", + "layer4.0.bn2.weight Parameter containing:\n", + "tensor([0.9555, 0.9785, 0.9496, 0.9762, 0.9477, 0.9735, 0.9830, 0.9709, 0.9828,\n", + " 0.9800, 0.9667, 0.9463, 0.9382, 0.9838, 0.9651, 0.9838, 0.9959, 0.9737,\n", + " 0.9975, 0.9388, 1.0001, 0.9637, 0.9640, 0.9840, 0.9687, 0.9413, 0.9266,\n", + " 0.9813, 0.9410, 0.9559, 0.9835, 0.9626, 0.9521, 0.9745, 1.0076, 0.9600,\n", + " 0.9822, 0.9820, 0.9353, 0.9054, 0.9364, 0.9649, 0.9596, 0.9963, 0.9624,\n", + " 0.9268, 0.9765, 1.0038, 0.9494, 0.9595, 0.9965, 1.0301, 0.9637, 0.9209,\n", + " 0.9634, 0.9810, 1.0096, 0.9736, 0.9506, 0.9889, 0.9523, 0.9745, 0.9406,\n", + " 0.9872, 0.9760, 1.0253, 0.9682, 0.9641, 0.9850, 1.0172, 0.9645, 0.9324,\n", + " 0.9707, 0.9516, 0.9803, 0.8969, 0.9797, 0.9510, 0.9722, 0.9290, 0.9546,\n", + " 0.9707, 0.9955, 0.9561, 0.9662, 1.0180, 0.9996, 0.9805, 0.9548, 0.9760,\n", + " 0.8972, 0.9846, 0.9154, 0.9770, 0.9695, 0.9210, 0.9686, 0.9508, 0.9365,\n", + " 0.9962, 0.9534, 0.9397, 0.9752, 0.9787, 0.9633, 0.9820, 0.9829, 0.9563,\n", + " 1.0000, 0.9692, 1.0029, 0.9389, 0.9895, 0.9848, 0.9618, 0.9606, 0.8920,\n", + " 0.9950, 1.0084, 1.0011, 0.9852, 0.9696, 0.9810, 0.9393, 1.0241, 0.9766,\n", + " 0.9407, 1.0100, 0.9954, 1.0044, 0.9674, 0.9553, 0.9904, 0.9750, 1.0066,\n", + " 0.9549, 0.9426, 0.9569, 0.9396, 0.9754, 0.9547, 1.0431, 0.9609, 0.9645,\n", + " 0.9803, 0.9494, 0.9675, 1.0050, 1.0065, 0.9358, 0.9977, 1.0043, 1.0053,\n", + " 0.9672, 0.9477, 0.9334, 0.9452, 0.9477, 0.9697, 0.9561, 0.9601, 0.9548,\n", + " 1.0286, 0.9487, 0.9529, 0.9548, 0.9512, 0.9689, 0.9678, 0.9416, 0.9688,\n", + " 0.9717, 0.9849, 0.9832, 0.9674, 0.8993, 0.9652, 0.9835, 0.9444, 0.9778,\n", + " 0.9206, 0.9684, 0.9371, 0.9898, 0.9586, 1.0066, 0.9892, 0.9753, 0.9446,\n", + " 0.9486, 0.9731, 0.9771, 0.9552, 0.9496, 0.9854, 0.8991, 0.9244, 0.9771,\n", + " 0.9557, 0.9823, 0.9617, 0.9556, 0.9713, 0.9306, 0.9830, 0.9637, 0.9615,\n", + " 0.9432, 0.9936, 0.9965, 0.9939, 0.9662, 0.9847, 0.9459, 1.0106, 0.9978,\n", + " 0.9864, 0.9836, 0.9472, 0.9943, 0.9877, 0.9788, 0.9589, 0.9004, 0.9368,\n", + " 0.9213, 0.9875, 0.9673, 0.9724, 0.9362, 0.9608, 0.9686, 0.9624, 0.9685,\n", + " 0.9486, 0.9572, 0.9623, 0.9320, 0.9418, 0.9778, 0.9820, 0.9436, 0.9621,\n", + " 0.9683, 0.9400, 0.9803, 0.9238, 0.9475, 0.9266, 0.9434, 0.9181, 1.0136,\n", + " 0.9859, 0.9299, 0.9896, 0.9835, 0.9883, 0.9865, 0.9275, 1.0005, 0.9368,\n", + " 0.9942, 0.9573, 0.9808, 0.9787, 0.9534, 1.0137, 0.8962, 0.9409, 0.9807,\n", + " 0.9453, 1.0381, 0.9634, 0.9773, 0.9643, 0.9484, 0.9605, 0.9253, 0.9943,\n", + " 0.9505, 0.9833, 0.9996, 0.9519, 0.9952, 1.0069, 0.9695, 0.9791, 0.9849,\n", + " 0.9753, 1.0021, 0.9570, 0.9735, 0.9800, 1.0057, 0.9682, 0.9981, 0.9178,\n", + " 0.9467, 1.0186, 0.9955, 0.9593, 0.9700, 0.9298, 0.9841, 0.9576, 0.9683,\n", + " 0.9715, 0.9853, 0.9751, 0.9591, 0.9580, 0.9272, 0.9904, 0.9850, 0.9661,\n", + " 0.9751, 1.0004, 0.9607, 0.9932, 0.9582, 0.9322, 0.9509, 1.0128, 0.9531,\n", + " 0.9501, 0.9875, 0.9256, 0.9406, 0.9517, 0.9849, 0.9961, 0.9599, 0.9165,\n", + " 0.9653, 0.9585, 0.9558, 0.9775, 0.9540, 0.9849, 0.9814, 0.9843, 0.9834,\n", + " 0.9572, 0.9750, 1.0024, 0.9440, 0.9530, 0.9244, 0.9580, 0.9388, 0.9935,\n", + " 0.9627, 0.9777, 0.9621, 0.9731, 0.9123, 0.9389, 0.9417, 0.9847, 0.9396,\n", + " 0.9555, 0.9516, 1.0058, 0.9725, 0.8901, 0.9674, 0.9561, 0.9598, 1.0162,\n", + " 0.9572, 0.9861, 0.9811, 0.9995, 1.0045, 0.9723, 0.9668, 0.9831, 0.9109,\n", + " 0.9407, 0.9635, 0.9981, 0.9481, 0.9573, 0.9463, 0.9792, 0.9070, 0.9721,\n", + " 0.9638, 0.9737, 0.9175, 1.0011, 0.9757, 1.0017, 0.9957, 0.9562, 0.9789,\n", + " 0.9631, 0.9546, 0.9913, 0.9549, 1.0033, 0.9550, 0.9639, 0.9441, 0.9702,\n", + " 0.9904, 0.9402, 0.9744, 1.0061, 0.9798, 1.0017, 1.0137, 1.0099, 0.9962,\n", + " 0.9511, 0.9597, 1.0050, 0.9719, 0.9704, 0.9813, 0.9900, 0.9667, 0.9582,\n", + " 1.0062, 0.9669, 1.0110, 0.9397, 1.0132, 0.9442, 0.9814, 0.8946, 0.9956,\n", + " 0.9420, 0.9621, 0.9445, 0.9682, 0.9952, 0.9799, 0.9466, 0.9718, 0.9509,\n", + " 0.9710, 0.9798, 0.9832, 0.9664, 0.9887, 0.9617, 0.9754, 1.0175, 0.9922,\n", + " 1.0063, 0.9900, 0.9609, 0.9710, 0.9622, 0.9798, 0.9414, 1.0162, 0.9600,\n", + " 0.9837, 1.0000, 0.9270, 0.9603, 0.9866, 0.9681, 1.0009, 0.9595, 0.9672,\n", + " 0.9514, 0.9584, 0.9382, 0.9560, 0.9736, 0.9953, 0.9513, 0.9762, 0.9818,\n", + " 0.9787, 0.9892, 0.9558, 0.9719, 0.9303, 0.9718, 0.9597, 0.9713, 0.9641,\n", + " 0.9278, 0.9673, 0.9914, 0.9579, 0.9361, 1.0325, 0.9826, 0.9888, 0.9810,\n", + " 0.9229, 1.0155, 1.0057, 0.9610, 0.9168, 0.9803, 0.9771, 0.9940, 0.9315,\n", + " 0.9625, 0.9365, 0.9514, 0.9863, 0.9985, 1.0196, 1.0129, 0.9435])\n", + "layer4.0.bn2.bias Parameter containing:\n", + "tensor([-0.1077, -0.1421, -0.1821, -0.1668, -0.1509, -0.1312, -0.1380, -0.1309,\n", + " -0.1376, -0.0952, -0.1164, -0.1412, -0.1481, -0.1340, -0.0982, -0.1111,\n", + " -0.1355, -0.1359, -0.0763, -0.1281, -0.1032, -0.1161, -0.1639, -0.1614,\n", + " -0.1359, -0.1743, -0.1449, -0.1202, -0.1613, -0.1218, -0.0912, -0.1359,\n", + " -0.1473, -0.1254, -0.1260, -0.1315, -0.1130, -0.1370, -0.1554, -0.1598,\n", + " -0.1412, -0.1671, -0.1722, -0.1153, -0.1522, -0.1593, -0.1292, -0.1383,\n", + " -0.1708, -0.1354, -0.1689, -0.1445, -0.1057, -0.1420, -0.1475, -0.1008,\n", + " -0.1661, -0.1452, -0.1685, -0.0975, -0.1722, -0.1182, -0.0939, -0.1553,\n", + " -0.1652, -0.1251, -0.1249, -0.1043, -0.1387, -0.1344, -0.1790, -0.1616,\n", + " -0.1750, -0.1710, -0.1289, -0.1422, -0.1303, -0.1413, -0.1527, -0.1282,\n", + " -0.1279, -0.1226, -0.1151, -0.1201, -0.1322, -0.1368, -0.1222, -0.1290,\n", + " -0.1492, -0.1467, -0.1736, -0.1057, -0.1849, -0.0941, -0.1291, -0.1398,\n", + " -0.1382, -0.1133, -0.1222, -0.1267, -0.1743, -0.1167, -0.1169, -0.1327,\n", + " -0.1233, -0.1442, -0.1083, -0.1543, -0.1278, -0.1602, -0.1623, -0.1351,\n", + " -0.1327, -0.1430, -0.1688, -0.1055, -0.1419, -0.1694, -0.1434, -0.1472,\n", + " -0.1157, -0.1287, -0.1160, -0.0923, -0.1165, -0.1386, -0.1188, -0.1284,\n", + " -0.1347, -0.1452, -0.1594, -0.1360, -0.1470, -0.1863, -0.1705, -0.1492,\n", + " -0.1368, -0.1487, -0.1645, -0.1125, -0.1228, -0.1313, -0.1517, -0.1273,\n", + " -0.1473, -0.1441, -0.1204, -0.1410, -0.1254, -0.1298, -0.1175, -0.1477,\n", + " -0.1687, -0.1658, -0.1260, -0.1066, -0.1634, -0.1069, -0.1038, -0.1724,\n", + " -0.1575, -0.1456, -0.1222, -0.1235, -0.1357, -0.1030, -0.1516, -0.1180,\n", + " -0.2283, -0.1081, -0.0908, -0.1186, -0.1629, -0.1338, -0.1590, -0.1619,\n", + " -0.1434, -0.1391, -0.1531, -0.1400, -0.1208, -0.1121, -0.1173, -0.1413,\n", + " -0.1178, -0.1539, -0.1020, -0.1621, -0.1607, -0.1491, -0.1088, -0.0963,\n", + " -0.1369, -0.1579, -0.1558, -0.1364, -0.1629, -0.1167, -0.1064, -0.1564,\n", + " -0.1016, -0.1513, -0.1344, -0.1452, -0.1330, -0.1143, -0.1711, -0.1544,\n", + " -0.1010, -0.1073, -0.1114, -0.1060, -0.2010, -0.1259, -0.1225, -0.1472,\n", + " -0.1328, -0.2026, -0.1359, -0.1276, -0.1133, -0.1622, -0.1322, -0.1666,\n", + " -0.1334, -0.1379, -0.1284, -0.1683, -0.1491, -0.1184, -0.1172, -0.1208,\n", + " -0.1506, -0.1204, -0.1467, -0.1479, -0.1494, -0.0968, -0.1067, -0.1380,\n", + " -0.1635, -0.0869, -0.1374, -0.1196, -0.1598, -0.1280, -0.1609, -0.1608,\n", + " -0.1404, -0.1249, -0.2054, -0.1497, -0.1637, -0.1352, -0.1547, -0.1162,\n", + " -0.1721, -0.1412, -0.1409, -0.1169, -0.1034, -0.1353, -0.1230, -0.1313,\n", + " -0.1110, -0.1485, -0.1141, -0.1649, -0.1388, -0.0882, -0.1378, -0.1768,\n", + " -0.1178, -0.1274, -0.1642, -0.1661, -0.1437, -0.1595, -0.1388, -0.2104,\n", + " -0.1090, -0.1467, -0.1431, -0.1173, -0.1046, -0.1935, -0.1478, -0.1360,\n", + " -0.1500, -0.1175, -0.1217, -0.1539, -0.1516, -0.1312, -0.1663, -0.0784,\n", + " -0.1408, -0.1435, -0.1334, -0.1537, -0.1486, -0.1203, -0.1474, -0.1198,\n", + " -0.1518, -0.1653, -0.1304, -0.1242, -0.1329, -0.1106, -0.1312, -0.1558,\n", + " -0.1402, -0.0926, -0.1435, -0.0997, -0.1235, -0.1709, -0.1523, -0.1424,\n", + " -0.1248, -0.1423, -0.1372, -0.1385, -0.1349, -0.1320, -0.1480, -0.1883,\n", + " -0.1528, -0.1106, -0.1115, -0.1545, -0.2491, -0.1437, -0.1366, -0.1148,\n", + " -0.1526, -0.0981, -0.1287, -0.0970, -0.1171, -0.1253, -0.1360, -0.1191,\n", + " -0.1091, -0.1361, -0.1243, -0.1733, -0.1326, -0.1444, -0.1465, -0.1514,\n", + " -0.1156, -0.1539, -0.1099, -0.1416, -0.1336, -0.1725, -0.1333, -0.1748,\n", + " -0.1506, -0.1016, -0.0820, -0.1540, -0.1404, -0.1233, -0.1787, -0.1224,\n", + " -0.1389, -0.1248, -0.1288, -0.1181, -0.1983, -0.1307, -0.1433, -0.1397,\n", + " -0.1570, -0.1298, -0.1285, -0.1475, -0.1472, -0.0788, -0.0871, -0.1186,\n", + " -0.1054, -0.1655, -0.1623, -0.1446, -0.0858, -0.1977, -0.1148, -0.1272,\n", + " -0.1156, -0.1302, -0.0837, -0.1272, -0.1368, -0.1431, -0.1257, -0.1206,\n", + " -0.1128, -0.1089, -0.1071, -0.1777, -0.1622, -0.1102, -0.2165, -0.1409,\n", + " -0.2035, -0.0944, -0.1122, -0.1559, -0.1246, -0.1746, -0.1242, -0.1357,\n", + " -0.1038, -0.1866, -0.1053, -0.1582, -0.1073, -0.1082, -0.0941, -0.1234,\n", + " -0.1571, -0.1522, -0.1285, -0.1717, -0.1760, -0.1226, -0.1873, -0.1178,\n", + " -0.1140, -0.1283, -0.1302, -0.1645, -0.1375, -0.1337, -0.1517, -0.1147,\n", + " -0.1548, -0.1192, -0.1427, -0.1613, -0.1634, -0.1307, -0.1150, -0.1227,\n", + " -0.1003, -0.1405, -0.1071, -0.1345, -0.1354, -0.1312, -0.0875, -0.1288,\n", + " -0.1407, -0.1009, -0.1498, -0.1397, -0.1114, -0.1694, -0.1349, -0.1294,\n", + " -0.1948, -0.1227, -0.1455, -0.1091, -0.1289, -0.0721, -0.1536, -0.1525,\n", + " -0.1659, -0.1372, -0.1063, -0.1048, -0.1303, -0.1184, -0.2030, -0.1300,\n", + " -0.1413, -0.1604, -0.1453, -0.1992, -0.1359, -0.1073, -0.1755, -0.1209,\n", + " -0.1073, -0.1238, -0.1732, -0.1200, -0.0910, -0.1428, -0.1074, -0.1454,\n", + " -0.1542, -0.1509, -0.1703, -0.1571, -0.1164, -0.1268, -0.1080, -0.1576,\n", + " -0.1055, -0.1228, -0.1551, -0.1254, -0.1390, -0.1149, -0.1284, -0.1413])\n", + "layer4.0.shortcut_conv.weight Parameter containing:\n", + "tensor([[[[-0.0372]],\n", + "\n", + " [[-0.0450]],\n", + "\n", + " [[-0.0141]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0531]],\n", + "\n", + " [[ 0.0417]],\n", + "\n", + " [[ 0.0791]]],\n", + "\n", + "\n", + " [[[-0.0139]],\n", + "\n", + " [[ 0.0392]],\n", + "\n", + " [[-0.0129]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0276]],\n", + "\n", + " [[ 0.0577]],\n", + "\n", + " [[ 0.0175]]],\n", + "\n", + "\n", + " [[[-0.0635]],\n", + "\n", + " [[ 0.0519]],\n", + "\n", + " [[-0.0476]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0614]],\n", + "\n", + " [[-0.0259]],\n", + "\n", + " [[-0.0251]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[ 0.0758]],\n", + "\n", + " [[ 0.0010]],\n", + "\n", + " [[-0.0365]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0223]],\n", + "\n", + " [[-0.0071]],\n", + "\n", + " [[ 0.0088]]],\n", + "\n", + "\n", + " [[[-0.0509]],\n", + "\n", + " [[-0.0467]],\n", + "\n", + " [[ 0.0209]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0022]],\n", + "\n", + " [[-0.0226]],\n", + "\n", + " [[ 0.0770]]],\n", + "\n", + "\n", + " [[[-0.0103]],\n", + "\n", + " [[ 0.0359]],\n", + "\n", + " [[-0.0126]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.0311]],\n", + "\n", + " [[ 0.0278]],\n", + "\n", + " [[-0.0614]]]])\n", + "layer4.0.shortcut_bn.weight Parameter containing:\n", + "tensor([0.9454, 0.9364, 0.8905, 0.9216, 0.9349, 0.9505, 0.9566, 0.9267, 0.9410,\n", + " 0.9398, 0.9523, 0.9230, 0.9259, 0.9375, 0.9117, 0.9437, 0.9479, 0.9363,\n", + " 0.9458, 0.9379, 0.9648, 0.9435, 0.9325, 0.9490, 0.9359, 0.9109, 0.9319,\n", + " 0.9613, 0.9306, 0.9377, 0.9287, 0.9642, 0.9484, 0.9523, 0.9468, 0.9423,\n", + " 0.9534, 0.9389, 0.9448, 0.8943, 0.9367, 0.9419, 0.9420, 0.9698, 0.9606,\n", + " 0.9161, 0.9627, 0.9337, 0.9162, 0.9507, 0.9376, 0.9306, 0.9499, 0.9384,\n", + " 0.9663, 0.9517, 0.9150, 0.9380, 0.9580, 0.9460, 0.9379, 0.9400, 0.9419,\n", + " 0.9045, 0.9396, 0.9372, 0.9189, 0.9340, 0.9214, 0.9442, 0.9297, 0.9199,\n", + " 0.9387, 0.9416, 0.9435, 0.9209, 0.9542, 0.9221, 0.9461, 0.9202, 0.9453,\n", + " 0.9545, 0.9388, 0.9386, 0.9114, 0.9444, 0.9671, 0.9598, 0.8984, 0.9551,\n", + " 0.9159, 0.9347, 0.9070, 0.9479, 0.9400, 0.9062, 0.9460, 0.9422, 0.9532,\n", + " 0.9363, 0.9118, 0.9303, 0.9167, 0.9567, 0.9494, 0.9324, 0.9519, 0.9788,\n", + " 0.9430, 0.9440, 0.8834, 0.9436, 0.9480, 0.9337, 0.9357, 0.9402, 0.8856,\n", + " 0.9098, 0.9426, 0.9636, 0.9661, 0.9518, 0.9420, 0.9275, 0.9252, 0.9419,\n", + " 0.9410, 0.9383, 0.9162, 0.9629, 0.9515, 0.9156, 0.9188, 0.9282, 0.9260,\n", + " 0.9329, 0.9382, 0.9005, 0.9181, 0.9358, 0.9311, 0.9387, 0.9482, 0.9345,\n", + " 0.9523, 0.9490, 0.9399, 0.9148, 0.9448, 0.9227, 0.9735, 0.9258, 0.9416,\n", + " 0.9587, 0.9017, 0.9389, 0.9152, 0.9303, 0.9709, 0.9155, 0.9271, 0.9396,\n", + " 0.9448, 0.9264, 0.9333, 0.9276, 0.9324, 0.9544, 0.9016, 0.9318, 0.9596,\n", + " 0.9114, 0.9449, 0.9452, 0.9285, 0.9156, 0.9248, 0.9164, 0.9393, 0.9621,\n", + " 0.9354, 0.9452, 0.9382, 0.9522, 0.9404, 0.9453, 0.9509, 0.9473, 0.9435,\n", + " 0.9179, 0.9595, 0.9643, 0.9439, 0.9360, 0.9280, 0.9137, 0.9217, 0.9421,\n", + " 0.9285, 0.9196, 0.9618, 0.9617, 0.9308, 0.9222, 0.9328, 0.9397, 0.9418,\n", + " 0.9508, 0.9493, 0.9743, 0.9421, 0.9489, 0.9166, 0.9135, 0.9594, 0.9322,\n", + " 0.9255, 0.8959, 0.9352, 0.9526, 0.9536, 0.9314, 0.9578, 0.8772, 0.9699,\n", + " 0.9241, 0.9499, 0.9198, 0.9376, 0.9467, 0.9394, 0.9376, 0.9464, 0.9573,\n", + " 0.9343, 0.9149, 0.9295, 0.9281, 0.9246, 0.9344, 0.9389, 0.9547, 0.9402,\n", + " 0.9658, 0.9411, 0.9382, 0.9156, 0.9437, 0.9275, 0.9152, 0.9184, 0.9217,\n", + " 0.9339, 0.9276, 0.9419, 0.9479, 0.9600, 0.9451, 0.9488, 0.9270, 0.9495,\n", + " 0.9479, 0.9402, 0.9447, 0.9458, 0.9371, 0.9498, 0.8992, 0.9151, 0.9553,\n", + " 0.9536, 0.9450, 0.9511, 0.9409, 0.9393, 0.9152, 0.9384, 0.9335, 0.9241,\n", + " 0.8940, 0.9487, 0.9354, 0.9424, 0.9438, 0.9535, 0.9028, 0.9661, 0.9582,\n", + " 0.9399, 0.9600, 0.9415, 0.9428, 0.9680, 0.9553, 0.9248, 0.9557, 0.9113,\n", + " 0.9374, 0.9374, 0.9461, 0.9443, 0.9416, 0.9279, 0.9388, 0.9399, 0.9466,\n", + " 0.9435, 0.9422, 0.9609, 0.9509, 0.9310, 0.9209, 0.9452, 0.9577, 0.9346,\n", + " 0.9400, 0.9428, 0.9525, 0.9346, 0.9641, 0.9349, 0.9470, 0.9137, 0.9350,\n", + " 0.9322, 0.9196, 0.9232, 0.9252, 0.9330, 0.9467, 0.9543, 0.9248, 0.8629,\n", + " 0.9180, 0.9414, 0.9429, 0.9231, 0.9056, 0.9382, 0.9187, 0.9495, 0.9341,\n", + " 0.9322, 0.9511, 0.9465, 0.9153, 0.9365, 0.9653, 0.9359, 0.9252, 0.9117,\n", + " 0.9598, 0.9450, 0.9524, 0.9153, 0.9032, 0.9297, 0.8921, 0.9509, 0.9215,\n", + " 0.9579, 0.9169, 0.9369, 0.9061, 0.9091, 0.9617, 0.9450, 0.9466, 0.9391,\n", + " 0.9113, 0.9307, 0.9253, 0.9217, 0.9238, 0.9481, 0.9549, 0.9279, 0.9138,\n", + " 0.9555, 0.9042, 0.9309, 0.9231, 0.9194, 0.9409, 0.9385, 0.9175, 0.9201,\n", + " 0.9591, 0.9504, 0.9133, 0.9520, 0.9449, 0.9599, 0.9433, 0.9160, 0.9297,\n", + " 0.9570, 0.9110, 0.9338, 0.9370, 0.9589, 0.9375, 0.9261, 0.9169, 0.9310,\n", + " 0.9374, 0.9363, 0.9514, 0.9272, 0.9535, 0.9502, 0.9374, 0.9295, 0.9511,\n", + " 0.9599, 0.9449, 0.9497, 0.8779, 0.9328, 0.9454, 0.9495, 0.9331, 0.9647,\n", + " 0.9698, 0.9484, 0.9300, 0.9280, 0.9345, 0.9237, 0.9337, 0.9030, 0.9565,\n", + " 0.9279, 0.9271, 0.9723, 0.9468, 0.9512, 0.9526, 0.9374, 0.9558, 0.9394,\n", + " 0.9587, 0.9495, 0.9402, 0.9312, 0.9395, 0.9419, 0.9758, 0.9342, 0.9564,\n", + " 0.9520, 0.9487, 0.9542, 0.9476, 0.9504, 0.9469, 0.9274, 0.9562, 0.9309,\n", + " 0.9451, 0.9384, 0.9386, 0.9480, 0.9413, 0.9142, 0.9645, 0.9315, 0.9449,\n", + " 0.9300, 0.9518, 0.9041, 0.9299, 0.9298, 0.9482, 0.9564, 0.9643, 0.9353,\n", + " 0.9628, 0.9437, 0.9412, 0.9463, 0.9266, 0.9471, 0.9221, 0.9278, 0.9626,\n", + " 0.9441, 0.9537, 0.9464, 0.9413, 0.9305, 0.9370, 0.9566, 0.9322, 0.9464,\n", + " 0.9445, 0.9437, 0.9459, 0.9764, 0.9289, 0.9477, 0.9661, 0.9630, 0.9282,\n", + " 0.9523, 0.9287, 0.9353, 0.9453, 0.9447, 0.9469, 0.9324, 0.9427])\n", + "layer4.0.shortcut_bn.bias Parameter containing:\n", + "tensor([-0.1077, -0.1421, -0.1821, -0.1668, -0.1509, -0.1312, -0.1380, -0.1309,\n", + " -0.1376, -0.0952, -0.1164, -0.1412, -0.1481, -0.1340, -0.0982, -0.1111,\n", + " -0.1355, -0.1359, -0.0763, -0.1281, -0.1032, -0.1161, -0.1639, -0.1614,\n", + " -0.1359, -0.1743, -0.1449, -0.1202, -0.1613, -0.1218, -0.0912, -0.1359,\n", + " -0.1473, -0.1254, -0.1260, -0.1315, -0.1130, -0.1370, -0.1554, -0.1598,\n", + " -0.1412, -0.1671, -0.1722, -0.1153, -0.1522, -0.1593, -0.1292, -0.1383,\n", + " -0.1708, -0.1354, -0.1689, -0.1445, -0.1057, -0.1420, -0.1475, -0.1008,\n", + " -0.1661, -0.1452, -0.1685, -0.0975, -0.1722, -0.1182, -0.0939, -0.1553,\n", + " -0.1652, -0.1251, -0.1249, -0.1043, -0.1387, -0.1344, -0.1790, -0.1616,\n", + " -0.1750, -0.1710, -0.1289, -0.1422, -0.1303, -0.1413, -0.1527, -0.1282,\n", + " -0.1279, -0.1226, -0.1151, -0.1201, -0.1322, -0.1368, -0.1222, -0.1290,\n", + " -0.1492, -0.1467, -0.1736, -0.1057, -0.1849, -0.0941, -0.1291, -0.1398,\n", + " -0.1382, -0.1133, -0.1222, -0.1267, -0.1743, -0.1167, -0.1169, -0.1327,\n", + " -0.1233, -0.1442, -0.1083, -0.1543, -0.1278, -0.1602, -0.1623, -0.1351,\n", + " -0.1327, -0.1430, -0.1688, -0.1055, -0.1419, -0.1694, -0.1434, -0.1472,\n", + " -0.1157, -0.1287, -0.1160, -0.0923, -0.1165, -0.1386, -0.1188, -0.1284,\n", + " -0.1347, -0.1452, -0.1594, -0.1360, -0.1470, -0.1863, -0.1705, -0.1492,\n", + " -0.1368, -0.1487, -0.1645, -0.1125, -0.1228, -0.1313, -0.1517, -0.1273,\n", + " -0.1473, -0.1441, -0.1204, -0.1410, -0.1254, -0.1298, -0.1175, -0.1477,\n", + " -0.1687, -0.1658, -0.1260, -0.1066, -0.1634, -0.1069, -0.1038, -0.1724,\n", + " -0.1575, -0.1456, -0.1222, -0.1235, -0.1357, -0.1030, -0.1516, -0.1180,\n", + " -0.2283, -0.1081, -0.0908, -0.1186, -0.1629, -0.1338, -0.1590, -0.1619,\n", + " -0.1434, -0.1391, -0.1531, -0.1400, -0.1208, -0.1121, -0.1173, -0.1413,\n", + " -0.1178, -0.1539, -0.1020, -0.1621, -0.1607, -0.1491, -0.1088, -0.0963,\n", + " -0.1369, -0.1579, -0.1558, -0.1364, -0.1629, -0.1167, -0.1064, -0.1564,\n", + " -0.1016, -0.1513, -0.1344, -0.1452, -0.1330, -0.1143, -0.1711, -0.1544,\n", + " -0.1010, -0.1073, -0.1114, -0.1060, -0.2010, -0.1259, -0.1225, -0.1472,\n", + " -0.1328, -0.2026, -0.1359, -0.1276, -0.1133, -0.1622, -0.1322, -0.1666,\n", + " -0.1334, -0.1379, -0.1284, -0.1683, -0.1491, -0.1184, -0.1172, -0.1208,\n", + " -0.1506, -0.1204, -0.1467, -0.1479, -0.1494, -0.0968, -0.1067, -0.1380,\n", + " -0.1635, -0.0869, -0.1374, -0.1196, -0.1598, -0.1280, -0.1609, -0.1608,\n", + " -0.1404, -0.1249, -0.2054, -0.1497, -0.1637, -0.1352, -0.1547, -0.1162,\n", + " -0.1721, -0.1412, -0.1409, -0.1169, -0.1034, -0.1353, -0.1230, -0.1313,\n", + " -0.1110, -0.1485, -0.1141, -0.1649, -0.1388, -0.0882, -0.1378, -0.1768,\n", + " -0.1178, -0.1274, -0.1642, -0.1661, -0.1437, -0.1595, -0.1388, -0.2104,\n", + " -0.1090, -0.1467, -0.1431, -0.1173, -0.1046, -0.1935, -0.1478, -0.1360,\n", + " -0.1500, -0.1175, -0.1217, -0.1539, -0.1516, -0.1312, -0.1663, -0.0784,\n", + " -0.1408, -0.1435, -0.1334, -0.1537, -0.1486, -0.1203, -0.1474, -0.1198,\n", + " -0.1518, -0.1653, -0.1304, -0.1242, -0.1329, -0.1106, -0.1312, -0.1558,\n", + " -0.1402, -0.0926, -0.1435, -0.0997, -0.1235, -0.1709, -0.1523, -0.1424,\n", + " -0.1248, -0.1423, -0.1372, -0.1385, -0.1349, -0.1320, -0.1480, -0.1883,\n", + " -0.1528, -0.1106, -0.1115, -0.1545, -0.2491, -0.1437, -0.1366, -0.1148,\n", + " -0.1526, -0.0981, -0.1287, -0.0970, -0.1171, -0.1253, -0.1360, -0.1191,\n", + " -0.1091, -0.1361, -0.1243, -0.1733, -0.1326, -0.1444, -0.1465, -0.1514,\n", + " -0.1156, -0.1539, -0.1099, -0.1416, -0.1336, -0.1725, -0.1333, -0.1748,\n", + " -0.1506, -0.1016, -0.0820, -0.1540, -0.1404, -0.1233, -0.1787, -0.1224,\n", + " -0.1389, -0.1248, -0.1288, -0.1181, -0.1983, -0.1307, -0.1433, -0.1397,\n", + " -0.1570, -0.1298, -0.1285, -0.1475, -0.1472, -0.0788, -0.0871, -0.1186,\n", + " -0.1054, -0.1655, -0.1623, -0.1446, -0.0858, -0.1977, -0.1148, -0.1272,\n", + " -0.1156, -0.1302, -0.0837, -0.1272, -0.1368, -0.1431, -0.1257, -0.1206,\n", + " -0.1128, -0.1089, -0.1071, -0.1777, -0.1622, -0.1102, -0.2165, -0.1409,\n", + " -0.2035, -0.0944, -0.1122, -0.1559, -0.1246, -0.1746, -0.1242, -0.1357,\n", + " -0.1038, -0.1866, -0.1053, -0.1582, -0.1073, -0.1082, -0.0941, -0.1234,\n", + " -0.1571, -0.1522, -0.1285, -0.1717, -0.1760, -0.1226, -0.1873, -0.1178,\n", + " -0.1140, -0.1283, -0.1302, -0.1645, -0.1375, -0.1337, -0.1517, -0.1147,\n", + " -0.1548, -0.1192, -0.1427, -0.1613, -0.1634, -0.1307, -0.1150, -0.1227,\n", + " -0.1003, -0.1405, -0.1071, -0.1345, -0.1354, -0.1312, -0.0875, -0.1288,\n", + " -0.1407, -0.1009, -0.1498, -0.1397, -0.1114, -0.1694, -0.1349, -0.1294,\n", + " -0.1948, -0.1227, -0.1455, -0.1091, -0.1289, -0.0721, -0.1536, -0.1525,\n", + " -0.1659, -0.1372, -0.1063, -0.1048, -0.1303, -0.1184, -0.2030, -0.1300,\n", + " -0.1413, -0.1604, -0.1453, -0.1992, -0.1359, -0.1073, -0.1755, -0.1209,\n", + " -0.1073, -0.1238, -0.1732, -0.1200, -0.0910, -0.1428, -0.1074, -0.1454,\n", + " -0.1542, -0.1509, -0.1703, -0.1571, -0.1164, -0.1268, -0.1080, -0.1576,\n", + " -0.1055, -0.1228, -0.1551, -0.1254, -0.1390, -0.1149, -0.1284, -0.1413])\n", + "layer4.1.conv1.weight Parameter containing:\n", + "tensor([[[[-1.1948e-02, -3.3142e-03, 1.0257e-02],\n", + " [-1.0413e-02, 4.3039e-03, 3.6476e-03],\n", + " [ 8.4051e-03, 1.0929e-02, -1.2016e-02]],\n", + "\n", + " [[-2.0995e-04, 1.1175e-02, -9.2195e-03],\n", + " [ 1.3481e-03, -1.4488e-02, -2.9663e-02],\n", + " [-4.7732e-03, -1.0124e-02, -1.2327e-02]],\n", + "\n", + " [[-2.1531e-02, -5.3558e-03, -1.5935e-02],\n", + " [-1.2236e-02, -1.0232e-02, -3.2332e-02],\n", + " [-2.7034e-02, -2.6549e-02, -2.6104e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 5.3538e-04, 1.6756e-02, -1.2413e-02],\n", + " [ 1.9543e-03, 3.2035e-03, -1.4134e-02],\n", + " [-2.7545e-04, -2.1481e-03, 2.1559e-02]],\n", + "\n", + " [[-9.9980e-03, 7.8874e-03, -8.9226e-03],\n", + " [ 6.6112e-03, -2.9001e-03, -1.2756e-02],\n", + " [ 2.3838e-04, 5.3704e-03, -7.9757e-03]],\n", + "\n", + " [[ 3.8782e-03, -8.8496e-03, 2.8949e-02],\n", + " [-2.1215e-02, 1.0959e-02, -1.3079e-03],\n", + " [ 2.4892e-04, -2.5248e-03, -1.0246e-02]]],\n", + "\n", + "\n", + " [[[ 5.5631e-03, -9.3964e-04, 4.2188e-03],\n", + " [-1.3459e-02, -6.8220e-03, -1.7020e-02],\n", + " [-4.0890e-02, -2.8097e-02, -1.3561e-02]],\n", + "\n", + " [[-6.0051e-03, -7.8186e-03, 4.0130e-03],\n", + " [-4.4814e-04, -2.0547e-02, -1.1539e-02],\n", + " [ 1.2592e-02, -1.8088e-02, -3.7631e-03]],\n", + "\n", + " [[-8.2642e-03, -5.4152e-03, 2.7797e-03],\n", + " [-1.3859e-03, 8.2239e-03, 4.0218e-04],\n", + " [ 8.4035e-03, 1.4251e-02, 4.7447e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[ 2.0392e-02, 7.7357e-03, 4.1268e-03],\n", + " [-6.8086e-03, 9.4403e-03, 5.3178e-03],\n", + " [ 1.9140e-02, -3.3532e-02, 1.0914e-03]],\n", + "\n", + " [[ 2.3713e-03, 2.2750e-02, 7.3886e-03],\n", + " [-1.9896e-02, -8.8388e-03, 1.8477e-02],\n", + " [-1.7475e-02, 1.4251e-03, 1.3589e-02]],\n", + "\n", + " [[ 1.3121e-02, 4.1141e-02, 7.3961e-03],\n", + " [ 1.3370e-02, 6.1834e-03, 2.0109e-02],\n", + " [ 2.5821e-02, 1.9590e-02, 2.4656e-02]]],\n", + "\n", + "\n", + " [[[ 9.9412e-03, 1.5238e-02, 2.0204e-02],\n", + " [ 4.8507e-03, 1.7608e-02, 5.5749e-03],\n", + " [-1.2087e-02, 1.0358e-02, 1.1422e-02]],\n", + "\n", + " [[-2.1776e-02, -5.6003e-03, 5.6663e-03],\n", + " [ 2.1339e-02, 5.1282e-03, 1.7236e-02],\n", + " [-2.4920e-03, 1.2391e-02, 1.8380e-02]],\n", + "\n", + " [[-1.7954e-02, -1.2279e-02, -1.3557e-02],\n", + " [ 2.7206e-03, 9.9728e-04, 1.4261e-02],\n", + " [-1.4279e-03, 2.9682e-04, -7.3328e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[-2.3195e-02, 1.6100e-02, -7.7113e-06],\n", + " [-1.0721e-02, -2.1473e-02, 1.6510e-02],\n", + " [ 4.8150e-03, 1.8678e-03, -2.0047e-02]],\n", + "\n", + " [[-1.3962e-02, -8.0512e-03, -1.5752e-02],\n", + " [ 6.9845e-03, -4.3029e-03, 1.1148e-02],\n", + " [ 3.1242e-03, -1.9239e-02, 4.6783e-03]],\n", + "\n", + " [[ 6.3272e-03, -6.4049e-03, -1.9009e-02],\n", + " [-7.0409e-03, -1.3461e-02, -1.2215e-02],\n", + " [ 1.7156e-03, 2.9771e-03, 5.8117e-03]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[-1.0419e-02, -6.0178e-03, -1.4953e-02],\n", + " [ 2.8744e-03, -1.3205e-02, 1.2708e-03],\n", + " [ 6.5012e-04, 1.0050e-03, -9.5456e-03]],\n", + "\n", + " [[-1.1000e-02, -2.2153e-04, -1.5086e-03],\n", + " [ 1.0215e-02, -9.4436e-03, 5.2485e-03],\n", + " [-2.4163e-02, -2.5928e-02, -1.3547e-02]],\n", + "\n", + " [[ 5.4092e-03, -6.3877e-03, -1.0579e-02],\n", + " [-8.7557e-04, 1.9436e-02, -4.6480e-03],\n", + " [-1.2522e-02, -3.8739e-03, -1.0423e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[-7.6105e-03, -3.4862e-03, 1.1735e-02],\n", + " [-2.6591e-02, -9.0118e-03, -1.8351e-02],\n", + " [-1.9584e-02, -4.1147e-03, -1.3313e-02]],\n", + "\n", + " [[ 2.4424e-02, 2.6349e-03, 2.4604e-03],\n", + " [ 1.8698e-03, -6.0177e-03, -1.6583e-02],\n", + " [ 8.0879e-03, -2.2503e-03, -2.3768e-02]],\n", + "\n", + " [[-1.1821e-02, 1.1441e-03, 1.3298e-02],\n", + " [-1.4603e-03, 2.7675e-03, -4.1455e-03],\n", + " [ 1.2778e-02, 1.9607e-02, 5.1929e-03]]],\n", + "\n", + "\n", + " [[[ 1.1794e-02, -1.9382e-03, -1.0644e-02],\n", + " [-1.8139e-02, 1.3198e-03, 2.2425e-03],\n", + " [ 1.0553e-03, 2.1395e-03, -3.7677e-03]],\n", + "\n", + " [[-2.0456e-02, -1.6302e-02, 1.1988e-02],\n", + " [-1.3381e-02, -2.8377e-02, -1.1368e-02],\n", + " [ 1.2046e-02, -9.7604e-03, -1.4571e-03]],\n", + "\n", + " [[-5.2547e-03, 2.3563e-03, -2.8907e-02],\n", + " [-4.2318e-03, 2.1925e-02, 3.8629e-03],\n", + " [ 5.9383e-03, -1.0632e-02, 2.3464e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[ 7.0141e-04, 5.0503e-03, -9.5930e-03],\n", + " [ 2.6526e-02, 1.5373e-02, 1.9513e-02],\n", + " [-4.0049e-03, -2.8385e-02, -2.0793e-02]],\n", + "\n", + " [[ 1.4260e-02, 1.0030e-02, 2.5354e-02],\n", + " [-7.2404e-03, -2.1980e-02, 6.5759e-03],\n", + " [ 1.8276e-02, 6.0751e-03, 1.1317e-02]],\n", + "\n", + " [[ 2.6095e-02, 2.3253e-02, 7.7629e-03],\n", + " [ 1.2630e-02, 1.5899e-02, -1.1508e-02],\n", + " [-4.3151e-03, -8.6703e-04, 5.9518e-03]]],\n", + "\n", + "\n", + " [[[ 1.7142e-02, 6.6391e-03, -3.0989e-03],\n", + " [-2.6798e-02, -2.7884e-02, -5.9585e-03],\n", + " [-5.1605e-03, -5.6922e-03, -5.1046e-04]],\n", + "\n", + " [[ 1.0339e-02, 1.4060e-02, 8.2361e-03],\n", + " [ 2.2945e-02, 1.6656e-02, 4.4532e-03],\n", + " [-9.3033e-03, -1.7460e-02, -1.4590e-02]],\n", + "\n", + " [[ 6.1705e-03, -1.2017e-02, -1.1784e-02],\n", + " [ 1.8722e-02, 1.7385e-03, 6.6652e-03],\n", + " [-8.8163e-03, 1.6646e-02, 1.8388e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 3.6180e-02, 2.2804e-02, 8.0334e-03],\n", + " [ 7.5858e-03, -3.0590e-03, -1.0771e-02],\n", + " [ 2.6141e-02, 6.5954e-04, 3.5041e-03]],\n", + "\n", + " [[ 1.1716e-02, 2.4431e-03, 1.5212e-02],\n", + " [-8.8621e-03, 1.3009e-02, 8.3022e-03],\n", + " [-1.0354e-02, -3.3272e-02, 4.8393e-03]],\n", + "\n", + " [[-6.0150e-06, 8.0425e-03, 2.5490e-02],\n", + " [ 1.2990e-02, -1.5889e-02, 2.3843e-02],\n", + " [ 2.6896e-03, 1.3003e-02, 1.4933e-02]]]])\n", + "layer4.1.bn1.weight Parameter containing:\n", + "tensor([0.9805, 0.9846, 0.9759, 0.9868, 0.9540, 0.9773, 0.9970, 1.0088, 0.9459,\n", + " 1.0135, 0.9817, 1.0045, 0.9881, 0.9783, 1.0104, 0.9747, 0.9724, 1.0108,\n", + " 1.0235, 1.0103, 0.9986, 0.9730, 0.9935, 0.9986, 1.0021, 1.0400, 0.9961,\n", + " 0.9527, 1.0376, 1.0096, 0.9515, 0.9864, 0.9696, 0.9796, 0.9870, 0.9595,\n", + " 0.9694, 0.9824, 0.9672, 1.0294, 1.0250, 0.9979, 0.9760, 0.9887, 0.9431,\n", + " 1.0026, 0.9854, 0.9683, 1.0087, 0.9879, 0.9638, 0.9985, 1.0065, 1.0173,\n", + " 1.0125, 0.9784, 0.9980, 1.0158, 1.0134, 0.9912, 0.9901, 1.0090, 0.9841,\n", + " 0.9555, 0.9996, 0.9847, 0.9812, 1.0173, 1.0262, 1.0204, 1.0023, 1.0182,\n", + " 1.0031, 1.0051, 0.9728, 0.9804, 1.0384, 0.9685, 0.9646, 0.9908, 0.9813,\n", + " 0.9950, 0.9732, 1.0086, 0.9890, 1.0129, 1.0204, 0.9924, 0.9849, 0.9708,\n", + " 1.0072, 1.0065, 0.9797, 1.0123, 0.9609, 0.9864, 1.0090, 0.9948, 0.9661,\n", + " 1.0218, 1.0108, 0.9750, 0.9779, 0.9953, 0.9782, 0.9789, 0.9951, 0.9922,\n", + " 1.0282, 1.0167, 1.0085, 0.9837, 0.9874, 1.0074, 1.0155, 1.1245, 0.9838,\n", + " 1.0047, 0.9883, 1.0097, 1.0006, 0.9961, 1.0088, 0.9767, 1.0109, 0.9652,\n", + " 1.0224, 0.9591, 0.9531, 1.0059, 1.0103, 0.9987, 0.9877, 1.0158, 0.9940,\n", + " 0.9723, 1.0097, 1.0078, 0.9928, 0.9844, 1.0161, 0.9910, 0.9563, 0.9602,\n", + " 1.0367, 0.9833, 1.0116, 0.9996, 0.9971, 1.0012, 0.9933, 1.0024, 0.9923,\n", + " 0.9714, 0.9900, 1.0053, 1.0053, 1.0154, 0.9536, 1.0031, 0.9785, 0.9632,\n", + " 0.9639, 1.0079, 1.0122, 0.9868, 1.0039, 0.9561, 1.0111, 1.0003, 0.9817,\n", + " 0.9626, 1.0205, 0.9594, 0.9970, 0.9981, 1.0043, 1.0301, 0.9845, 0.9772,\n", + " 0.9690, 0.9833, 0.9810, 0.9739, 0.9751, 0.9919, 1.0352, 1.0192, 0.9927,\n", + " 0.9860, 1.0235, 0.9994, 0.9919, 1.0123, 0.9797, 0.9880, 1.0050, 0.9768,\n", + " 0.9599, 0.9904, 0.9890, 1.0258, 0.9771, 1.0327, 0.9756, 1.0120, 0.9563,\n", + " 1.0169, 1.0041, 0.9974, 0.9641, 0.9446, 0.9659, 0.9797, 0.9486, 0.9336,\n", + " 1.0025, 0.9776, 0.9898, 0.9870, 1.0090, 0.9526, 0.9759, 1.0003, 0.9790,\n", + " 0.9954, 1.0057, 1.0005, 1.0094, 0.9701, 0.9780, 0.9816, 0.9897, 0.9684,\n", + " 0.9891, 1.0065, 1.0226, 0.9636, 0.9637, 0.9261, 1.0008, 1.0651, 0.9930,\n", + " 0.9958, 0.9700, 0.9869, 0.9898, 1.0047, 1.0040, 1.0012, 1.0196, 0.9930,\n", + " 1.0209, 0.9953, 0.9851, 1.0046, 0.9682, 0.9725, 1.0114, 0.9976, 1.0024,\n", + " 1.0266, 1.0152, 0.9489, 0.9990, 0.9975, 0.9912, 1.0132, 0.9978, 0.9912,\n", + " 0.9794, 0.9792, 0.9685, 1.0182, 0.9790, 0.9918, 0.9876, 1.0032, 1.0073,\n", + " 1.0125, 1.0061, 0.9674, 1.0080, 0.9517, 0.9852, 1.0098, 0.9646, 0.9963,\n", + " 0.9974, 0.9805, 0.9975, 1.0254, 0.9624, 0.9958, 0.9880, 1.0272, 1.0107,\n", + " 1.0114, 1.0068, 0.9564, 0.9969, 0.9830, 1.0223, 1.0025, 0.9901, 1.0058,\n", + " 0.9905, 0.9872, 1.0171, 0.9750, 0.9838, 1.0035, 0.9567, 1.0092, 1.0087,\n", + " 0.9281, 0.9962, 0.9900, 0.9797, 0.9860, 1.0171, 0.9944, 0.9739, 1.0253,\n", + " 0.9776, 1.0007, 0.9641, 0.9815, 1.0056, 0.9797, 1.0091, 0.9975, 0.9990,\n", + " 1.0092, 1.0226, 1.0119, 0.9579, 1.0084, 1.0219, 0.9920, 0.9944, 1.0221,\n", + " 0.9891, 1.0069, 0.9719, 1.0181, 0.9983, 1.0006, 1.0074, 0.9991, 0.9806,\n", + " 0.9645, 0.9950, 0.9887, 1.0078, 0.9995, 1.0178, 0.9971, 0.9745, 0.9846,\n", + " 1.0115, 0.9732, 0.9848, 0.9896, 0.9884, 1.0157, 1.0317, 1.0030, 0.9966,\n", + " 0.9446, 0.9885, 1.0239, 0.9931, 0.9852, 1.0077, 0.9649, 0.9880, 1.0098,\n", + " 0.9700, 1.0080, 1.0383, 0.9895, 0.9959, 0.9903, 1.0276, 0.9585, 0.9994,\n", + " 0.9939, 0.9725, 1.0185, 1.0011, 0.9972, 0.9493, 0.9815, 0.9937, 1.0199,\n", + " 0.9811, 1.0165, 1.0112, 1.0096, 0.9968, 0.9923, 0.9888, 0.9956, 1.0406,\n", + " 0.9957, 1.0087, 0.9953, 0.9961, 0.9790, 0.9766, 0.9792, 0.9992, 0.9895,\n", + " 0.9994, 0.9961, 0.9934, 0.9956, 0.9952, 0.9332, 0.9898, 0.9428, 0.9966,\n", + " 0.9634, 0.9670, 1.0047, 1.0183, 0.9669, 0.9316, 0.9917, 1.0281, 0.9686,\n", + " 0.9995, 1.0016, 1.0001, 0.9854, 0.9899, 0.9941, 0.9777, 0.9883, 0.9760,\n", + " 0.9763, 1.0214, 0.9761, 0.9893, 1.0096, 1.0141, 0.9906, 1.0030, 0.9963,\n", + " 0.9774, 1.0134, 0.9830, 1.0211, 1.0154, 0.9978, 0.9926, 1.0075, 1.0314,\n", + " 0.9650, 1.0000, 1.0233, 1.0141, 1.0137, 0.9987, 0.9872, 0.9927, 1.0081,\n", + " 1.0121, 0.9797, 1.0049, 0.9735, 1.0206, 0.9841, 0.9722, 1.0008, 0.9902,\n", + " 0.9560, 1.0136, 0.9915, 0.9973, 1.0004, 0.9762, 0.9707, 0.9804, 0.9942,\n", + " 0.9903, 1.0008, 0.9939, 1.0084, 0.9891, 0.9804, 0.9916, 0.9853, 0.9925,\n", + " 0.9732, 0.9747, 0.9848, 0.9919, 0.9396, 1.0009, 0.9794, 0.9971, 1.0151,\n", + " 1.0144, 0.9861, 0.9910, 0.9841, 0.9594, 0.9910, 0.9671, 0.9948])\n", + "layer4.1.bn1.bias Parameter containing:\n", + "tensor([-0.0608, -0.1106, -0.1077, -0.0862, -0.0966, -0.1186, -0.0844, -0.1174,\n", + " -0.1269, -0.1427, -0.1088, -0.1491, -0.1391, -0.1073, -0.0685, -0.1510,\n", + " -0.0840, -0.1244, -0.1091, -0.0766, -0.0785, -0.0935, -0.1171, -0.0805,\n", + " -0.1100, -0.1681, -0.1164, -0.1179, -0.0757, -0.1134, -0.1257, -0.0916,\n", + " -0.1223, -0.0713, -0.1444, -0.0968, -0.0749, -0.1484, -0.1897, -0.1927,\n", + " -0.1047, -0.1176, -0.0849, -0.0631, -0.0967, -0.0724, -0.0853, -0.1475,\n", + " -0.2147, -0.1051, -0.0995, -0.1090, -0.1502, -0.0868, -0.1374, -0.0629,\n", + " -0.1739, -0.0584, -0.1067, -0.1148, -0.1047, -0.1238, -0.1547, -0.1756,\n", + " -0.1839, -0.1441, -0.1100, -0.1534, -0.1557, -0.1117, -0.0860, -0.1076,\n", + " -0.1687, -0.1198, -0.1197, -0.0775, -0.1608, -0.1607, -0.0424, -0.1094,\n", + " -0.0943, -0.0690, -0.1275, -0.1888, -0.1153, -0.0499, -0.1390, -0.1021,\n", + " -0.1419, -0.0608, -0.0607, -0.1145, -0.1181, -0.1917, -0.1177, -0.1105,\n", + " -0.0889, -0.0545, -0.0642, -0.1006, -0.0981, -0.1013, -0.0728, -0.0794,\n", + " -0.0851, -0.1038, -0.1028, -0.0633, -0.1673, -0.1243, -0.1759, -0.0961,\n", + " -0.1261, -0.1214, -0.1933, -0.3384, -0.0859, -0.1718, -0.1494, -0.1796,\n", + " -0.0852, -0.0871, -0.0846, -0.1100, -0.2103, -0.1524, -0.0914, -0.0854,\n", + " -0.1211, -0.0944, -0.2191, -0.2019, -0.0837, -0.0704, -0.0908, -0.1619,\n", + " -0.0854, -0.1587, -0.0729, -0.1280, -0.1806, -0.1186, -0.1148, -0.0597,\n", + " -0.1322, -0.1507, -0.1482, -0.0712, -0.1087, -0.1658, -0.1355, -0.1157,\n", + " -0.0734, -0.1032, -0.0741, -0.1099, -0.0713, -0.0955, -0.1165, -0.1179,\n", + " -0.0909, -0.1548, -0.0737, -0.1618, -0.1329, -0.1200, -0.0681, -0.0856,\n", + " -0.1546, -0.0685, -0.1340, -0.0716, -0.0849, -0.0728, -0.0803, -0.0587,\n", + " -0.1320, -0.1410, -0.1361, -0.1207, -0.0786, -0.1054, -0.0963, -0.0907,\n", + " -0.0963, -0.1160, -0.0913, -0.1016, -0.1288, -0.1441, -0.1249, -0.1070,\n", + " -0.1004, -0.0959, -0.1122, -0.1121, -0.0655, -0.1088, -0.1389, -0.0901,\n", + " -0.1826, -0.2301, -0.0609, -0.0993, -0.1086, -0.1292, -0.1134, -0.1160,\n", + " -0.1750, -0.0728, -0.1286, -0.0742, -0.1474, -0.1549, -0.0860, -0.0976,\n", + " -0.1396, -0.1356, -0.0913, -0.0766, -0.1393, -0.1175, -0.0719, -0.0979,\n", + " -0.1313, -0.1787, -0.0839, -0.0961, -0.1154, -0.1008, -0.1131, -0.0837,\n", + " -0.0736, -0.1305, -0.0767, -0.0568, -0.1916, -0.0585, -0.1332, -0.1342,\n", + " -0.1559, -0.3032, -0.0988, -0.1053, -0.0608, -0.1198, -0.1132, -0.1624,\n", + " -0.1054, -0.0622, -0.1352, -0.1380, -0.1693, -0.1173, -0.0603, -0.1139,\n", + " -0.1178, -0.1277, -0.2131, -0.0784, -0.0519, -0.2277, -0.1055, -0.0931,\n", + " -0.1010, -0.1456, -0.0494, -0.1482, -0.1402, -0.0407, -0.1182, -0.1223,\n", + " -0.1086, -0.1254, -0.1419, -0.0713, -0.1254, -0.0799, -0.0662, -0.1664,\n", + " -0.0995, -0.1176, -0.1106, -0.1328, -0.1268, -0.1084, -0.0495, -0.0686,\n", + " -0.1277, -0.0775, -0.0593, -0.1401, -0.0968, -0.0929, -0.0689, -0.0705,\n", + " -0.1046, -0.1467, -0.1568, -0.0931, -0.1599, -0.0977, -0.1093, -0.0855,\n", + " -0.0871, -0.0872, -0.0939, -0.1416, -0.1200, -0.1911, -0.1143, -0.0808,\n", + " -0.1169, -0.1221, -0.1957, -0.1343, -0.1106, -0.0971, -0.0942, -0.1319,\n", + " -0.1436, -0.0967, -0.0854, -0.1190, -0.1104, -0.1925, -0.0976, -0.1878,\n", + " -0.1187, -0.0819, -0.0849, -0.1193, -0.0944, -0.1071, -0.0700, -0.0783,\n", + " -0.1317, -0.1517, -0.0699, -0.1153, -0.1113, -0.2056, -0.1347, -0.1889,\n", + " -0.1641, -0.0675, -0.1322, -0.1328, -0.1147, -0.1123, -0.1087, -0.1466,\n", + " -0.0859, -0.1492, -0.1310, -0.0981, -0.0566, -0.1827, -0.0625, -0.0621,\n", + " -0.0790, -0.0667, -0.0876, -0.0773, -0.0938, -0.1823, -0.1910, -0.1502,\n", + " -0.1416, -0.0941, -0.1280, -0.1437, -0.1407, -0.0356, -0.1092, -0.1139,\n", + " -0.0943, -0.1434, -0.0862, -0.1915, -0.1062, -0.0828, -0.0734, -0.0818,\n", + " -0.1325, -0.0910, -0.0997, -0.1065, -0.1666, -0.0842, -0.1114, -0.1090,\n", + " -0.1106, -0.1000, -0.1096, -0.2158, -0.1140, -0.1128, -0.0802, -0.0938,\n", + " -0.1751, -0.0998, -0.1030, -0.1405, -0.1131, -0.1083, -0.0755, -0.0839,\n", + " -0.1115, -0.1142, -0.0762, -0.0731, -0.1117, -0.1029, -0.0728, -0.0489,\n", + " -0.0535, -0.1026, -0.1306, -0.1330, -0.1135, -0.0804, -0.2111, -0.1288,\n", + " -0.1213, -0.1006, -0.1309, -0.1047, -0.0553, -0.0531, -0.1103, -0.0808,\n", + " -0.0620, -0.0879, -0.1219, -0.1101, -0.1087, -0.0573, -0.1147, -0.0571,\n", + " -0.0626, -0.1239, -0.2353, -0.1909, -0.2048, -0.0762, -0.1088, -0.1020,\n", + " -0.0687, -0.1871, -0.1064, -0.1255, -0.0896, -0.1078, -0.1170, -0.0867,\n", + " -0.0720, -0.1306, -0.1985, -0.0807, -0.1060, -0.1028, -0.2111, -0.1270,\n", + " -0.0658, -0.0956, -0.1027, -0.1447, -0.2264, -0.1177, -0.1414, -0.1139,\n", + " -0.1850, -0.0533, -0.0952, -0.0898, -0.1356, -0.1355, -0.2044, -0.0864,\n", + " -0.0759, -0.0708, -0.0850, -0.1273, -0.1374, -0.1031, -0.0824, -0.1243,\n", + " -0.1191, -0.0874, -0.1535, -0.1514, -0.1384, -0.0627, -0.0857, -0.1217,\n", + " -0.1527, -0.1059, -0.1265, -0.0937, -0.0859, -0.0676, -0.0940, -0.1024,\n", + " -0.1152, -0.0735, -0.1281, -0.1194, -0.0971, -0.0706, -0.1200, -0.0926])\n", + "layer4.1.conv2.weight Parameter containing:\n", + "tensor([[[[ 8.3330e-03, -8.8291e-03, 1.2275e-02],\n", + " [-1.2374e-02, -8.4982e-03, 9.2257e-03],\n", + " [ 1.2488e-02, -3.5636e-03, 4.6247e-04]],\n", + "\n", + " [[-2.5105e-03, -8.5636e-03, -1.5846e-02],\n", + " [-1.0081e-02, -6.2206e-03, 8.4948e-04],\n", + " [ 7.9979e-03, 4.0291e-03, -1.0241e-02]],\n", + "\n", + " [[ 1.0911e-02, -2.9711e-04, -5.6554e-03],\n", + " [ 2.8144e-03, 8.9447e-03, -1.0828e-02],\n", + " [ 4.2974e-03, 1.0024e-02, 1.3600e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-5.2114e-03, -5.0998e-03, -8.2229e-03],\n", + " [-2.2549e-03, -3.9222e-04, -1.6133e-03],\n", + " [-1.4629e-02, 4.1414e-03, 5.7456e-03]],\n", + "\n", + " [[ 1.2655e-02, 2.8707e-03, 3.0632e-03],\n", + " [ 6.5042e-03, 8.0699e-03, 5.7052e-03],\n", + " [ 1.0453e-02, 1.2078e-02, 3.9071e-04]],\n", + "\n", + " [[ 1.7813e-02, -6.4311e-03, 1.0226e-02],\n", + " [ 1.2140e-02, 4.2405e-03, 9.8047e-03],\n", + " [-1.4421e-02, 5.1734e-03, -1.9546e-03]]],\n", + "\n", + "\n", + " [[[ 1.8379e-02, 1.1177e-02, 1.7544e-02],\n", + " [-8.6856e-04, 3.5306e-03, 1.3847e-02],\n", + " [-1.4978e-03, 1.8132e-02, -3.5698e-03]],\n", + "\n", + " [[ 2.4527e-04, 1.4042e-02, -8.8348e-03],\n", + " [ 4.5480e-03, 4.4421e-03, -7.4060e-03],\n", + " [ 7.7895e-03, 5.7431e-05, 4.0264e-03]],\n", + "\n", + " [[ 5.3564e-03, -1.4270e-02, -1.9118e-02],\n", + " [ 1.3765e-02, 8.8396e-04, 6.0510e-03],\n", + " [ 3.6769e-03, -3.0135e-03, -1.5168e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.7418e-02, -9.8899e-03, -2.1458e-03],\n", + " [ 7.3479e-03, 7.0623e-03, 1.1940e-03],\n", + " [ 3.7772e-03, 2.9783e-04, -2.8873e-03]],\n", + "\n", + " [[-9.7363e-03, 1.1990e-02, 5.8910e-04],\n", + " [-1.3291e-02, -1.3032e-02, 2.7085e-03],\n", + " [-1.2218e-02, -2.0081e-02, -2.5794e-04]],\n", + "\n", + " [[-1.4517e-03, -1.3347e-03, -1.9954e-02],\n", + " [-9.5822e-03, 8.9759e-03, -5.2620e-03],\n", + " [ 1.1365e-02, 6.5043e-03, -1.7677e-03]]],\n", + "\n", + "\n", + " [[[ 2.8814e-03, -1.1459e-02, 7.2866e-03],\n", + " [-1.1217e-02, -6.3151e-03, -4.2912e-04],\n", + " [ 4.4563e-03, -3.7896e-03, -1.9523e-03]],\n", + "\n", + " [[-1.3908e-02, 1.0951e-02, 7.5778e-03],\n", + " [ 9.1734e-05, -2.2147e-03, -4.2200e-03],\n", + " [ 9.6189e-03, -1.1441e-02, -1.4601e-02]],\n", + "\n", + " [[ 1.0115e-02, 1.7201e-02, 6.6180e-03],\n", + " [-4.3699e-03, 1.8709e-03, -1.8911e-02],\n", + " [ 6.8236e-03, 1.1824e-02, -5.8826e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.9188e-02, 8.9656e-03, 3.0848e-02],\n", + " [-5.0819e-03, -3.8729e-03, 1.8700e-02],\n", + " [ 4.7682e-03, 5.8786e-03, 9.7623e-03]],\n", + "\n", + " [[-3.5205e-04, -2.5114e-03, 1.1785e-04],\n", + " [ 8.3362e-03, 7.0353e-03, -1.7875e-03],\n", + " [-1.1702e-03, -8.3737e-03, 1.4237e-02]],\n", + "\n", + " [[-2.8220e-03, -2.0585e-02, -1.9547e-02],\n", + " [-3.0085e-04, -2.0964e-02, -1.9344e-02],\n", + " [-1.0963e-02, -2.8152e-03, -1.9004e-02]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[-1.1269e-02, 1.0234e-02, -1.5817e-03],\n", + " [ 3.7629e-03, -3.9181e-03, -1.3499e-02],\n", + " [-7.1245e-04, -2.1378e-03, 1.7662e-02]],\n", + "\n", + " [[ 1.0952e-02, 6.3739e-03, 1.2541e-02],\n", + " [ 1.6861e-02, -1.0851e-02, 4.3220e-03],\n", + " [ 1.3744e-02, 3.8529e-03, 5.6580e-04]],\n", + "\n", + " [[-1.3857e-02, -2.3699e-04, -2.2844e-03],\n", + " [-2.0008e-02, 6.1069e-03, -1.1951e-02],\n", + " [ 1.0057e-02, -6.6128e-03, 2.5271e-03]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.3057e-02, 1.1391e-02, 2.4781e-03],\n", + " [ 5.9918e-03, 2.2100e-02, 3.2220e-03],\n", + " [ 1.6557e-02, 1.9436e-02, 4.1398e-03]],\n", + "\n", + " [[ 1.3594e-02, 7.4428e-03, 2.5966e-02],\n", + " [ 1.4304e-02, -4.4501e-03, -4.0079e-03],\n", + " [-1.0729e-02, 9.3449e-03, -7.2407e-03]],\n", + "\n", + " [[ 8.2458e-03, 1.2307e-02, -1.8730e-03],\n", + " [ 3.8755e-03, -1.0899e-02, -1.0243e-02],\n", + " [ 1.4385e-02, 1.4387e-02, -6.9691e-06]]],\n", + "\n", + "\n", + " [[[-1.2377e-03, -2.4913e-02, -6.6713e-03],\n", + " [-1.1979e-02, -3.9882e-02, -3.5081e-02],\n", + " [-3.0059e-02, -1.3726e-02, -1.6513e-02]],\n", + "\n", + " [[-7.6263e-03, 4.6258e-03, 4.1316e-03],\n", + " [-2.0899e-02, 9.7261e-03, 1.7532e-02],\n", + " [ 1.3484e-02, -1.5733e-02, 5.7539e-03]],\n", + "\n", + " [[ 6.4326e-03, 1.1697e-02, -9.0290e-03],\n", + " [ 3.7352e-03, -7.8998e-03, -2.0597e-03],\n", + " [ 1.2484e-02, 6.9576e-03, 1.9936e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-9.8303e-03, -1.0628e-02, -3.4814e-03],\n", + " [ 4.3456e-03, -1.1256e-02, 4.4709e-03],\n", + " [ 5.8977e-03, 1.3371e-03, 1.3130e-03]],\n", + "\n", + " [[ 1.2398e-02, 8.9216e-03, 1.2770e-03],\n", + " [-1.1820e-02, -7.3181e-03, -8.0942e-04],\n", + " [-5.5888e-03, -8.0208e-03, 3.3651e-03]],\n", + "\n", + " [[-4.3322e-03, -9.2950e-03, -1.2784e-02],\n", + " [-6.3165e-03, -1.3445e-03, 8.2466e-03],\n", + " [ 2.8841e-03, -8.8737e-03, 3.1212e-03]]],\n", + "\n", + "\n", + " [[[-6.0997e-03, -1.6026e-02, -1.7981e-02],\n", + " [-1.8672e-02, 3.5800e-03, -1.1836e-02],\n", + " [-2.5707e-03, 6.3123e-03, -3.1627e-03]],\n", + "\n", + " [[ 2.3208e-03, -9.5134e-03, 5.2807e-03],\n", + " [-4.3437e-03, -2.9705e-03, -1.8156e-02],\n", + " [ 1.3944e-02, -1.7958e-02, 4.8227e-03]],\n", + "\n", + " [[-1.5250e-03, 1.4135e-02, -7.6070e-03],\n", + " [-6.7339e-03, 2.5677e-03, -8.7000e-03],\n", + " [-5.2387e-03, 1.7277e-02, 1.6578e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[-3.9090e-03, -4.0446e-03, 1.0471e-02],\n", + " [ 9.9475e-03, -4.6785e-03, 1.9391e-02],\n", + " [ 1.5370e-02, 4.5712e-03, 8.4700e-03]],\n", + "\n", + " [[ 1.7564e-02, -3.1784e-03, -2.9029e-03],\n", + " [-1.3421e-02, -2.4112e-02, -2.0311e-02],\n", + " [-6.3522e-03, -1.1957e-02, -3.5473e-03]],\n", + "\n", + " [[ 3.6304e-03, 2.8775e-04, -5.7601e-03],\n", + " [-2.3570e-03, 1.9330e-02, -3.9670e-03],\n", + " [-7.8952e-03, -1.0448e-02, -2.4362e-02]]]])\n", + "layer4.1.bn2.weight Parameter containing:\n", + "tensor([1.0242, 1.0146, 1.0980, 1.1073, 1.0217, 1.0784, 1.1271, 1.0397, 1.0641,\n", + " 1.0312, 1.0438, 1.0239, 1.0947, 1.0322, 1.0603, 1.0867, 1.0833, 1.0292,\n", + " 1.0249, 1.0900, 1.0181, 1.0410, 1.1284, 1.0774, 1.0293, 1.0697, 1.0609,\n", + " 1.0724, 1.0830, 1.1673, 1.0062, 1.1154, 1.1345, 1.0555, 1.0160, 1.0482,\n", + " 1.0454, 1.1229, 1.2171, 1.0859, 1.0210, 1.1659, 1.1155, 1.0243, 1.1353,\n", + " 1.0889, 1.0895, 1.1472, 1.0563, 1.0662, 1.1582, 1.0786, 1.0374, 0.9917,\n", + " 1.0358, 1.0606, 1.1683, 1.0239, 1.1692, 1.0581, 1.0631, 1.0899, 1.0036,\n", + " 1.1219, 1.0847, 1.0466, 1.0357, 1.0475, 1.0673, 1.0314, 1.0647, 1.0164,\n", + " 1.0947, 1.0652, 1.0888, 1.1126, 1.0377, 1.0581, 1.0609, 1.0574, 1.0381,\n", + " 1.0433, 1.0988, 1.0728, 1.0245, 1.1218, 1.0846, 1.0253, 1.0980, 1.0815,\n", + " 1.0791, 1.0930, 1.0781, 1.0400, 1.0950, 1.0370, 1.0712, 1.0209, 1.0012,\n", + " 1.0378, 1.0469, 1.0637, 1.0556, 1.0836, 1.0990, 1.1014, 1.0262, 1.0404,\n", + " 1.0208, 1.0725, 1.1698, 1.0678, 1.0908, 1.0541, 1.0590, 1.0491, 1.0458,\n", + " 1.0766, 1.1052, 1.0307, 1.0563, 1.0302, 1.1253, 1.0703, 1.0515, 1.0016,\n", + " 1.0125, 1.0734, 1.1176, 1.0469, 1.1395, 1.0709, 1.1377, 1.2334, 1.1187,\n", + " 1.0490, 1.0694, 1.0609, 1.1138, 1.0400, 1.1218, 1.0512, 1.2059, 1.0560,\n", + " 1.0507, 1.0822, 1.0961, 1.0719, 1.0658, 1.0883, 1.0782, 1.0793, 1.0848,\n", + " 1.0312, 1.0373, 1.0784, 1.0518, 1.0757, 1.0129, 1.0832, 1.0855, 1.0356,\n", + " 1.0360, 1.0601, 1.0969, 1.0702, 1.0355, 1.0328, 1.1190, 1.0092, 1.0068,\n", + " 1.0617, 1.1217, 1.0142, 1.1396, 1.1373, 1.2151, 1.1153, 1.0831, 1.0624,\n", + " 0.9801, 1.0520, 1.0310, 1.0255, 1.0246, 1.0752, 1.0150, 1.1226, 1.1025,\n", + " 1.0545, 1.0117, 1.0097, 1.0898, 1.0338, 1.1749, 0.9866, 1.0784, 1.0726,\n", + " 1.0084, 1.1760, 1.0616, 1.1402, 1.0070, 1.0595, 1.0467, 1.1104, 0.9872,\n", + " 1.0514, 1.0343, 1.0182, 1.0532, 1.0238, 1.1251, 1.0696, 1.0595, 1.1180,\n", + " 1.0848, 1.1539, 1.0253, 1.0455, 1.0185, 1.0646, 1.0264, 1.0715, 1.1070,\n", + " 0.9987, 1.0956, 1.0493, 1.0908, 1.0277, 1.0646, 1.0280, 1.0533, 1.0014,\n", + " 1.0239, 1.0354, 1.0473, 1.0112, 1.0122, 1.1165, 1.0348, 0.9891, 1.0225,\n", + " 1.0456, 0.9976, 1.0481, 1.1477, 1.0721, 1.0166, 1.0151, 1.0839, 1.0700,\n", + " 1.1691, 1.0164, 1.0309, 1.0240, 1.1567, 1.0431, 1.0736, 1.1114, 1.0065,\n", + " 1.0703, 1.0611, 1.0475, 1.0474, 1.0789, 1.0267, 1.1011, 1.0307, 1.0315,\n", + " 1.0288, 1.1784, 1.0545, 1.0427, 1.1500, 1.0942, 1.1086, 1.0019, 1.1033,\n", + " 1.1051, 1.0601, 1.0148, 1.0515, 1.0572, 1.0110, 1.1092, 1.0073, 1.0400,\n", + " 1.0272, 1.0786, 1.0280, 1.1989, 1.1029, 1.0141, 1.0876, 1.0109, 1.0384,\n", + " 1.0236, 1.0347, 1.1355, 1.1481, 1.0085, 1.0619, 1.1081, 1.0697, 1.1674,\n", + " 1.0433, 1.0232, 1.0453, 1.0082, 1.0715, 1.0691, 1.0982, 0.9905, 1.0557,\n", + " 1.0059, 1.0252, 1.0962, 1.1416, 1.0434, 1.0214, 1.0485, 1.1178, 1.0474,\n", + " 1.1043, 1.1483, 1.1027, 1.1413, 1.1306, 1.1123, 1.0489, 1.0220, 1.1077,\n", + " 1.0568, 1.0750, 1.0148, 1.1145, 1.0168, 1.0564, 1.0298, 1.0500, 1.0745,\n", + " 1.1666, 1.0316, 1.0129, 1.0392, 0.9879, 1.0129, 1.0548, 1.0767, 1.1244,\n", + " 1.1325, 1.0386, 1.0411, 1.0769, 1.0749, 1.0645, 1.0663, 1.0729, 1.0543,\n", + " 1.0424, 1.0464, 1.0326, 1.1171, 1.1288, 1.0248, 1.0545, 1.0248, 1.0937,\n", + " 1.0726, 1.0840, 1.0670, 1.0170, 1.0825, 1.1329, 1.0469, 1.0576, 1.0934,\n", + " 1.0217, 1.0305, 1.0960, 1.0253, 1.0343, 1.0251, 1.0356, 1.0462, 1.1378,\n", + " 1.1591, 1.0126, 1.0977, 1.0571, 1.0547, 1.0241, 1.0492, 1.0013, 1.0310,\n", + " 1.0516, 1.1145, 1.1385, 1.0651, 1.0141, 1.0097, 1.0420, 1.0733, 1.0806,\n", + " 1.0100, 1.1806, 1.0809, 1.1831, 1.0071, 1.0546, 1.0627, 1.0322, 1.0487,\n", + " 1.0382, 1.1233, 1.0546, 1.1694, 1.0623, 1.0617, 1.0711, 1.0351, 1.0413,\n", + " 1.0751, 1.0799, 1.0228, 1.0337, 1.1464, 1.1565, 1.0602, 1.0685, 1.0337,\n", + " 1.0126, 1.0579, 1.0529, 1.1038, 1.1000, 1.1436, 1.0905, 1.0137, 1.1300,\n", + " 1.0555, 1.0546, 1.1316, 1.0681, 1.0819, 1.0160, 1.0665, 1.0106, 1.0965,\n", + " 1.0122, 1.0951, 1.1311, 1.0323, 0.9902, 1.0654, 1.1184, 1.0895, 1.1081,\n", + " 1.1093, 1.0375, 1.0910, 1.0604, 1.0668, 1.1883, 1.0838, 1.0655, 1.0060,\n", + " 1.0086, 1.0424, 1.1064, 1.0452, 1.0224, 1.0966, 1.0517, 1.0848, 1.0836,\n", + " 1.0501, 1.1060, 1.0344, 1.0298, 1.0721, 1.1158, 1.1045, 1.1082, 1.0016,\n", + " 1.1235, 1.1025, 1.0471, 1.1419, 1.1065, 1.0055, 1.0378, 1.0280, 1.0306,\n", + " 1.0631, 1.0714, 1.1036, 1.0619, 1.0554, 1.0397, 1.0618, 1.0505, 1.0305,\n", + " 1.0174, 1.0348, 1.0543, 1.1011, 1.0401, 1.0488, 1.1258, 1.0311])\n", + "layer4.1.bn2.bias Parameter containing:\n", + "tensor([-0.0325, -0.0605, -0.0857, -0.0918, -0.0602, -0.0845, -0.0784, -0.0422,\n", + " -0.0753, -0.0499, -0.0510, -0.0269, -0.0744, -0.0588, -0.0597, -0.0444,\n", + " -0.0797, -0.0465, -0.0286, -0.1052, -0.0191, -0.0452, -0.1614, -0.1270,\n", + " -0.0403, -0.0847, -0.0681, -0.0665, -0.0804, -0.0736, -0.0157, -0.0878,\n", + " -0.1096, -0.0556, -0.0338, -0.0658, -0.0492, -0.0942, -0.1351, -0.0913,\n", + " -0.0605, -0.1077, -0.0760, -0.0268, -0.0929, -0.1248, -0.0825, -0.1334,\n", + " -0.0502, -0.0905, -0.0893, -0.0895, -0.0377, -0.0247, -0.0406, -0.0473,\n", + " -0.1259, -0.0324, -0.1490, -0.0499, -0.0981, -0.0630, -0.0290, -0.0723,\n", + " -0.0771, -0.0589, -0.0483, -0.0727, -0.0719, -0.0575, -0.0789, -0.0573,\n", + " -0.0650, -0.0764, -0.0824, -0.0842, -0.0428, -0.0618, -0.0717, -0.0674,\n", + " -0.0505, -0.0385, -0.1024, -0.0690, -0.0712, -0.0958, -0.0633, -0.0369,\n", + " -0.1048, -0.0852, -0.0832, -0.0954, -0.1188, -0.0503, -0.0537, -0.0669,\n", + " -0.1042, -0.0372, -0.0239, -0.0809, -0.0686, -0.0626, -0.0527, -0.0764,\n", + " -0.0807, -0.0843, -0.0461, -0.0379, -0.0419, -0.0656, -0.1475, -0.0798,\n", + " -0.0922, -0.0623, -0.0506, -0.0202, -0.0486, -0.0802, -0.0912, -0.0541,\n", + " -0.0456, -0.0481, -0.0719, -0.0575, -0.0546, -0.0292, -0.0339, -0.0723,\n", + " -0.0812, -0.0887, -0.1355, -0.0588, -0.0987, -0.1827, -0.1166, -0.0825,\n", + " -0.0900, -0.0490, -0.1302, -0.0523, -0.1335, -0.0503, -0.1490, -0.0766,\n", + " -0.0275, -0.0724, -0.0757, -0.0518, -0.0562, -0.1220, -0.0533, -0.0878,\n", + " -0.0788, -0.0503, -0.0515, -0.0457, -0.0664, -0.0668, -0.0324, -0.1042,\n", + " -0.1219, -0.0598, -0.0487, -0.0811, -0.1269, -0.0676, -0.0335, -0.0383,\n", + " -0.1570, -0.0325, -0.0246, -0.0483, -0.0958, -0.0224, -0.0991, -0.1084,\n", + " -0.1514, -0.0992, -0.0957, -0.0658, -0.0233, -0.0615, -0.0628, -0.0498,\n", + " -0.0483, -0.0696, -0.0132, -0.0625, -0.1002, -0.0779, -0.0290, -0.0283,\n", + " -0.0602, -0.0512, -0.1086, -0.0198, -0.1045, -0.0523, -0.0197, -0.1472,\n", + " -0.0456, -0.1203, -0.0306, -0.0667, -0.0519, -0.1100, -0.0186, -0.0970,\n", + " -0.0415, -0.0227, -0.0472, -0.0281, -0.1176, -0.0439, -0.0673, -0.1259,\n", + " -0.0806, -0.1527, -0.0542, -0.0571, -0.0378, -0.0684, -0.0353, -0.0721,\n", + " -0.0857, -0.0396, -0.1194, -0.0586, -0.0896, -0.0594, -0.0764, -0.0524,\n", + " -0.0556, -0.0250, -0.0375, -0.0627, -0.0567, -0.0237, -0.0375, -0.1197,\n", + " -0.0859, -0.0185, -0.0505, -0.0494, -0.0357, -0.0603, -0.1173, -0.1240,\n", + " -0.0454, -0.0445, -0.0784, -0.0812, -0.1168, -0.0637, -0.0371, -0.0569,\n", + " -0.1315, -0.0535, -0.0760, -0.1098, -0.0259, -0.0787, -0.0838, -0.0419,\n", + " -0.0635, -0.0755, -0.0300, -0.0900, -0.0478, -0.0320, -0.0406, -0.1227,\n", + " -0.0664, -0.0390, -0.1053, -0.0896, -0.1074, -0.0246, -0.0633, -0.0996,\n", + " -0.0780, -0.0535, -0.0581, -0.0648, -0.0292, -0.0925, -0.0487, -0.0541,\n", + " -0.0306, -0.0418, -0.0345, -0.1240, -0.0759, -0.0462, -0.0935, -0.0202,\n", + " -0.0498, -0.0459, -0.0519, -0.1087, -0.0906, -0.0095, -0.0508, -0.1118,\n", + " -0.0899, -0.1910, -0.0354, -0.0657, -0.0774, -0.0577, -0.0633, -0.0740,\n", + " -0.0751, -0.0207, -0.1020, -0.0199, -0.0419, -0.1173, -0.1266, -0.0585,\n", + " -0.0434, -0.0672, -0.0881, -0.0537, -0.1247, -0.1065, -0.0609, -0.1341,\n", + " -0.1198, -0.0728, -0.0685, -0.0626, -0.0630, -0.0776, -0.0959, -0.0196,\n", + " -0.1088, -0.0125, -0.0495, -0.0377, -0.0659, -0.0954, -0.1193, -0.0432,\n", + " -0.0374, -0.0297, -0.0210, -0.0286, -0.0535, -0.0750, -0.0967, -0.0967,\n", + " -0.0632, -0.0489, -0.0639, -0.0741, -0.0620, -0.0948, -0.0966, -0.0725,\n", + " -0.0739, -0.0649, -0.0182, -0.0995, -0.1026, -0.0422, -0.0829, -0.0523,\n", + " -0.0904, -0.1020, -0.1121, -0.0748, -0.0402, -0.0973, -0.1215, -0.0515,\n", + " -0.0663, -0.0595, -0.0502, -0.0370, -0.1234, -0.0346, -0.0301, -0.0440,\n", + " -0.0398, -0.0457, -0.1115, -0.1188, -0.0378, -0.1049, -0.0657, -0.0521,\n", + " -0.0337, -0.0106, -0.0047, -0.0402, -0.0479, -0.1028, -0.1165, -0.0627,\n", + " -0.0531, -0.0219, -0.0575, -0.0982, -0.1051, -0.0290, -0.1009, -0.0741,\n", + " -0.1253, -0.0165, -0.0490, -0.0601, -0.0461, -0.0342, -0.0526, -0.1105,\n", + " -0.0567, -0.1048, -0.0604, -0.0469, -0.0570, -0.0415, -0.0517, -0.0641,\n", + " -0.0758, -0.0582, -0.0588, -0.1217, -0.1596, -0.0583, -0.0997, -0.0409,\n", + " -0.0497, -0.0700, -0.0495, -0.1100, -0.0712, -0.1107, -0.1000, -0.0489,\n", + " -0.0686, -0.0748, -0.0701, -0.1143, -0.0737, -0.0786, -0.0387, -0.0504,\n", + " -0.0310, -0.0747, -0.0424, -0.0903, -0.1444, -0.0355, -0.0141, -0.0639,\n", + " -0.0912, -0.0753, -0.1046, -0.0843, -0.0553, -0.0972, -0.0655, -0.0793,\n", + " -0.1704, -0.0576, -0.0707, -0.0293, -0.0306, -0.0406, -0.0982, -0.0649,\n", + " -0.0475, -0.0868, -0.0556, -0.0701, -0.0867, -0.0442, -0.1121, -0.0544,\n", + " -0.0447, -0.0932, -0.0825, -0.0917, -0.0973, -0.0286, -0.1121, -0.0692,\n", + " -0.0585, -0.1000, -0.1387, -0.0233, -0.0365, -0.0427, -0.0382, -0.0542,\n", + " -0.0804, -0.0513, -0.0809, -0.0928, -0.0485, -0.0691, -0.0529, -0.0574,\n", + " -0.0412, -0.0369, -0.0579, -0.1057, -0.0373, -0.0553, -0.1050, -0.0390])\n", + "linear.weight Parameter containing:\n", + "tensor([[ 0.0370, -0.0410, -0.0149, ..., -0.0283, -0.0105, -0.0024],\n", + " [ 0.0430, 0.0205, -0.0063, ..., -0.0271, 0.0237, -0.0130],\n", + " [ 0.0301, 0.0312, -0.0110, ..., 0.0084, 0.0321, 0.0298],\n", + " ...,\n", + " [ 0.0296, -0.0215, -0.0301, ..., -0.0128, -0.0214, 0.0351],\n", + " [ 0.0431, 0.0256, 0.0101, ..., -0.0294, -0.0066, -0.0428],\n", + " [-0.0257, 0.0246, 0.0169, ..., -0.0290, 0.0167, 0.0032]],\n", + " requires_grad=True)\n", + "linear.bias Parameter containing:\n", + "tensor([ 0.0029, -0.0253, -0.0079, -0.0038, 0.0257, 0.0105, 0.0017, 0.0405,\n", + " 0.0337, 0.0201], requires_grad=True)\n" + ] + } + ], + "source": [ + "model = create_backbone(name='res18', num_classes=10)\n", + "classifier = nn.Linear(in_features=model.output_dim, out_features=10, bias=True)\n", + "for name, value in model.named_parameters():\n", + " if not name.startswith('linear') :\n", + " value.requires_grad = False\n", + "pretrained_model = torch.load('../checkpoint/SimCLR_on_Cifar4CL_lr0.5_lstep1_rn100.ckpt', map_location='cpu')\n", + "model.load_state_dict({k[9:]:v for k, v in pretrained_model['model'].items() if k.startswith('backbone.')}, strict=False)\n", + "\n", + "del pretrained_model\n", + "# model.add_module(\"Linear\", classifier)\n", + "for name, value in model.named_parameters():\n", + " print(name, value)" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "166237f2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Parameter containing:\n", + "tensor([[-0.0187, -0.0211, -0.0236, ..., -0.0326, 0.0137, 0.0262],\n", + " [-0.0381, -0.0086, 0.0255, ..., -0.0020, 0.0384, -0.0152],\n", + " [ 0.0052, 0.0015, -0.0284, ..., 0.0357, 0.0110, -0.0346],\n", + " ...,\n", + " [-0.0243, 0.0382, -0.0128, ..., 0.0228, 0.0274, -0.0390],\n", + " [-0.0257, 0.0127, 0.0342, ..., 0.0194, 0.0440, -0.0125],\n", + " [-0.0125, 0.0293, -0.0360, ..., -0.0290, 0.0309, 0.0352]],\n", + " requires_grad=True)\n", + "Parameter containing:\n", + "tensor([-0.0302, 0.0121, 0.0225, -0.0034, 0.0179, -0.0157, -0.0225, -0.0358,\n", + " 0.0366, -0.0426], requires_grad=True)\n" + ] + } + ], + "source": [ + "def get_freezed_parameters(module):\n", + " \"\"\"\n", + " Returns names of freezed parameters of the given module.\n", + " \"\"\"\n", + " \n", + " freezed_parameters = []\n", + " for name, parameter in module.named_parameters():\n", + " if not parameter.requires_grad:\n", + " freezed_parameters.append(name)\n", + " \n", + " return freezed_parameters\n", + "\n", + "get_freezed_parameters(model)\n", + "for i in filter(lambda p: p.requires_grad, model.parameters()):\n", + " print(i)" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "id": "d6dc3b50", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)\n", + "optimizer.zero_grad()\n", + "output = model(torch.rand((1,3,32,32)))\n", + "loss = F.cross_entropy(output, torch.tensor([1]))\n", + "loss.backward()\n", + "optimizer.step()" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "id": "d6ff199a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 99, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'abcdfgh'.endswith('dfgh')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5293562", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/federatedscope/cl/trainer/__init__.py b/federatedscope/cl/trainer/__init__.py new file mode 100644 index 000000000..c0b31382d --- /dev/null +++ b/federatedscope/cl/trainer/__init__.py @@ -0,0 +1,8 @@ +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py new file mode 100644 index 000000000..b34a8ba3f --- /dev/null +++ b/federatedscope/cl/trainer/trainer.py @@ -0,0 +1,144 @@ +from federatedscope.register import register_trainer +from federatedscope.core.trainers import GeneralTorchTrainer +from federatedscope.core.auxiliaries import utils +import numpy as np + +def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): + # compute cos similarity between each feature vector and feature bank ---> [B, N] + sim_matrix = torch.mm(feature, feature_bank) + # [B, K] + sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) + # [B, K] + sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices) + sim_weight = (sim_weight / knn_t).exp() + + # counts for each class + one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device) + # [B*K, C] + one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0) + # weighted score ---> [B, C] + pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1) + + pred_labels = pred_scores.argsort(dim=-1, descending=True) + return pred_labels + +def knn_monitor(net, memory_data_loader, test_data_loader, k=200, t=0.1, device="cpu", verbose=True): + net.eval() + classes = len(memory_data_loader.dataset.classes) + total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] + feature_labels = [] + with torch.no_grad(): + # generate feature bank + for data, target in tqdm(memory_data_loader, desc='Feature extracting', leave=False, disable=not verbose): + feature = net(data.to(device)) + feature = F.normalize(feature, dim=1) + feature_bank.append(feature) + feature_labels.append(target.to(device)) + # [D, N] + feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() + # [N] + # feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=device) + feature_labels = torch.cat(feature_labels, dim=0).contiguous() + + # loop test data to predict the label by weighted knn search + test_bar = tqdm(test_data_loader, desc='kNN', disable=not verbose) + for data, target in test_bar: + data, target = data.to(device), target.to(device) + feature = net(data) + feature = F.normalize(feature, dim=1) + + pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, k, t) + + total_num += data.size(0) + total_top1 += (pred_labels[:, 0] == target).float().sum().item() + return total_top1 / total_num * 100 + + +class CLTrainer(GeneralTorchTrainer): + def _hook_on_batch_forward(self, ctx): + x, label = [utils.move_to(_, ctx.device) for _ in ctx.data_batch] +# print(len(x), x[0].size(), x[1].size(), label.size()) + x1, x2 = x[0], x[1] + z1, z2 = ctx.model(x1, x2) + if len(label.size()) == 0: + label = label.unsqueeze(0) + ctx.loss_batch = ctx.criterion(z1, z2) + ctx.y_true = label + ctx.y_prob = z1, z2 + + ctx.batch_size = len(label) + + def _hook_on_batch_end(self, ctx): + # update statistics + setattr( + ctx, "loss_batch_total_{}".format(ctx.cur_data_split), + ctx.get("loss_batch_total_{}".format(ctx.cur_data_split)) + + ctx.loss_batch.item() * ctx.batch_size) + + if ctx.get("loss_regular", None) is None or ctx.loss_regular == 0: + loss_regular = 0. + else: + loss_regular = ctx.loss_regular.item() + setattr( + ctx, "loss_regular_total_{}".format(ctx.cur_data_split), + ctx.get("loss_regular_total_{}".format(ctx.cur_data_split)) + + loss_regular) + setattr( + ctx, "num_samples_{}".format(ctx.cur_data_split), + ctx.get("num_samples_{}".format(ctx.cur_data_split)) + + ctx.batch_size) + + # cache label for evaluate + ctx.get("{}_y_true".format(ctx.cur_data_split)).append( + ctx.y_true.detach().cpu().numpy()) + +# print(len(ctx.y_prob), ctx.y_prob[0].size(), ctx.y_prob[1].size()) + ctx.get("{}_y_prob".format(ctx.cur_data_split)).append( + ctx.y_prob[0].detach().cpu().numpy()) + + # clean temp ctx + ctx.data_batch = None + ctx.batch_size = None + ctx.loss_task = None + ctx.loss_batch = None + ctx.loss_regular = None + ctx.y_true = None + ctx.y_prob = None + + def _hook_on_fit_end(self, ctx): + """Evaluate metrics. + + """ + setattr( + ctx, "{}_y_true".format(ctx.cur_data_split), + np.concatenate(ctx.get("{}_y_true".format(ctx.cur_data_split)))) + setattr( + ctx, "{}_y_prob".format(ctx.cur_data_split), + np.concatenate(ctx.get("{}_y_prob".format(ctx.cur_data_split)))) + results = self.metric_calculator.eval(ctx) + setattr(ctx, 'eval_metrics', results) + +def linear_prob_forward(ctx): + x, label = [_.to(ctx.device) for _ in ctx.data_batch] + pred = ctx.model(x) + if len(label.size()) == 0: + label = label.unsqueeze(0) + ctx.loss_batch = ctx.criterion(pred, label) + ctx.y_true = label + ctx.y_prob = pred + + ctx.batch_size = len(label) + +class LPTrainer(GeneralTorchTrainer): + pass + +def call_cl_trainer(trainer_type): + if trainer_type == 'cltrainer': + trainer_builder = CLTrainer + return trainer_builder + elif trainer_type == 'lptrainer': + trainer_builder = LPTrainer + return trainer_builder + + +register_trainer('cltrainer', call_cl_trainer) diff --git a/federatedscope/core/auxiliaries/criterion_builder.py b/federatedscope/core/auxiliaries/criterion_builder.py index 1502fc0d3..dfe2f42d5 100644 --- a/federatedscope/core/auxiliaries/criterion_builder.py +++ b/federatedscope/core/auxiliaries/criterion_builder.py @@ -3,6 +3,7 @@ try: from torch import nn from federatedscope.nlp.loss import * + from federatedscope.cl.loss import * except ImportError: nn = None diff --git a/federatedscope/core/auxiliaries/data_builder.py b/federatedscope/core/auxiliaries/data_builder.py index a6e3c13d4..18b6ced12 100644 --- a/federatedscope/core/auxiliaries/data_builder.py +++ b/federatedscope/core/auxiliaries/data_builder.py @@ -541,6 +541,9 @@ def get_data(config): elif config.data.type.lower() in ['femnist', 'celeba']: from federatedscope.cv.dataloader import load_cv_dataset data, modified_config = load_cv_dataset(config) + elif config.data.type.lower() in ['cifar4cl', 'cifar4lp']: + from federatedscope.cl.dataloader import load_cifar_dataset + data, modified_config = load_cifar_dataset(config) elif config.data.type.lower() in [ 'shakespeare', 'twitter', 'subreddit', 'synthetic' ]: diff --git a/federatedscope/core/auxiliaries/eunms.py b/federatedscope/core/auxiliaries/eunms.py new file mode 100644 index 000000000..2ef6b478b --- /dev/null +++ b/federatedscope/core/auxiliaries/eunms.py @@ -0,0 +1,30 @@ +class MODE: + """ + + Note: + Currently StrEnum cannot be imported with the environment + `sys.version_info < (3, 11)`, so we simply create a MODE class here. + """ + TRAIN = 'train' + TEST = 'test' + VAL = 'val' + FINETUNE = 'finetune' + + +class TRIGGER: + ON_FIT_START = 'on_fit_start' + ON_EPOCH_START = 'on_epoch_start' + ON_BATCH_START = 'on_batch_start' + ON_BATCH_FORWARD = 'on_batch_forward' + ON_BATCH_BACKWARD = 'on_batch_backward' + ON_BATCH_END = 'on_batch_end' + ON_EPOCH_END = 'on_epoch_end' + ON_FIT_END = 'on_fit_end' + + @classmethod + def contains(cls, item): + return item in [ + "on_fit_start", "on_epoch_start", "on_batch_start", + "on_batch_forward", "on_batch_backward", "on_batch_end", + "on_epoch_end", "on_fit_end" + ] diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index 3dcb460c4..1135df259 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -134,7 +134,14 @@ def get_model(model_config, local_data=None, backend='torch'): elif model_config.type.lower() in ['convnet2', 'convnet5', 'vgg11', 'lr']: from federatedscope.cv.model import get_cnn - model = get_cnn(model_config, input_shape) + model = get_cnn(model_config, local_data) + elif model_config.type.lower() in ['simclr', 'simclr_linear']: + from federatedscope.cl.model import get_simclr + model = get_simclr(model_config, local_data) + if model_config.type.lower().endswith('linear'): + for name, value in model.named_parameters(): + if not name.startswith('linear') : + value.requires_grad = False elif model_config.type.lower() in ['lstm']: from federatedscope.nlp.model import get_rnn model = get_rnn(model_config, input_shape) diff --git a/federatedscope/core/auxiliaries/optimizer_builder.py b/federatedscope/core/auxiliaries/optimizer_builder.py index 6083dbaee..ce1eeca3c 100644 --- a/federatedscope/core/auxiliaries/optimizer_builder.py +++ b/federatedscope/core/auxiliaries/optimizer_builder.py @@ -10,7 +10,7 @@ def get_optimizer(model, type, lr, **kwargs): if isinstance(type, str): if hasattr(torch.optim, type): if isinstance(model, torch.nn.Module): - return getattr(torch.optim, type)(model.parameters(), lr, + return getattr(torch.optim, type)(filter(lambda p: p.requires_grad, model.parameters()), lr, **kwargs) else: return getattr(torch.optim, type)(model, lr, **kwargs) diff --git a/federatedscope/core/auxiliaries/trainer_builder.py b/federatedscope/core/auxiliaries/trainer_builder.py index c8cd4d5d7..9fbae8259 100644 --- a/federatedscope/core/auxiliaries/trainer_builder.py +++ b/federatedscope/core/auxiliaries/trainer_builder.py @@ -25,6 +25,7 @@ "fedvattrainer": "FedVATTrainer", "fedfocaltrainer": "FedFocalTrainer", "mftrainer": "MFTrainer", + "cltrainer": "CLTrainer", } @@ -61,6 +62,8 @@ def get_trainer(model=None, dict_path = "federatedscope.cv.trainer.trainer" elif config.trainer.type.lower() in ['nlptrainer']: dict_path = "federatedscope.nlp.trainer.trainer" + elif config.trainer.type.lower() in ['cltrainer']: + dict_path = "federatedscope.cl.trainer.trainer" elif config.trainer.type.lower() in [ 'graphminibatch_trainer', ]: From eaa50a109aac883c432708e90b00e47c27cf1ed9 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 2 Aug 2022 02:11:20 +0800 Subject: [PATCH 02/46] delete other yamls --- .../cl/baseline/fedavg_lr_on_twitter.yaml | 32 ----------------- .../baseline/fedavg_lstm_on_shakespeare.yaml | 35 ------------------ .../cl/baseline/fedavg_lstm_on_subreddit.yaml | 32 ----------------- .../baseline/fedavg_transformer_on_imdb.yaml | 36 ------------------- 4 files changed, 135 deletions(-) delete mode 100644 federatedscope/cl/baseline/fedavg_lr_on_twitter.yaml delete mode 100644 federatedscope/cl/baseline/fedavg_lstm_on_shakespeare.yaml delete mode 100644 federatedscope/cl/baseline/fedavg_lstm_on_subreddit.yaml delete mode 100644 federatedscope/cl/baseline/fedavg_transformer_on_imdb.yaml diff --git a/federatedscope/cl/baseline/fedavg_lr_on_twitter.yaml b/federatedscope/cl/baseline/fedavg_lr_on_twitter.yaml deleted file mode 100644 index 4f0656cdf..000000000 --- a/federatedscope/cl/baseline/fedavg_lr_on_twitter.yaml +++ /dev/null @@ -1,32 +0,0 @@ -use_gpu: True -device: 0 -early_stop: - patience: 5 -federate: - mode: standalone - total_round_num: 100 - sample_client_num: 10 -data: - root: data/ - type: twitter - batch_size: 5 - subsample: 0.005 - num_workers: 0 -model: - type: lr - out_channels: 2 - dropout: 0.0 -train: - local_update_steps: 10 - optimizer: - lr: 0.0003 - weight_decay: 0.0 -criterion: - type: CrossEntropyLoss -trainer: - type: nlptrainer -eval: - freq: 10 - metrics: ['acc', 'correct'] - split: ['train'] - best_res_update_round_wise_key: 'train_loss' \ No newline at end of file diff --git a/federatedscope/cl/baseline/fedavg_lstm_on_shakespeare.yaml b/federatedscope/cl/baseline/fedavg_lstm_on_shakespeare.yaml deleted file mode 100644 index 86d5aca27..000000000 --- a/federatedscope/cl/baseline/fedavg_lstm_on_shakespeare.yaml +++ /dev/null @@ -1,35 +0,0 @@ -use_gpu: True -device: 0 -early_stop: - patience: 10 -federate: - mode: standalone - total_round_num: 1000 - sample_client_rate: 0.2 -data: - root: data/ - type: shakespeare - batch_size: 64 - subsample: 0.2 - num_workers: 0 - splits: [0.6,0.2,0.2] -model: - type: lstm - in_channels: 80 - out_channels: 80 - embed_size: 8 - hidden: 256 - dropout: 0.0 -train: - local_update_steps: 1 - batch_or_epoch: epoch - optimizer: - lr: 0.8 - weight_decay: 0.0 -criterion: - type: character_loss -trainer: - type: nlptrainer -eval: - freq: 10 - metrics: ['acc', 'correct'] diff --git a/federatedscope/cl/baseline/fedavg_lstm_on_subreddit.yaml b/federatedscope/cl/baseline/fedavg_lstm_on_subreddit.yaml deleted file mode 100644 index 1080bb591..000000000 --- a/federatedscope/cl/baseline/fedavg_lstm_on_subreddit.yaml +++ /dev/null @@ -1,32 +0,0 @@ -use_gpu: True -device: 0 -early_stop: - patience: 10 -federate: - mode: standalone - total_round_num: 100 - sample_client_num: 10 -data: - root: data/ - type: subreddit - batch_size: 5 - subsample: 1.0 -model: - type: lstm - in_channels: 10000 - out_channels: 10000 - hidden: 256 - embed_size: 200 - dropout: 0.0 -train: - local_update_steps: 10 - optimizer: - lr: 8.0 - weight_decay: 0.0 -criterion: - type: CrossEntropyLoss -trainer: - type: nlptrainer -eval: - freq: 10 - metrics: ['acc', 'correct'] \ No newline at end of file diff --git a/federatedscope/cl/baseline/fedavg_transformer_on_imdb.yaml b/federatedscope/cl/baseline/fedavg_transformer_on_imdb.yaml deleted file mode 100644 index a9e818aa1..000000000 --- a/federatedscope/cl/baseline/fedavg_transformer_on_imdb.yaml +++ /dev/null @@ -1,36 +0,0 @@ -use_gpu: True -device: 2 -federate: - mode: standalone - total_round_num: 400 - client_num: 5 - share_local_model: True - online_aggr: True - sample_client_rate: 1.0 -data: - root: 'data' - type: 'IMDB@torchtext' - args: [{'max_len': 512}] - splits: [0.8, 0.2, 0.0] # test is fixed - batch_size: 128 - splitter: 'lda' - splitter_args: [{'alpha': 0.5}] - num_workers: 0 -model: - type: 'google/bert_uncased_L-2_H-128_A-2@transformers' - task: 'SequenceClassification' - out_channels: 2 -train: - local_update_steps: 1 - batch_or_epoch: 'epoch' - optimizer: - lr: 0.0001 - weight_decay: 0.0 -criterion: - type: 'CrossEntropyLoss' -trainer: - type: 'nlptrainer' -eval: - freq: 2 - metrics: ['acc', 'correct', 'f1'] - split: ['test', 'val', 'train'] \ No newline at end of file From 289942b278c445396715ad504c9bf3992196bab5 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 2 Aug 2022 03:27:44 +0800 Subject: [PATCH 03/46] script debug --- federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml | 2 +- federatedscope/cl/model/SimCLR.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml b/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml index 11d70ce36..b4e85ed31 100644 --- a/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml @@ -6,7 +6,7 @@ federate: client_num: 5 sample_client_rate: 1.0 method: local - restore_from: 'checkpoint/SimCLR_on_Cifar4CL_lr0.1_lstep5_rn100.ckpt' + #restore_from: 'checkpoint/SimCLR_on_Cifar4CL_lr0.1_lstep5_rn100.ckpt' data: root: 'data' type: 'Cifar4LP' diff --git a/federatedscope/cl/model/SimCLR.py b/federatedscope/cl/model/SimCLR.py index 3bf940d21..444e942cb 100644 --- a/federatedscope/cl/model/SimCLR.py +++ b/federatedscope/cl/model/SimCLR.py @@ -206,7 +206,7 @@ def ModelBuilder(model_config, local_data): return model if model_config.type == "SimCLR_linear": model = create_backbone(name='res18', num_classes=10) - pretrained_model = torch.load('checkpoint/SimCLR_on_Cifar4CL_lr0.5_lstep1_rn100.ckpt', map_location='cpu') + pretrained_model = torch.load('checkpoint/SimCLR_on_Cifar4CL_lr0.5_lstep5_rn100.ckpt', map_location='cpu') model.load_state_dict({k[9:]:v for k, v in pretrained_model['model'].items() if k.startswith('backbone.')}, strict=False) # for name, value in model.named_parameters(): # if not name.startswith('linear') : From c1c9beaf96924fcdd1a336905f17f6968813d5f5 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 2 Aug 2022 17:23:19 +0800 Subject: [PATCH 04/46] debug unit test error --- federatedscope/core/auxiliaries/model_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index 1135df259..e04b5b276 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -134,7 +134,7 @@ def get_model(model_config, local_data=None, backend='torch'): elif model_config.type.lower() in ['convnet2', 'convnet5', 'vgg11', 'lr']: from federatedscope.cv.model import get_cnn - model = get_cnn(model_config, local_data) + model = get_cnn(model_config, input_shape) elif model_config.type.lower() in ['simclr', 'simclr_linear']: from federatedscope.cl.model import get_simclr model = get_simclr(model_config, local_data) From f303910cda6101a8cc93ab74f434ea733e0d9bd2 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Wed, 3 Aug 2022 03:56:27 +0800 Subject: [PATCH 05/46] debug --- federatedscope/cl/dataloader/Cifar10.py | 6 +- federatedscope/cl/test.ipynb | 4425 ----------------- federatedscope/cl/trainer/trainer.py | 11 - federatedscope/core/auxiliaries/eunms.py | 30 - .../core/auxiliaries/model_builder.py | 4 + 5 files changed, 6 insertions(+), 4470 deletions(-) delete mode 100644 federatedscope/cl/test.ipynb delete mode 100644 federatedscope/core/auxiliaries/eunms.py diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index 73a3ab149..b5cb16103 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -10,7 +10,7 @@ from torchvision.datasets import CIFAR10, CIFAR100 import pickle as pkl import numpy as np - +from federatedscope.register import register_data class SimCLRTransform(): @@ -54,7 +54,6 @@ def Cifar4CL(config): val_per_client = len(data_val) // config.federate.client_num test_per_client = len(data_test) // config.federate.client_num - print("time1") for client_idx in range(1, config.federate.client_num + 1): dataloader_dict = { 'train': @@ -83,7 +82,6 @@ def Cifar4CL(config): shuffle=False) } data_dict[client_idx] = dataloader_dict - print("time2") r""" Returns: @@ -172,7 +170,7 @@ def Cifar4LP(config): config = config return data_dict, config -from federatedscope.register import register_data + def load_cifar_dataset(config): if config.data.type == "Cifar4CL": diff --git a/federatedscope/cl/test.ipynb b/federatedscope/cl/test.ipynb deleted file mode 100644 index b3267377c..000000000 --- a/federatedscope/cl/test.ipynb +++ /dev/null @@ -1,4425 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 13, - "id": "2abe03ca", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "3009a354", - "metadata": {}, - "outputs": [], - "source": [ - "def NT_xentloss_(z1, z2, temperature=0.5): \n", - " N, Z = z1.shape \n", - " device = z1.device \n", - " representations = torch.cat([z1, z2], dim=0)\n", - " similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)\n", - "\n", - " l_pos = torch.diag(similarity_matrix, N)\n", - " r_pos = torch.diag(similarity_matrix, -N)\n", - " positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)\n", - "\n", - " diag = torch.eye(2*N, dtype=torch.bool, device=device)\n", - " diag[N:,:N] = diag[:N,N:] = diag[:N,:N]\n", - " negatives = similarity_matrix[~diag].view(2*N, -1)\n", - "\n", - " logits = torch.cat([positives, negatives], dim=1) / temperature\n", - " labels = torch.zeros(2*N, device=device, dtype=torch.int64) # scalar label per sample\n", - " loss = F.cross_entropy(logits, labels, reduction='sum')\n", - "\n", - " return loss / (2 * N)\n", - "\n", - "class NT_xentloss(nn.Module):\n", - " def __init__(self, temperature=0.5):\n", - " super(NT_xentloss, self).__init__()\n", - " self.temperature = temperature\n", - " \n", - " def forward(self, z1, z2):\n", - " N, Z = z1.shape \n", - " device = z1.device \n", - " representations = torch.cat([z1, z2], dim=0)\n", - " similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)\n", - "\n", - " l_pos = torch.diag(similarity_matrix, N)\n", - " r_pos = torch.diag(similarity_matrix, -N)\n", - " positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)\n", - "\n", - " diag = torch.eye(2*N, dtype=torch.bool, device=device)\n", - " diag[N:,:N] = diag[:N,N:] = diag[:N,:N]\n", - " negatives = similarity_matrix[~diag].view(2*N, -1)\n", - "\n", - " logits = torch.cat([positives, negatives], dim=1) / temperature\n", - " labels = torch.zeros(2*N, device=device, dtype=torch.int64) # scalar label per sample\n", - " loss = F.cross_entropy(logits, labels, reduction='sum')\n", - " \n", - " return loss / (2 * N)" - ] - }, - { - "cell_type": "code", - "execution_count": 65, - "id": "89177767", - "metadata": {}, - "outputs": [], - "source": [ - "from model.SimCLR import simclr, create_backbone\n", - "import random" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "id": "a0c608b3", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "conv1.weight Parameter containing:\n", - "tensor([[[[-0.0183, -0.0765, 0.2014],\n", - " [-0.0594, -0.0160, -0.0777],\n", - " [-0.1222, -0.0284, 0.1164]],\n", - "\n", - " [[ 0.3198, 0.2831, 0.3139],\n", - " [ 0.3030, -0.0359, 0.2107],\n", - " [ 0.3482, -0.0047, 0.4345]],\n", - "\n", - " [[-0.0811, -0.0312, -0.0877],\n", - " [-0.4400, -0.6757, -0.4297],\n", - " [ 0.0313, -0.4326, 0.0159]]],\n", - "\n", - "\n", - " [[[-0.0643, 0.2237, 0.4458],\n", - " [-0.3193, -0.0773, 0.0561],\n", - " [-0.4273, -0.0085, 0.1814]],\n", - "\n", - " [[-0.0256, -0.1148, 0.1742],\n", - " [-0.4455, -0.1189, 0.2433],\n", - " [-0.4229, -0.2738, -0.1874]],\n", - "\n", - " [[-0.0910, 0.0338, 0.2989],\n", - " [-0.1994, 0.0990, 0.2593],\n", - " [-0.0121, 0.1451, -0.0939]]],\n", - "\n", - "\n", - " [[[-0.3403, -0.4584, -0.3110],\n", - " [ 0.1626, 0.2688, 0.1822],\n", - " [ 0.3948, 0.3835, 0.1284]],\n", - "\n", - " [[-0.4745, -0.1487, -0.1274],\n", - " [-0.1165, -0.2197, -0.0441],\n", - " [ 0.1622, 0.2059, -0.1887]],\n", - "\n", - " [[-0.0929, -0.2441, 0.0493],\n", - " [-0.0427, -0.0754, -0.0753],\n", - " [ 0.3795, 0.3901, 0.2711]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[-0.0482, 0.2177, 0.2685],\n", - " [-0.1009, -0.1820, 0.0026],\n", - " [-0.2059, -0.1552, -0.0143]],\n", - "\n", - " [[ 0.0179, -0.0478, -0.0131],\n", - " [ 0.1983, -0.2138, 0.2272],\n", - " [ 0.0263, -0.2440, -0.1631]],\n", - "\n", - " [[-0.0075, 0.1251, 0.0224],\n", - " [ 0.1411, -0.1159, -0.2196],\n", - " [ 0.0806, -0.3194, -0.2587]]],\n", - "\n", - "\n", - " [[[-0.0033, 0.2087, 0.0977],\n", - " [-0.0519, 0.2213, 0.2196],\n", - " [ 0.1723, 0.0177, 0.2410]],\n", - "\n", - " [[-0.1598, 0.1688, 0.1923],\n", - " [ 0.0352, 0.0465, 0.1310],\n", - " [ 0.1351, 0.2713, 0.0726]],\n", - "\n", - " [[-0.1874, 0.1088, -0.0098],\n", - " [-0.0707, 0.2780, 0.0114],\n", - " [-0.0553, 0.1157, 0.0258]]],\n", - "\n", - "\n", - " [[[-0.2044, -0.2053, -0.0668],\n", - " [-0.2899, -0.1816, -0.1537],\n", - " [-0.2074, -0.1091, -0.1762]],\n", - "\n", - " [[-0.1399, -0.2275, -0.0407],\n", - " [-0.1803, -0.0934, -0.0654],\n", - " [ 0.0834, -0.1998, 0.0623]],\n", - "\n", - " [[ 0.2251, 0.0599, 0.1248],\n", - " [ 0.1365, 0.3196, 0.1363],\n", - " [ 0.0486, 0.2550, 0.2294]]]])\n", - "bn1.weight Parameter containing:\n", - "tensor([1.1764, 1.1640, 1.0995, 1.1407, 1.1942, 0.9670, 1.8451, 1.1208, 1.0389,\n", - " 0.9228, 0.8822, 1.3220, 0.8472, 1.0401, 1.3100, 0.9914, 1.0612, 1.2423,\n", - " 1.0459, 1.0871, 0.9735, 0.9283, 1.0515, 0.9037, 1.0665, 0.9346, 1.5852,\n", - " 1.1763, 1.0359, 0.9605, 1.2617, 0.9576, 0.9276, 1.5612, 1.4488, 1.4455,\n", - " 1.0734, 1.0145, 1.0641, 1.1351, 0.7940, 1.5344, 1.0264, 0.9425, 0.9327,\n", - " 0.8563, 0.8882, 1.1045, 1.1882, 1.0556, 1.2136, 0.9525, 0.9602, 0.9024,\n", - " 1.0303, 0.9192, 0.8604, 1.1484, 1.9566, 0.9445, 1.1349, 0.9673, 0.8733,\n", - " 0.8793])\n", - "bn1.bias Parameter containing:\n", - "tensor([ 0.2469, 0.1301, 0.2798, 0.2101, 0.2512, -0.0913, 0.1757, -0.1319,\n", - " 0.0583, 0.1632, 0.0488, 0.4202, -0.0444, 0.0935, 0.2882, -0.0337,\n", - " 0.0455, 0.3127, 0.0145, 0.2779, -0.0525, 0.0980, 0.2111, -0.2265,\n", - " 0.3618, 0.0158, 0.4913, 0.3805, -0.1700, -0.0907, 0.3175, -0.0993,\n", - " 0.0133, 0.5292, 0.6708, -0.1043, 0.0500, 0.1086, 0.1487, 0.1016,\n", - " 0.1520, 0.6627, 0.0677, 0.1277, -0.1703, 0.2342, -0.0021, 0.0561,\n", - " 0.2777, 0.1305, 0.3191, 0.0372, -0.1303, 0.2400, -0.1451, -0.1098,\n", - " -0.1606, 0.2635, 0.5051, 0.1788, -0.0703, -0.1441, -0.0713, -0.1347])\n", - "layer1.0.conv1.weight Parameter containing:\n", - "tensor([[[[ 2.6822e-02, 2.7844e-02, 5.7835e-02],\n", - " [ 5.3032e-02, -1.9933e-02, 4.4936e-03],\n", - " [-7.0999e-03, -3.9780e-02, 8.8800e-03]],\n", - "\n", - " [[-4.3002e-02, -1.6001e-02, 2.9426e-02],\n", - " [-7.6142e-03, 7.4012e-02, 6.8242e-02],\n", - " [ 4.2627e-02, 9.5829e-02, 2.2561e-02]],\n", - "\n", - " [[ 2.0417e-02, 2.1891e-02, 7.8952e-02],\n", - " [-9.9509e-03, 8.9533e-02, 1.3280e-01],\n", - " [ 5.9915e-03, 1.0857e-01, 4.9066e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-4.3741e-02, -1.6707e-04, -8.0110e-02],\n", - " [-7.0618e-02, -5.8168e-02, -1.4400e-02],\n", - " [-7.1466e-02, -3.5159e-02, -1.8935e-02]],\n", - "\n", - " [[-5.9944e-02, -2.6373e-02, 6.4819e-02],\n", - " [-3.7174e-02, 4.6251e-02, 5.9955e-02],\n", - " [ 2.8329e-02, 7.1514e-02, 5.7975e-02]],\n", - "\n", - " [[-2.5611e-02, 3.2862e-02, 2.4414e-02],\n", - " [ 6.2642e-02, 9.8129e-03, 7.5067e-02],\n", - " [ 9.4383e-02, 3.7375e-02, 5.1015e-02]]],\n", - "\n", - "\n", - " [[[ 2.6016e-02, 3.3790e-03, 4.1034e-02],\n", - " [ 2.0423e-02, 3.9459e-02, 5.0560e-02],\n", - " [ 7.3355e-02, 5.9024e-02, 1.0244e-01]],\n", - "\n", - " [[ 9.3631e-02, -9.1128e-03, -2.5420e-03],\n", - " [ 1.2080e-01, 4.7242e-02, 4.0391e-02],\n", - " [ 6.4514e-02, 8.8472e-02, 4.4476e-02]],\n", - "\n", - " [[ 1.5563e-02, 6.9989e-02, 8.5023e-02],\n", - " [-6.7027e-04, 1.0338e-02, 6.1742e-02],\n", - " [-5.4143e-03, 2.5595e-02, 6.4282e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 3.9204e-02, -2.7143e-02, -2.1611e-02],\n", - " [-1.0465e-02, 6.9390e-03, 1.8000e-02],\n", - " [-2.1076e-02, 3.1111e-02, -2.4030e-02]],\n", - "\n", - " [[ 9.1457e-02, 5.4195e-02, 3.1970e-03],\n", - " [ 7.9195e-02, 3.1202e-02, 3.8534e-02],\n", - " [ 7.0259e-02, 5.2701e-02, 4.9941e-02]],\n", - "\n", - " [[ 2.0938e-02, 4.2828e-02, -1.8668e-02],\n", - " [ 6.0307e-02, 6.1750e-02, -1.7545e-02],\n", - " [ 7.9255e-02, 1.0614e-03, -2.1844e-02]]],\n", - "\n", - "\n", - " [[[-9.9994e-03, 5.3842e-02, 1.6255e-02],\n", - " [ 3.0797e-02, 7.7342e-02, 1.1806e-01],\n", - " [ 1.1641e-01, 8.0531e-02, 1.3195e-01]],\n", - "\n", - " [[ 3.2614e-02, 4.4080e-02, -9.1549e-02],\n", - " [ 3.5523e-02, 2.9408e-03, -3.7003e-02],\n", - " [ 6.3749e-02, 3.2515e-02, 4.9107e-02]],\n", - "\n", - " [[-1.5121e-03, 1.2614e-02, 4.8092e-02],\n", - " [-5.1469e-02, -6.2799e-03, -5.0903e-02],\n", - " [-2.3207e-02, -3.2161e-02, -4.1511e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.3646e-02, -7.9403e-02, -8.0696e-02],\n", - " [-5.3833e-02, -5.3324e-02, -4.2518e-02],\n", - " [-5.0440e-03, 2.9950e-02, -3.2976e-02]],\n", - "\n", - " [[-2.6310e-02, 5.8497e-02, 3.2514e-02],\n", - " [-4.3040e-02, -3.1510e-02, -9.4330e-03],\n", - " [-7.6044e-02, -1.1833e-02, -6.2875e-02]],\n", - "\n", - " [[ 1.9541e-02, 3.7317e-02, 4.0267e-02],\n", - " [-7.7281e-03, 5.9121e-03, -2.1681e-02],\n", - " [-2.5764e-02, 6.2579e-03, 9.6221e-04]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[ 1.1449e-02, 5.3609e-02, -3.0909e-03],\n", - " [ 6.0895e-02, 4.4686e-02, 4.9324e-02],\n", - " [-1.6167e-02, -3.9159e-02, 1.5942e-02]],\n", - "\n", - " [[-3.4361e-02, 3.7426e-02, 3.8636e-02],\n", - " [ 2.5543e-03, -3.0453e-02, -1.5798e-02],\n", - " [ 2.1582e-05, -8.5651e-02, -3.5822e-02]],\n", - "\n", - " [[-4.8913e-02, -6.9554e-02, -2.0775e-03],\n", - " [-1.4599e-02, -5.2028e-02, -3.5874e-02],\n", - " [-9.9485e-02, -1.4734e-03, -2.0859e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[ 1.8716e-02, 4.1136e-04, 3.3992e-02],\n", - " [-3.7199e-02, -4.4121e-03, 3.3526e-02],\n", - " [-1.1656e-02, -3.2170e-02, 2.0356e-02]],\n", - "\n", - " [[ 2.5727e-02, 5.4506e-02, -2.8860e-02],\n", - " [-2.2427e-02, -1.3907e-02, -2.6238e-02],\n", - " [-2.0807e-03, 6.2578e-03, 1.1508e-02]],\n", - "\n", - " [[-1.3074e-02, 6.6535e-02, 3.0361e-02],\n", - " [ 3.9119e-02, 3.1634e-02, 2.4559e-02],\n", - " [-2.3985e-02, -3.8256e-02, -6.4856e-02]]],\n", - "\n", - "\n", - " [[[-4.8634e-02, -7.1717e-02, -3.4695e-02],\n", - " [-7.3348e-04, -3.3484e-02, 4.1865e-03],\n", - " [-3.7279e-02, -3.6906e-02, -1.5438e-02]],\n", - "\n", - " [[-8.8918e-02, -5.0840e-02, -3.6884e-02],\n", - " [-5.6966e-02, -8.8720e-02, -3.8041e-02],\n", - " [ 9.9085e-03, -8.7135e-02, -1.0016e-01]],\n", - "\n", - " [[-1.7784e-02, -8.5942e-03, -4.1659e-03],\n", - " [ 2.7800e-02, 2.1069e-02, -1.0392e-03],\n", - " [-4.0285e-02, -2.9785e-03, 3.1919e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-2.0661e-02, 1.6509e-02, 5.3129e-02],\n", - " [-4.6699e-03, 7.5870e-03, 1.6611e-02],\n", - " [-6.6152e-02, -4.0273e-02, 4.6010e-02]],\n", - "\n", - " [[-9.7571e-03, -2.1538e-02, -3.4537e-04],\n", - " [ 5.6689e-02, 1.5892e-02, -4.1756e-02],\n", - " [ 5.6645e-02, 3.2210e-02, -3.3068e-03]],\n", - "\n", - " [[ 1.7798e-02, -3.1845e-02, -2.4694e-02],\n", - " [-6.8033e-02, -4.6940e-02, -2.4673e-02],\n", - " [-6.2354e-02, -4.0748e-02, 2.6249e-02]]],\n", - "\n", - "\n", - " [[[-6.6686e-02, -4.4081e-02, 4.6792e-02],\n", - " [-3.4211e-02, -6.6569e-02, 5.2638e-02],\n", - " [-5.5514e-02, -7.6365e-02, -4.6954e-03]],\n", - "\n", - " [[ 5.5348e-03, -4.3402e-02, -9.7526e-03],\n", - " [ 5.4384e-02, 2.8204e-02, 1.9586e-02],\n", - " [-4.1902e-02, -1.1975e-02, -4.1381e-03]],\n", - "\n", - " [[-7.1795e-03, -8.9517e-03, -4.6672e-02],\n", - " [-7.6845e-02, -8.6473e-02, -9.3363e-02],\n", - " [-4.8047e-03, -6.1196e-02, -5.3406e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-5.5690e-02, -3.6102e-02, 3.2068e-02],\n", - " [-4.7767e-02, -3.4354e-02, -2.1960e-02],\n", - " [-1.8896e-02, -2.4625e-02, -1.4877e-02]],\n", - "\n", - " [[-7.4994e-03, -1.6231e-03, -7.5342e-02],\n", - " [-8.7427e-03, -1.9412e-02, -9.5213e-02],\n", - " [-3.1043e-02, 3.6101e-03, -3.2579e-02]],\n", - "\n", - " [[ 9.4366e-02, 1.0742e-02, -7.4900e-03],\n", - " [ 7.4800e-02, 2.6321e-02, 5.5021e-02],\n", - " [ 2.0209e-02, 8.3471e-02, 9.2501e-02]]]])\n", - "layer1.0.bn1.weight Parameter containing:\n", - "tensor([1.0136, 0.9218, 0.9679, 1.0025, 0.9043, 0.8406, 1.2799, 1.0009, 1.0661,\n", - " 1.0704, 0.9351, 1.3376, 1.1102, 0.9364, 0.8600, 1.2370, 0.9971, 1.0767,\n", - " 1.3791, 0.9053, 0.9246, 1.0014, 1.0106, 0.9543, 0.8708, 0.9405, 1.0075,\n", - " 1.0012, 1.0660, 1.3460, 1.0126, 0.9835, 0.9491, 1.0240, 0.9015, 0.9719,\n", - " 0.8948, 1.3173, 0.9648, 0.9181, 0.9111, 0.8892, 0.8699, 0.9280, 1.0618,\n", - " 0.8207, 0.8929, 1.0177, 0.9283, 0.8522, 0.9356, 0.9263, 0.8903, 0.8923,\n", - " 0.9502, 0.9133, 1.1011, 1.0126, 0.8934, 0.9556, 1.1305, 0.9480, 1.1295,\n", - " 0.8721])\n", - "layer1.0.bn1.bias Parameter containing:\n", - "tensor([-0.0695, -0.1471, 0.0139, -0.0895, -0.0935, -0.1319, 0.2851, 0.0956,\n", - " 0.0141, -0.0588, -0.0934, 0.1814, 0.0189, -0.1585, -0.1160, 0.2663,\n", - " -0.0549, -0.0755, 0.2770, -0.2165, -0.1437, -0.0044, -0.0689, -0.2161,\n", - " -0.1360, 0.0215, -0.0459, 0.0850, -0.0693, 0.2729, -0.0163, 0.0720,\n", - " -0.0108, -0.0080, -0.0584, -0.0065, -0.0323, 0.0700, -0.0114, 0.0549,\n", - " 0.0068, -0.0173, -0.0960, -0.0148, 0.0259, -0.1088, -0.0520, -0.0149,\n", - " -0.0885, 0.0346, 0.1636, 0.0960, -0.0292, -0.0059, 0.1584, -0.2165,\n", - " -0.0836, -0.0022, -0.1441, -0.0754, 0.2161, -0.0256, 0.1242, -0.0562])\n", - "layer1.0.conv2.weight Parameter containing:\n", - "tensor([[[[ 1.1813e-02, -2.9603e-02, 8.2939e-03],\n", - " [-3.9731e-02, -4.8097e-02, 9.8338e-03],\n", - " [-4.8729e-02, -2.6406e-02, 3.8050e-02]],\n", - "\n", - " [[-1.3992e-01, -5.0215e-02, -3.7980e-02],\n", - " [-7.6246e-02, -8.1081e-02, -7.5231e-02],\n", - " [-1.0410e-01, -6.5782e-03, -2.2425e-02]],\n", - "\n", - " [[ 1.9414e-02, 3.7513e-02, 1.0480e-02],\n", - " [ 3.1074e-02, 4.5602e-02, 3.8310e-02],\n", - " [ 8.8012e-03, 2.3819e-03, 4.6176e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[-3.1133e-02, -2.8605e-03, -5.2947e-02],\n", - " [ 3.7506e-02, 6.3998e-04, -1.2003e-03],\n", - " [ 1.6662e-02, -3.0633e-02, -1.1023e-02]],\n", - "\n", - " [[-7.1672e-02, -3.3559e-02, 2.7411e-02],\n", - " [-1.2149e-02, -3.2008e-02, 5.0320e-02],\n", - " [-4.2444e-02, -6.7399e-02, 2.0852e-02]],\n", - "\n", - " [[ 1.1870e-04, -9.6183e-02, -8.2698e-02],\n", - " [ 1.0616e-02, -8.2367e-02, -6.7266e-02],\n", - " [-8.5162e-02, -2.5697e-02, -3.6014e-02]]],\n", - "\n", - "\n", - " [[[-2.4725e-02, 1.3686e-02, 4.1195e-02],\n", - " [ 1.0350e-02, 1.4489e-02, -3.2543e-02],\n", - " [-3.9483e-02, 4.3526e-03, -6.4983e-02]],\n", - "\n", - " [[-8.8199e-02, -6.7863e-02, -8.9248e-02],\n", - " [-9.4662e-02, -8.7976e-02, -1.0673e-01],\n", - " [-4.4556e-02, -1.0821e-01, -8.3664e-02]],\n", - "\n", - " [[-4.2361e-02, 2.5027e-02, 7.9239e-02],\n", - " [-7.9938e-02, 8.8911e-03, 3.7579e-02],\n", - " [ 2.5401e-02, 4.2908e-02, 8.5704e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[-2.8075e-02, 2.0473e-02, -9.1930e-03],\n", - " [ 3.9986e-02, 9.4368e-03, 1.3627e-02],\n", - " [ 4.2203e-02, 3.4609e-02, 3.5526e-02]],\n", - "\n", - " [[-5.1567e-02, -1.1104e-01, -2.5351e-02],\n", - " [-7.1857e-02, -1.0935e-01, -5.1078e-02],\n", - " [-4.2563e-02, -1.3174e-01, -1.1441e-01]],\n", - "\n", - " [[-2.1442e-02, -1.0601e-02, 2.1566e-02],\n", - " [ 4.4194e-02, -2.4603e-03, -2.8993e-02],\n", - " [ 4.4593e-02, 2.4336e-02, 4.1785e-02]]],\n", - "\n", - "\n", - " [[[-3.7110e-02, -2.1384e-02, -8.8386e-02],\n", - " [ 8.9586e-03, -2.1911e-02, -1.9369e-02],\n", - " [ 2.1264e-05, 1.0292e-02, -7.6875e-03]],\n", - "\n", - " [[-6.9804e-02, -2.7804e-02, -3.7626e-02],\n", - " [-1.0934e-04, -5.0990e-02, -1.1562e-01],\n", - " [-4.4917e-02, -3.5801e-02, -7.5220e-02]],\n", - "\n", - " [[-7.0519e-02, -1.0848e-02, 2.0123e-02],\n", - " [-3.8126e-02, -3.2042e-02, -2.0311e-02],\n", - " [-6.6819e-02, 3.2167e-04, -5.7737e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 3.5913e-03, 7.8448e-02, 8.0313e-02],\n", - " [ 1.4435e-02, -2.4045e-03, 6.7934e-03],\n", - " [-9.2509e-03, -9.8345e-03, 4.7002e-02]],\n", - "\n", - " [[ 1.6466e-02, 4.2463e-02, 7.0508e-02],\n", - " [ 3.4950e-02, 1.0808e-01, 7.7231e-02],\n", - " [-6.9534e-02, 4.1800e-02, 7.8877e-02]],\n", - "\n", - " [[-9.4433e-04, -6.5545e-02, 1.0643e-02],\n", - " [-1.2707e-02, -2.0921e-02, -7.0154e-02],\n", - " [-3.4571e-02, -3.0644e-02, -4.7103e-02]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[-7.4073e-02, -5.6490e-02, -7.8497e-03],\n", - " [-3.7687e-02, -7.1005e-02, -1.6297e-02],\n", - " [-4.7876e-02, -5.8232e-02, -6.4043e-02]],\n", - "\n", - " [[ 1.3346e-01, 9.8550e-02, 9.0538e-02],\n", - " [ 1.6279e-01, 4.8807e-02, 1.0326e-01],\n", - " [ 6.8256e-02, 8.8236e-02, 2.5214e-02]],\n", - "\n", - " [[ 5.0286e-02, -2.8655e-02, -3.9797e-02],\n", - " [ 8.1590e-02, 8.8670e-02, 3.2581e-02],\n", - " [ 6.6629e-02, 5.5634e-02, 6.7553e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-4.8289e-02, -1.8756e-02, -1.9947e-02],\n", - " [ 1.5246e-03, 2.4550e-02, -1.8470e-02],\n", - " [ 7.0791e-02, 5.5384e-02, 4.3716e-02]],\n", - "\n", - " [[-3.0825e-02, -2.1276e-02, -3.7002e-02],\n", - " [-6.0327e-02, -7.0949e-02, -8.2225e-02],\n", - " [-3.8889e-02, -3.2270e-02, -2.6290e-02]],\n", - "\n", - " [[-1.0067e-01, -1.1138e-02, -6.7776e-02],\n", - " [-9.9491e-02, -4.8505e-02, -5.5871e-02],\n", - " [-1.2496e-01, -1.1068e-01, -6.7283e-02]]],\n", - "\n", - "\n", - " [[[-1.5475e-03, -7.9660e-03, -3.8476e-02],\n", - " [ 5.9998e-02, 5.8792e-02, 1.5293e-02],\n", - " [ 9.0893e-02, -2.4707e-03, -1.3622e-02]],\n", - "\n", - " [[-2.2771e-02, 2.6777e-02, -2.0016e-03],\n", - " [-1.1872e-03, 7.3358e-02, -4.8738e-02],\n", - " [-3.2554e-03, -2.7650e-02, -3.6791e-03]],\n", - "\n", - " [[ 1.1710e-01, 6.0996e-03, -2.3733e-02],\n", - " [-6.7923e-04, 1.2859e-03, 2.7628e-03],\n", - " [-4.2115e-02, -2.6475e-02, 7.1679e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[ 3.7477e-02, 1.1140e-03, 6.3222e-02],\n", - " [ 4.8929e-02, 5.5243e-02, 1.1429e-03],\n", - " [-3.3385e-02, -4.7582e-02, -3.9440e-02]],\n", - "\n", - " [[-4.3734e-02, 1.0802e-02, 1.5586e-02],\n", - " [-2.0318e-02, 3.2691e-02, 2.9043e-03],\n", - " [-5.3472e-03, 5.9829e-03, 3.4189e-02]],\n", - "\n", - " [[-3.2688e-03, -5.7450e-02, 3.6852e-02],\n", - " [ 3.2528e-03, -1.5297e-03, 5.3320e-03],\n", - " [ 1.7151e-02, -1.1268e-03, 8.1022e-02]]],\n", - "\n", - "\n", - " [[[ 4.3947e-02, 4.7477e-02, -1.7206e-02],\n", - " [-5.8523e-02, -6.2340e-03, -4.0691e-02],\n", - " [-8.4128e-02, -2.6183e-02, 1.9188e-02]],\n", - "\n", - " [[-3.2490e-02, -9.6267e-02, -7.1499e-02],\n", - " [-7.1048e-04, -8.0449e-02, -5.1654e-02],\n", - " [ 2.3208e-02, -2.1737e-02, 4.9298e-02]],\n", - "\n", - " [[-4.4389e-02, -5.6683e-02, -2.3320e-02],\n", - " [-6.4499e-02, -3.0856e-02, 6.9859e-03],\n", - " [-5.6850e-02, -3.9121e-02, -4.1350e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 1.7911e-02, 4.6162e-02, 1.9657e-02],\n", - " [ 7.5653e-03, -4.2511e-02, -2.7971e-02],\n", - " [-3.6817e-02, -9.0424e-02, -1.6890e-02]],\n", - "\n", - " [[-2.2699e-02, 1.6050e-02, -2.7669e-02],\n", - " [-2.6073e-02, -5.1785e-02, -3.8797e-02],\n", - " [-6.3523e-02, -5.3364e-02, -9.2193e-02]],\n", - "\n", - " [[-5.3270e-02, -4.1235e-03, -5.8335e-03],\n", - " [-1.9613e-02, 5.4143e-02, -1.6930e-02],\n", - " [ 3.3865e-02, -2.8973e-03, -1.6124e-02]]]])\n", - "layer1.0.bn2.weight Parameter containing:\n", - "tensor([0.9079, 0.9314, 0.9273, 0.9928, 0.9441, 0.8893, 1.3586, 0.9319, 0.9911,\n", - " 0.8918, 0.9668, 1.2491, 0.9299, 0.8372, 0.9112, 1.1834, 0.8022, 0.9417,\n", - " 0.8730, 1.4600, 1.0404, 1.0169, 0.9842, 0.8882, 0.9578, 1.0128, 0.9548,\n", - " 0.8057, 0.7646, 0.9008, 1.0668, 0.8646, 0.9026, 1.0237, 0.7285, 1.0507,\n", - " 0.9604, 0.9052, 0.8586, 0.7162, 0.9041, 0.8687, 0.9544, 0.8602, 0.9671,\n", - " 1.0384, 0.9835, 0.9222, 0.9371, 0.7231, 0.8584, 0.9737, 0.9182, 0.7807,\n", - " 0.9644, 0.9615, 0.9434, 0.8674, 1.1694, 1.0835, 0.7759, 0.8638, 1.0058,\n", - " 0.9209])\n", - "layer1.0.bn2.bias Parameter containing:\n", - "tensor([-0.0163, -0.0560, -0.0696, -0.2069, -0.0330, -0.0317, 0.0248, -0.0275,\n", - " -0.0437, -0.0688, -0.0788, 0.1443, -0.1303, -0.0720, -0.0025, 0.0403,\n", - " 0.0716, -0.0432, -0.0148, 0.2234, -0.0015, 0.0852, 0.0106, -0.0835,\n", - " -0.0506, 0.0035, 0.0694, 0.0313, -0.0780, -0.0077, 0.0154, 0.0233,\n", - " -0.0140, 0.0316, 0.0475, -0.0478, -0.0771, 0.0058, -0.0953, 0.0017,\n", - " 0.0255, 0.0810, -0.0305, 0.0693, -0.1069, -0.0252, -0.0649, -0.0238,\n", - " -0.0523, 0.0308, 0.0118, 0.0199, -0.0296, 0.0458, -0.0450, -0.0723,\n", - " -0.0193, 0.0517, 0.0489, -0.0302, -0.1371, -0.1570, 0.0574, -0.1233])\n", - "layer1.1.conv1.weight Parameter containing:\n", - "tensor([[[[ 2.6593e-02, -3.4198e-02, 6.7904e-02],\n", - " [ 4.9502e-02, 1.7229e-02, 5.6829e-02],\n", - " [-8.7818e-03, 6.0911e-02, 4.0403e-02]],\n", - "\n", - " [[-7.6700e-02, -9.6916e-02, -5.3278e-02],\n", - " [-5.9268e-03, -2.2775e-02, -8.3431e-02],\n", - " [ 9.8249e-03, 4.8766e-02, -5.0661e-02]],\n", - "\n", - " [[-4.0173e-02, 2.9130e-02, 3.3362e-02],\n", - " [-1.5308e-02, 2.6540e-02, -9.1486e-03],\n", - " [-8.1136e-02, -7.1668e-02, 1.7501e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 4.9759e-02, 6.7114e-02, 9.5368e-03],\n", - " [ 2.4895e-02, 3.7341e-02, -9.1378e-03],\n", - " [ 1.3741e-02, -2.4999e-02, -9.5631e-03]],\n", - "\n", - " [[-2.3029e-02, -5.6797e-02, -7.2388e-02],\n", - " [ 2.0829e-02, 3.3307e-02, -4.1836e-02],\n", - " [-2.4193e-02, 1.8719e-02, -3.0852e-02]],\n", - "\n", - " [[-4.1146e-02, -5.9768e-02, -1.3481e-02],\n", - " [-2.4279e-02, -2.0338e-02, -4.9791e-02],\n", - " [ 1.4678e-02, -2.1303e-02, 3.9306e-02]]],\n", - "\n", - "\n", - " [[[-2.4511e-02, -1.0791e-02, -5.8488e-02],\n", - " [-3.8544e-02, -6.9196e-03, -4.3451e-02],\n", - " [-3.9745e-02, -4.2166e-02, -4.0130e-02]],\n", - "\n", - " [[ 2.2328e-03, 2.0923e-02, 7.5556e-02],\n", - " [ 1.6712e-02, 6.0325e-02, 2.8951e-02],\n", - " [ 5.7430e-02, 7.4455e-03, 1.1946e-02]],\n", - "\n", - " [[ 2.1712e-02, -1.3877e-02, -2.8872e-03],\n", - " [-6.3337e-02, -6.5354e-02, -4.3473e-02],\n", - " [-7.9494e-03, -4.2597e-02, -3.6345e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[ 1.3227e-02, 1.3381e-02, 1.2781e-02],\n", - " [ 1.1976e-02, 1.7176e-03, 7.5455e-02],\n", - " [ 3.3403e-03, 4.9971e-02, 2.0079e-02]],\n", - "\n", - " [[-7.2390e-03, -3.0566e-02, 1.6535e-02],\n", - " [ 2.0263e-02, -1.2138e-02, 4.2431e-02],\n", - " [ 2.3446e-02, 4.1872e-04, -8.2290e-03]],\n", - "\n", - " [[-2.5881e-03, -3.2973e-02, 4.8453e-02],\n", - " [ 1.2530e-02, 2.2892e-02, 4.6676e-02],\n", - " [-4.0720e-02, 4.1090e-02, 3.1536e-02]]],\n", - "\n", - "\n", - " [[[-4.4590e-02, 3.2132e-02, 5.0317e-02],\n", - " [ 5.8723e-02, 2.7518e-02, 5.2837e-02],\n", - " [ 4.5984e-02, -4.8859e-02, 2.8074e-02]],\n", - "\n", - " [[ 3.5439e-02, 4.2154e-02, 6.2620e-03],\n", - " [ 1.2440e-02, 9.4835e-02, -1.4920e-02],\n", - " [ 9.3768e-02, 6.6834e-02, -2.9676e-02]],\n", - "\n", - " [[-1.0067e-02, -2.6803e-02, -3.5374e-02],\n", - " [-8.0500e-03, 2.7707e-02, -4.1172e-02],\n", - " [-4.7432e-02, -3.9620e-02, -4.4891e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 4.3233e-02, 2.7552e-02, 1.3127e-02],\n", - " [ 6.8530e-02, 4.6166e-02, 7.0684e-03],\n", - " [ 5.6551e-02, -3.3509e-02, -5.4552e-02]],\n", - "\n", - " [[-1.7251e-02, 3.8119e-02, 1.4739e-02],\n", - " [-3.9602e-02, 3.8198e-02, 7.9852e-02],\n", - " [-2.2415e-02, 1.7422e-02, 7.2721e-02]],\n", - "\n", - " [[-2.1743e-02, -2.1620e-02, -2.9930e-02],\n", - " [ 5.2186e-02, 4.4009e-02, -1.4880e-02],\n", - " [ 8.2744e-02, 4.0110e-02, 6.9476e-02]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[ 5.4301e-02, 5.8203e-02, 5.1051e-02],\n", - " [ 6.0607e-03, -2.2593e-03, 4.8936e-02],\n", - " [ 2.2871e-02, 9.6164e-03, 1.8444e-02]],\n", - "\n", - " [[ 5.7914e-02, 6.4484e-02, -2.3707e-03],\n", - " [ 2.6293e-02, 4.1130e-02, 1.4373e-02],\n", - " [ 4.6521e-02, 6.0853e-02, 3.2299e-02]],\n", - "\n", - " [[ 4.5852e-02, 2.4686e-02, 2.4303e-02],\n", - " [ 3.1945e-02, 3.4674e-02, -1.7316e-02],\n", - " [-2.4934e-02, 5.7327e-02, 2.0576e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 1.4928e-02, 9.5490e-03, -6.2237e-02],\n", - " [-2.0978e-02, 1.9526e-02, 6.8305e-03],\n", - " [ 4.8388e-03, -3.7300e-02, -3.6364e-02]],\n", - "\n", - " [[-3.6045e-02, 1.8570e-02, 4.3668e-02],\n", - " [-2.9774e-02, -1.0319e-02, 1.3255e-02],\n", - " [-2.2707e-02, 1.9602e-02, 7.9169e-04]],\n", - "\n", - " [[ 2.3576e-03, -1.9243e-03, 1.2709e-02],\n", - " [ 3.2156e-03, 3.6086e-02, 3.8457e-02],\n", - " [-2.2701e-02, 1.5551e-02, -4.0218e-02]]],\n", - "\n", - "\n", - " [[[ 4.7197e-02, 2.8875e-02, 5.8667e-03],\n", - " [-2.7055e-03, 3.3594e-02, 1.2208e-02],\n", - " [ 9.2712e-04, -1.1782e-02, -1.3044e-03]],\n", - "\n", - " [[-8.2965e-03, -7.0715e-03, -4.0541e-02],\n", - " [-1.0041e-02, -3.4150e-02, -2.7035e-02],\n", - " [ 1.0155e-02, 3.3201e-02, -1.4504e-02]],\n", - "\n", - " [[-7.8979e-02, -7.4035e-02, -8.0417e-02],\n", - " [-3.3371e-03, -4.3919e-04, -9.0875e-02],\n", - " [-4.7493e-02, -5.0132e-02, -2.4897e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[-4.9027e-02, -3.2874e-02, -1.1205e-02],\n", - " [-4.4780e-05, -1.0628e-02, 1.3734e-04],\n", - " [-2.6351e-02, -6.2976e-03, -3.4885e-02]],\n", - "\n", - " [[-3.6896e-02, -5.0080e-02, -1.3959e-02],\n", - " [-3.0033e-02, -1.8250e-02, -1.6171e-02],\n", - " [-1.2968e-02, -7.3765e-02, -3.8385e-02]],\n", - "\n", - " [[-1.1514e-02, -1.0031e-02, -4.4923e-02],\n", - " [-5.2885e-02, -8.0522e-02, 2.2228e-04],\n", - " [-8.0144e-03, -5.1980e-02, -6.4364e-02]]],\n", - "\n", - "\n", - " [[[ 2.0123e-02, 4.2277e-02, -8.7214e-03],\n", - " [-1.0368e-02, 3.4194e-02, 4.5498e-02],\n", - " [ 9.1709e-02, 9.3509e-02, 9.3661e-03]],\n", - "\n", - " [[ 1.2534e-02, 5.9424e-03, 1.4623e-02],\n", - " [ 3.2083e-02, 1.1867e-02, 6.8766e-02],\n", - " [ 1.9279e-02, 1.5177e-03, 1.6413e-02]],\n", - "\n", - " [[ 4.6366e-03, 8.8012e-02, 3.8882e-02],\n", - " [ 6.3008e-02, 9.7451e-02, 2.7030e-03],\n", - " [ 7.2544e-02, 4.6869e-02, 7.6242e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 8.5150e-03, -1.8553e-02, -4.8858e-02],\n", - " [-2.0450e-02, -4.0801e-03, -4.4920e-02],\n", - " [-1.4873e-02, 1.2599e-02, -2.8559e-02]],\n", - "\n", - " [[ 2.6483e-02, -2.5318e-03, 2.3260e-02],\n", - " [ 1.4355e-02, 2.9071e-03, 3.7970e-02],\n", - " [ 1.7436e-02, 1.5663e-02, 2.8179e-02]],\n", - "\n", - " [[ 5.3114e-02, 4.9337e-02, 1.1110e-01],\n", - " [ 4.0582e-02, 7.0569e-02, 3.7544e-02],\n", - " [ 5.0659e-02, 6.7754e-02, 6.0954e-02]]]])\n", - "layer1.1.bn1.weight Parameter containing:\n", - "tensor([1.0599, 0.9466, 1.0883, 1.0216, 0.9549, 1.0047, 0.9343, 1.1112, 0.9535,\n", - " 0.9474, 0.9471, 1.0289, 1.0615, 0.9349, 0.9411, 0.8893, 1.0767, 1.0939,\n", - " 1.0169, 0.9414, 1.1864, 0.9885, 1.0232, 1.3587, 0.9706, 0.9471, 0.9786,\n", - " 0.9206, 0.9999, 0.9639, 1.0253, 1.0520, 1.1145, 0.9727, 0.9409, 0.9877,\n", - " 0.9381, 0.9588, 1.0095, 1.0971, 0.9343, 1.0072, 0.9908, 0.9240, 0.9385,\n", - " 0.9874, 0.8669, 1.1432, 0.8880, 0.9495, 0.9635, 0.9576, 1.1564, 0.9145,\n", - " 0.9872, 0.9669, 0.9895, 1.0291, 0.9717, 0.8540, 1.0361, 0.9782, 0.9716,\n", - " 1.0026])\n", - "layer1.1.bn1.bias Parameter containing:\n", - "tensor([-0.0409, -0.0949, 0.0122, -0.0208, 0.0112, 0.0394, -0.0883, -0.0212,\n", - " -0.0387, -0.0367, -0.1526, -0.0442, -0.0031, -0.0585, 0.0076, -0.1048,\n", - " -0.0662, -0.0024, -0.0183, 0.0024, 0.0064, -0.0465, -0.0546, 0.2693,\n", - " -0.0167, -0.0446, -0.0386, -0.0651, -0.0097, 0.0304, 0.0280, 0.0169,\n", - " 0.0934, -0.1040, -0.0748, -0.0120, -0.0615, -0.0757, -0.0773, -0.0143,\n", - " -0.0449, -0.0480, -0.0719, -0.1599, -0.0103, -0.0578, -0.1456, 0.0326,\n", - " -0.1470, -0.1311, -0.0437, -0.0365, 0.0418, -0.0848, -0.1881, -0.0249,\n", - " 0.0098, -0.0431, 0.0236, -0.1328, -0.0497, -0.0338, -0.0273, -0.0168])\n", - "layer1.1.conv2.weight Parameter containing:\n", - "tensor([[[[-0.0425, 0.0580, 0.0541],\n", - " [-0.0597, 0.0107, -0.0336],\n", - " [ 0.0133, -0.0100, -0.0221]],\n", - "\n", - " [[-0.0303, -0.0345, -0.0381],\n", - " [-0.0578, -0.0366, 0.0113],\n", - " [-0.0443, -0.0016, 0.0039]],\n", - "\n", - " [[-0.0271, -0.0357, -0.0787],\n", - " [ 0.0236, -0.0049, -0.0762],\n", - " [-0.0060, -0.0611, -0.0399]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0652, -0.0172, 0.0024],\n", - " [-0.0414, -0.0345, 0.0072],\n", - " [ 0.0390, 0.0039, -0.0515]],\n", - "\n", - " [[-0.0112, -0.1150, -0.0668],\n", - " [-0.0464, -0.0827, -0.0895],\n", - " [-0.0056, -0.0743, -0.0646]],\n", - "\n", - " [[ 0.0104, 0.0106, 0.0315],\n", - " [ 0.0361, -0.0349, -0.0232],\n", - " [ 0.0040, -0.0481, -0.0696]]],\n", - "\n", - "\n", - " [[[-0.0087, 0.0416, 0.0631],\n", - " [ 0.0012, -0.0020, 0.0648],\n", - " [-0.0254, 0.0522, 0.0958]],\n", - "\n", - " [[ 0.0291, -0.0027, 0.0749],\n", - " [ 0.0146, 0.0427, 0.0398],\n", - " [ 0.0216, 0.0151, 0.0305]],\n", - "\n", - " [[-0.0142, -0.0535, -0.0562],\n", - " [-0.0465, 0.0428, -0.0320],\n", - " [ 0.0533, 0.0121, -0.0514]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0050, 0.0087, -0.0373],\n", - " [-0.0023, -0.0032, -0.0080],\n", - " [-0.0019, 0.0070, -0.0359]],\n", - "\n", - " [[ 0.0295, 0.0137, -0.0127],\n", - " [ 0.0067, -0.0188, -0.0028],\n", - " [-0.0058, 0.0057, -0.0498]],\n", - "\n", - " [[-0.0124, -0.0034, 0.0735],\n", - " [ 0.0373, 0.0344, 0.0147],\n", - " [-0.0111, -0.0272, -0.0573]]],\n", - "\n", - "\n", - " [[[-0.0064, -0.0499, -0.0069],\n", - " [ 0.0320, 0.0582, 0.0055],\n", - " [ 0.0213, 0.0547, 0.0568]],\n", - "\n", - " [[-0.0671, -0.0392, -0.0588],\n", - " [-0.0143, -0.0592, -0.0417],\n", - " [-0.0555, -0.0098, 0.0159]],\n", - "\n", - " [[-0.0233, 0.0063, 0.0173],\n", - " [ 0.0356, 0.0272, 0.0058],\n", - " [ 0.0783, -0.0183, 0.0077]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0018, -0.0190, -0.0682],\n", - " [ 0.0239, -0.0067, -0.0636],\n", - " [ 0.0168, 0.0153, -0.0345]],\n", - "\n", - " [[ 0.0611, -0.0152, -0.0178],\n", - " [ 0.0012, -0.0079, 0.0055],\n", - " [-0.0247, 0.0308, -0.0334]],\n", - "\n", - " [[ 0.0216, -0.0141, -0.0028],\n", - " [ 0.0329, 0.0348, -0.0168],\n", - " [ 0.0369, 0.0396, 0.0111]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[ 0.0227, -0.0292, 0.0172],\n", - " [ 0.0439, -0.0371, -0.0349],\n", - " [-0.0452, -0.0315, -0.0630]],\n", - "\n", - " [[ 0.0070, 0.0408, 0.0525],\n", - " [ 0.0316, 0.0162, 0.0199],\n", - " [-0.0101, 0.0084, 0.0361]],\n", - "\n", - " [[-0.0197, -0.0297, 0.0734],\n", - " [-0.0344, 0.0061, 0.0714],\n", - " [-0.0281, 0.0119, 0.0195]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0260, -0.0386, -0.0111],\n", - " [-0.0640, -0.0643, -0.0123],\n", - " [-0.0149, -0.0794, -0.0550]],\n", - "\n", - " [[ 0.0072, -0.0051, 0.0791],\n", - " [ 0.0199, 0.0463, 0.0407],\n", - " [ 0.0329, 0.0058, -0.0410]],\n", - "\n", - " [[ 0.0060, -0.0138, 0.0508],\n", - " [-0.0337, 0.0088, 0.0093],\n", - " [ 0.0649, 0.0187, 0.0765]]],\n", - "\n", - "\n", - " [[[-0.0246, -0.0090, 0.0458],\n", - " [-0.0393, 0.0483, 0.0357],\n", - " [-0.0376, 0.0400, 0.0190]],\n", - "\n", - " [[-0.0568, 0.0053, -0.0201],\n", - " [ 0.0060, -0.0178, -0.0189],\n", - " [ 0.0246, 0.0467, 0.0395]],\n", - "\n", - " [[-0.0069, 0.0289, 0.0062],\n", - " [ 0.0210, -0.0279, 0.0359],\n", - " [ 0.0072, -0.0665, -0.0325]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0284, 0.0276, -0.0132],\n", - " [-0.0086, 0.0012, -0.0279],\n", - " [ 0.0037, -0.0484, -0.0750]],\n", - "\n", - " [[-0.0388, -0.0717, -0.0088],\n", - " [ 0.0191, -0.0185, -0.0208],\n", - " [-0.0333, -0.0268, 0.0093]],\n", - "\n", - " [[ 0.0156, -0.0465, -0.0360],\n", - " [ 0.0242, 0.0537, 0.0047],\n", - " [ 0.0500, -0.0474, -0.0115]]],\n", - "\n", - "\n", - " [[[-0.0350, -0.0504, -0.0429],\n", - " [-0.0300, -0.0106, -0.0264],\n", - " [-0.0137, -0.0023, 0.0497]],\n", - "\n", - " [[-0.0818, -0.0710, -0.0014],\n", - " [-0.0329, -0.0905, -0.0708],\n", - " [-0.0799, -0.0574, -0.0885]],\n", - "\n", - " [[-0.0414, -0.0102, 0.0219],\n", - " [-0.0278, -0.0462, -0.0012],\n", - " [ 0.0137, -0.0312, -0.0250]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0034, 0.0267, 0.0137],\n", - " [ 0.0144, 0.0389, 0.0243],\n", - " [-0.0171, 0.0232, 0.0449]],\n", - "\n", - " [[-0.0145, -0.0584, -0.0150],\n", - " [-0.0247, 0.0208, -0.0215],\n", - " [ 0.0675, 0.0202, 0.0110]],\n", - "\n", - " [[ 0.0119, -0.0197, 0.0159],\n", - " [ 0.0221, -0.0354, -0.0394],\n", - " [ 0.0256, -0.0396, -0.0515]]]])\n", - "layer1.1.bn2.weight Parameter containing:\n", - "tensor([0.8659, 0.9570, 0.8891, 0.9045, 1.0087, 1.0288, 0.9630, 0.8839, 0.9474,\n", - " 0.8482, 0.9522, 1.0135, 0.9082, 0.7224, 0.7214, 0.9890, 0.8385, 0.9520,\n", - " 0.9855, 1.1083, 0.9786, 0.8302, 0.9394, 1.0564, 0.8630, 0.9151, 1.0819,\n", - " 0.8662, 0.7969, 0.9081, 0.8537, 0.8439, 0.9650, 0.9277, 0.7604, 0.8773,\n", - " 1.0263, 0.9107, 0.8561, 0.7252, 0.9475, 1.0708, 0.8672, 0.9553, 0.9996,\n", - " 0.9246, 0.9422, 0.9534, 0.7877, 0.7036, 0.9022, 0.9321, 0.8827, 0.7797,\n", - " 0.9681, 0.9808, 0.9166, 0.8894, 0.7313, 1.0279, 0.8668, 0.7882, 0.9513,\n", - " 0.9528])\n", - "layer1.1.bn2.bias Parameter containing:\n", - "tensor([-0.0263, -0.0246, -0.0380, -0.1532, 0.0098, 0.0709, -0.0320, -0.0192,\n", - " -0.0470, -0.0412, -0.0918, -0.1333, -0.0979, -0.1266, -0.0348, -0.0114,\n", - " -0.1043, -0.0150, -0.0406, 0.0311, 0.0046, -0.0594, 0.0503, 0.0359,\n", - " -0.0266, -0.0136, -0.0352, -0.0149, -0.1466, 0.0198, -0.0174, 0.0008,\n", - " -0.0475, 0.0132, -0.0096, -0.1214, -0.0484, -0.0154, -0.0565, -0.0403,\n", - " 0.0498, 0.0305, -0.0371, 0.0178, -0.0004, -0.0951, -0.0434, -0.0246,\n", - " -0.0990, -0.0080, -0.0547, -0.0164, -0.0435, 0.0026, -0.0411, -0.0411,\n", - " 0.0899, -0.0145, 0.0086, -0.0031, -0.0222, -0.0939, 0.1107, -0.0317])\n", - "layer2.0.conv1.weight Parameter containing:\n", - "tensor([[[[ 0.0158, 0.0590, 0.0726],\n", - " [ 0.0895, 0.0307, 0.0540],\n", - " [ 0.0257, 0.0603, 0.0361]],\n", - "\n", - " [[-0.0074, 0.0527, 0.0438],\n", - " [-0.0517, 0.0088, 0.0551],\n", - " [ 0.0119, 0.0110, 0.0201]],\n", - "\n", - " [[-0.0075, 0.0335, 0.0092],\n", - " [ 0.0356, -0.0221, -0.0678],\n", - " [ 0.0104, 0.0493, -0.0018]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0101, 0.0331, 0.0198],\n", - " [-0.0509, 0.0216, 0.0288],\n", - " [ 0.0267, 0.0396, 0.0091]],\n", - "\n", - " [[ 0.0025, 0.0212, 0.0050],\n", - " [-0.0509, 0.0145, -0.0062],\n", - " [-0.0173, 0.0092, 0.0103]],\n", - "\n", - " [[ 0.0235, 0.0254, -0.0029],\n", - " [ 0.0238, -0.0455, -0.0025],\n", - " [-0.0164, -0.0220, -0.0331]]],\n", - "\n", - "\n", - " [[[-0.0046, 0.0073, -0.0030],\n", - " [ 0.0025, 0.0385, 0.0005],\n", - " [-0.0113, 0.0183, -0.0256]],\n", - "\n", - " [[-0.0707, -0.0429, 0.0003],\n", - " [-0.0114, -0.0280, 0.0282],\n", - " [-0.0454, -0.0544, -0.0074]],\n", - "\n", - " [[-0.0286, 0.0242, 0.0237],\n", - " [ 0.0335, -0.0229, -0.0133],\n", - " [ 0.0259, 0.0158, 0.0215]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0169, -0.0294, -0.0425],\n", - " [-0.0587, -0.0123, 0.0025],\n", - " [-0.0308, -0.0002, 0.0036]],\n", - "\n", - " [[ 0.0214, -0.0275, -0.0157],\n", - " [-0.0303, -0.0238, 0.0302],\n", - " [-0.0468, 0.0252, -0.0138]],\n", - "\n", - " [[ 0.0201, -0.0123, -0.0140],\n", - " [ 0.0291, -0.0242, 0.0228],\n", - " [-0.0296, -0.0166, 0.0087]]],\n", - "\n", - "\n", - " [[[-0.0199, 0.0162, -0.0597],\n", - " [-0.0074, 0.0055, 0.0203],\n", - " [ 0.0311, -0.0217, 0.0030]],\n", - "\n", - " [[ 0.0209, 0.0524, -0.0269],\n", - " [ 0.0419, 0.0137, -0.0015],\n", - " [ 0.0096, 0.0223, -0.0197]],\n", - "\n", - " [[-0.0443, -0.0454, -0.0844],\n", - " [ 0.0445, 0.0140, -0.0364],\n", - " [ 0.0462, 0.0051, 0.0261]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0373, -0.0513, -0.0283],\n", - " [ 0.0204, -0.0505, -0.0076],\n", - " [-0.0416, -0.0456, -0.0535]],\n", - "\n", - " [[-0.0130, -0.0107, 0.0106],\n", - " [-0.0009, 0.0307, 0.0062],\n", - " [ 0.0187, 0.0160, -0.0581]],\n", - "\n", - " [[-0.0268, -0.0571, -0.0176],\n", - " [ 0.0213, -0.0509, -0.0372],\n", - " [ 0.0095, 0.0063, 0.0298]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[-0.0013, -0.0230, 0.0109],\n", - " [ 0.0346, 0.0035, 0.0074],\n", - " [ 0.0350, 0.0549, 0.0220]],\n", - "\n", - " [[-0.0379, 0.0252, -0.0286],\n", - " [ 0.0210, -0.0278, 0.0077],\n", - " [ 0.0042, -0.0041, 0.0536]],\n", - "\n", - " [[ 0.0404, -0.0401, -0.0298],\n", - " [-0.0400, -0.0587, -0.0317],\n", - " [-0.0152, -0.0491, -0.0283]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0188, -0.0144, -0.0377],\n", - " [-0.0201, -0.0266, -0.0110],\n", - " [-0.0168, 0.0445, -0.0206]],\n", - "\n", - " [[-0.0047, -0.0460, 0.0059],\n", - " [-0.0314, -0.0195, -0.0260],\n", - " [-0.0873, -0.0550, -0.0475]],\n", - "\n", - " [[ 0.0206, -0.0238, -0.0055],\n", - " [-0.0567, 0.0145, 0.0048],\n", - " [ 0.0315, -0.0404, -0.0358]]],\n", - "\n", - "\n", - " [[[ 0.0376, 0.0002, 0.0187],\n", - " [-0.0714, -0.0410, -0.0170],\n", - " [-0.0173, -0.0522, -0.0071]],\n", - "\n", - " [[ 0.0622, 0.0299, -0.0092],\n", - " [ 0.0005, -0.0267, 0.0519],\n", - " [ 0.0598, 0.0151, 0.0693]],\n", - "\n", - " [[-0.0858, -0.0869, -0.0007],\n", - " [-0.0152, -0.0226, 0.0009],\n", - " [-0.0550, 0.0011, -0.0550]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0283, 0.0049, 0.0274],\n", - " [ 0.0528, 0.0098, -0.0040],\n", - " [ 0.0354, 0.0302, 0.0458]],\n", - "\n", - " [[-0.0104, 0.0141, -0.0294],\n", - " [-0.0277, -0.0633, 0.0178],\n", - " [ 0.0387, 0.0451, -0.0141]],\n", - "\n", - " [[-0.0456, -0.0294, -0.0607],\n", - " [ 0.0065, -0.0081, -0.0275],\n", - " [-0.0255, -0.0203, -0.0485]]],\n", - "\n", - "\n", - " [[[-0.0105, 0.0096, -0.0166],\n", - " [-0.0391, -0.0354, -0.0003],\n", - " [-0.0502, 0.0183, 0.0312]],\n", - "\n", - " [[ 0.0169, -0.0024, -0.0177],\n", - " [-0.0148, 0.0153, 0.0100],\n", - " [-0.0366, 0.0393, 0.0128]],\n", - "\n", - " [[ 0.0077, 0.0533, 0.0733],\n", - " [ 0.0163, -0.0186, -0.0017],\n", - " [-0.0240, 0.0063, 0.0089]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0015, 0.0222, -0.0031],\n", - " [-0.0379, -0.0089, 0.0327],\n", - " [-0.0173, -0.0029, 0.0263]],\n", - "\n", - " [[-0.0621, 0.0345, 0.0175],\n", - " [-0.0643, 0.0059, 0.0182],\n", - " [ 0.0182, 0.0213, 0.0743]],\n", - "\n", - " [[ 0.0029, 0.0253, 0.0361],\n", - " [-0.0325, 0.0326, 0.0458],\n", - " [ 0.0121, -0.0050, 0.0581]]]])\n", - "layer2.0.bn1.weight Parameter containing:\n", - "tensor([0.9856, 0.9561, 1.0298, 0.9849, 1.0096, 1.0142, 0.9732, 1.0974, 0.9764,\n", - " 0.9295, 0.9870, 0.9391, 0.9860, 0.9508, 0.9131, 0.9831, 1.0883, 1.0482,\n", - " 1.0211, 1.0542, 0.9074, 0.9631, 0.9923, 1.0211, 0.9294, 0.9730, 0.9817,\n", - " 1.1598, 0.9025, 0.9847, 1.0318, 1.0859, 1.0417, 1.0281, 1.0176, 1.0079,\n", - " 1.0154, 1.0588, 1.0517, 1.1073, 1.0900, 0.9344, 0.9602, 1.0996, 1.0282,\n", - " 1.0062, 1.0235, 1.0970, 0.9816, 1.0375, 0.9065, 0.9550, 1.0763, 0.8873,\n", - " 0.9492, 0.9854, 0.9797, 0.9520, 1.0777, 1.0084, 0.9395, 0.9863, 0.8887,\n", - " 0.9469, 1.1168, 0.9005, 1.1311, 0.9260, 0.9452, 0.9881, 0.8533, 0.9687,\n", - " 0.9784, 0.9794, 1.0136, 0.9528, 0.9522, 0.9981, 0.8862, 1.0042, 0.9056,\n", - " 0.8985, 1.1169, 0.9522, 1.0590, 0.8980, 0.9304, 1.0703, 1.0742, 0.9503,\n", - " 1.0123, 1.0671, 0.9040, 0.9954, 1.1495, 1.0520, 1.0443, 0.9952, 0.9958,\n", - " 0.9039, 1.0339, 0.9977, 0.9204, 0.9406, 0.9400, 1.0673, 0.9102, 1.0681,\n", - " 0.8994, 0.9404, 0.9741, 0.9414, 0.9928, 1.0733, 1.0143, 1.0218, 1.2354,\n", - " 1.1403, 1.0164, 1.0249, 0.9757, 0.9149, 0.9944, 1.1335, 0.8665, 1.1207,\n", - " 0.9400, 0.9708])\n", - "layer2.0.bn1.bias Parameter containing:\n", - "tensor([-4.4408e-02, -6.3460e-02, -5.9487e-02, 3.9959e-02, 1.3809e-02,\n", - " 5.7571e-03, -6.3010e-02, 1.1827e-02, -4.0396e-02, -6.0959e-02,\n", - " -6.8656e-02, -5.4246e-02, 1.2467e-03, -7.8717e-02, -8.8926e-02,\n", - " -4.7148e-03, -6.2139e-02, -1.5113e-02, -8.7445e-02, -4.2238e-02,\n", - " -5.4938e-02, -3.2369e-02, -7.4581e-02, -3.3462e-02, -4.9141e-02,\n", - " -4.9546e-02, -7.5260e-02, 5.9013e-02, -1.0577e-01, -1.0421e-01,\n", - " -2.5820e-04, -2.0431e-02, -5.1903e-02, -1.8563e-02, -6.3990e-02,\n", - " -2.6435e-02, -1.3204e-02, -3.7876e-02, -2.3634e-02, -5.6133e-02,\n", - " -4.8304e-02, -8.8873e-02, -2.3738e-02, -1.1452e-02, -4.6369e-02,\n", - " -2.9358e-03, 8.4575e-03, 2.1221e-02, -4.8208e-02, 1.7766e-02,\n", - " -1.3707e-01, -5.6345e-02, 2.1062e-02, -1.4223e-01, -6.9033e-03,\n", - " -6.6139e-02, -7.3860e-02, -5.2747e-02, 7.9480e-02, -4.0997e-02,\n", - " -1.0225e-02, 1.3592e-02, -8.0994e-02, -6.3811e-02, -1.8184e-02,\n", - " -6.4378e-02, 3.6648e-02, -8.9201e-02, -6.7080e-02, -3.2439e-02,\n", - " -1.4202e-01, -1.4953e-02, -5.7625e-02, -4.2923e-02, -2.5661e-02,\n", - " -3.7217e-02, -2.7865e-02, -5.7743e-02, -4.3572e-02, 9.9375e-03,\n", - " -6.7939e-02, -9.3802e-02, -2.8521e-02, -5.0249e-02, 5.2249e-02,\n", - " -1.2405e-01, -5.6959e-02, -3.5813e-02, -1.4070e-02, -1.1760e-01,\n", - " 1.4158e-03, -1.2018e-02, -8.2153e-02, -1.9360e-02, 1.9967e-02,\n", - " -3.5138e-02, -8.8690e-02, 5.2727e-03, -1.1467e-02, -6.7700e-02,\n", - " 4.8841e-02, -3.8626e-02, -1.0003e-01, -4.1781e-02, -5.3253e-02,\n", - " -3.4413e-02, -7.0329e-02, -9.5402e-05, -1.0329e-01, -5.2985e-02,\n", - " 1.8760e-02, -2.0112e-02, -6.0653e-02, -5.8835e-02, -7.4210e-02,\n", - " 3.0239e-02, -7.5367e-02, -2.2648e-02, -4.0464e-02, -6.7306e-02,\n", - " 4.4584e-02, -6.6090e-02, 2.0774e-02, -4.0260e-02, -1.1678e-01,\n", - " 6.9928e-02, -5.5673e-02, -3.2468e-02])\n", - "layer2.0.conv2.weight Parameter containing:\n", - "tensor([[[[-2.9326e-02, 1.4776e-04, 1.3197e-02],\n", - " [-1.6475e-02, -4.4652e-02, -2.1389e-03],\n", - " [-4.7745e-04, 1.1555e-02, 1.1161e-02]],\n", - "\n", - " [[-8.2114e-03, 7.7707e-03, 3.8702e-02],\n", - " [ 4.8851e-02, 2.2664e-02, -3.6508e-04],\n", - " [-1.8164e-02, -1.9086e-02, -7.5711e-04]],\n", - "\n", - " [[-1.2012e-02, 5.1844e-03, -1.6460e-02],\n", - " [-4.1903e-02, 1.6704e-02, 1.1525e-02],\n", - " [-2.8750e-02, -1.8928e-02, -2.8496e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-3.4732e-02, -2.8028e-02, -4.9673e-02],\n", - " [ 2.3525e-02, -7.4910e-03, -1.2319e-02],\n", - " [-1.1767e-02, -2.2846e-03, 1.4591e-02]],\n", - "\n", - " [[ 2.5598e-02, 3.6546e-02, 3.5433e-02],\n", - " [ 3.2633e-02, 4.1257e-03, -2.0384e-02],\n", - " [-1.6832e-02, -2.4976e-02, -1.0932e-02]],\n", - "\n", - " [[-2.4550e-02, 1.4871e-02, -2.1896e-03],\n", - " [-3.5753e-02, 3.3367e-04, 3.5690e-03],\n", - " [ 9.3313e-03, -1.0965e-02, -3.1038e-03]]],\n", - "\n", - "\n", - " [[[-6.1483e-02, -1.5989e-02, -5.0814e-02],\n", - " [-7.6276e-03, 1.1751e-02, -2.3630e-02],\n", - " [-1.1072e-02, -8.0032e-03, 1.2706e-02]],\n", - "\n", - " [[ 3.1783e-02, 2.4595e-02, 4.2654e-02],\n", - " [ 3.6219e-02, 1.1338e-02, 4.2761e-02],\n", - " [-1.9528e-03, 2.3325e-02, -2.9664e-02]],\n", - "\n", - " [[ 1.7203e-02, 2.9008e-02, 1.6512e-02],\n", - " [-2.1788e-02, -1.5973e-02, 3.8330e-03],\n", - " [-6.3857e-02, -4.3809e-02, -1.8129e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 1.3763e-02, -1.1545e-02, -1.4335e-02],\n", - " [-3.4292e-02, -3.5164e-03, -3.0173e-02],\n", - " [-4.4377e-02, -1.6386e-02, -2.8011e-02]],\n", - "\n", - " [[-3.7781e-02, -3.0508e-02, -9.6721e-04],\n", - " [ 1.6204e-02, 2.7993e-02, 5.1552e-02],\n", - " [ 5.3977e-03, 1.0185e-02, -3.6914e-02]],\n", - "\n", - " [[ 3.8443e-02, 3.5896e-02, 4.3196e-02],\n", - " [-4.8815e-04, -3.2587e-02, -7.3151e-03],\n", - " [ 1.6852e-02, -3.9039e-02, -3.5906e-02]]],\n", - "\n", - "\n", - " [[[ 1.6662e-02, 3.7526e-02, 3.4958e-02],\n", - " [ 4.3510e-02, 3.0857e-02, 1.3001e-02],\n", - " [-5.7000e-03, -7.4834e-03, -4.0627e-02]],\n", - "\n", - " [[-3.2959e-02, -2.2397e-02, 1.4952e-02],\n", - " [-2.7921e-02, 1.1206e-02, 5.2647e-03],\n", - " [-5.5147e-03, 3.3321e-02, 6.1338e-02]],\n", - "\n", - " [[-1.4455e-02, 1.6981e-02, 2.7939e-02],\n", - " [ 4.5812e-02, 6.6108e-02, 5.7073e-02],\n", - " [-1.3593e-02, -1.7858e-03, 1.6507e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-5.2863e-02, -1.2127e-02, 3.7099e-02],\n", - " [-4.4639e-02, 2.2652e-02, -3.6863e-05],\n", - " [-4.7820e-02, 1.3688e-02, 3.1525e-02]],\n", - "\n", - " [[-3.2835e-02, -4.7681e-02, 2.5508e-02],\n", - " [-2.3645e-02, -7.2019e-03, -2.9590e-02],\n", - " [-1.4003e-02, -2.3828e-02, -2.2238e-02]],\n", - "\n", - " [[-2.7537e-04, 1.6434e-02, 3.0733e-02],\n", - " [ 7.4750e-03, -5.3029e-03, 2.6011e-02],\n", - " [ 4.8595e-02, 2.9357e-03, 5.6132e-02]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[-3.5815e-02, -1.3693e-02, -2.7741e-02],\n", - " [-3.0585e-02, -4.0935e-02, -3.9116e-02],\n", - " [-2.7633e-02, -2.2705e-02, 1.9924e-02]],\n", - "\n", - " [[ 2.8751e-02, 1.9905e-02, 2.9578e-02],\n", - " [ 3.6153e-02, -4.5332e-03, 3.3908e-02],\n", - " [ 3.7232e-03, 3.7734e-04, 2.3319e-02]],\n", - "\n", - " [[-7.5011e-03, -1.1992e-02, -1.8491e-02],\n", - " [-5.2015e-03, 1.5962e-02, -6.3561e-03],\n", - " [ 2.9083e-03, 1.9739e-02, -1.9642e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.7470e-02, -9.8844e-03, 8.6894e-03],\n", - " [ 1.1370e-02, -2.8649e-02, -1.7614e-02],\n", - " [ 7.5035e-03, -2.8190e-02, 2.3285e-02]],\n", - "\n", - " [[ 2.2195e-02, 1.1344e-02, -1.1869e-02],\n", - " [ 1.4061e-03, 3.4397e-02, 2.2290e-02],\n", - " [-8.9835e-03, -3.6462e-04, 4.3276e-03]],\n", - "\n", - " [[ 1.3146e-02, -1.4805e-03, -3.4839e-02],\n", - " [ 9.8965e-03, 1.2123e-03, -1.2450e-02],\n", - " [-2.6127e-02, -1.2640e-03, -2.4437e-02]]],\n", - "\n", - "\n", - " [[[-3.1953e-02, -4.1924e-02, -1.0290e-02],\n", - " [-3.3615e-03, -3.2658e-02, -8.7536e-03],\n", - " [-1.8593e-02, 4.7878e-04, 4.5571e-02]],\n", - "\n", - " [[-1.6380e-02, 9.0561e-03, 3.3792e-02],\n", - " [-4.8492e-03, -3.1668e-02, -1.1819e-02],\n", - " [-2.6978e-02, -4.0084e-03, -7.3130e-03]],\n", - "\n", - " [[-4.7184e-03, -6.5889e-02, -6.8211e-02],\n", - " [ 1.6105e-02, 3.9260e-02, -2.7005e-02],\n", - " [ 2.3144e-02, 3.0311e-02, 1.1514e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-5.8818e-03, -3.2534e-02, 2.3114e-02],\n", - " [ 1.2380e-02, 1.3630e-02, -4.6785e-02],\n", - " [ 1.2566e-02, 3.2790e-02, 2.4120e-02]],\n", - "\n", - " [[-4.9905e-02, -4.5534e-03, -4.3831e-02],\n", - " [-4.3033e-02, -1.5961e-02, -2.9293e-02],\n", - " [-1.2634e-02, 4.5129e-03, 3.3744e-02]],\n", - "\n", - " [[-1.0258e-02, 1.6650e-02, -1.9402e-02],\n", - " [ 4.0746e-02, 3.2129e-02, 3.6650e-02],\n", - " [-4.0866e-02, -4.9628e-02, -2.9688e-02]]],\n", - "\n", - "\n", - " [[[ 2.0279e-02, 8.4315e-03, -2.2303e-02],\n", - " [ 2.8924e-03, -3.4284e-02, -2.5508e-02],\n", - " [ 7.2640e-03, -1.9162e-02, -3.7586e-02]],\n", - "\n", - " [[ 1.6705e-02, 1.3032e-02, -8.9937e-03],\n", - " [ 4.2467e-02, 8.9482e-03, 2.3359e-02],\n", - " [ 6.0444e-02, 1.8572e-02, -2.0024e-02]],\n", - "\n", - " [[ 5.3415e-02, -1.9900e-02, -2.3915e-02],\n", - " [ 9.5328e-03, 6.5805e-03, -2.5049e-03],\n", - " [ 4.5412e-03, 1.4086e-02, 3.1413e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 1.6861e-02, -5.0322e-03, -5.3216e-03],\n", - " [ 4.4522e-02, 3.5843e-02, 2.8520e-02],\n", - " [-2.2439e-02, 4.2831e-02, 2.1473e-02]],\n", - "\n", - " [[-1.4825e-02, -1.6843e-02, -3.1605e-02],\n", - " [-1.7422e-02, 1.0006e-02, -1.6374e-02],\n", - " [ 1.0959e-02, -2.0312e-02, 2.9210e-02]],\n", - "\n", - " [[ 2.0450e-02, -1.3104e-03, -2.7232e-02],\n", - " [ 1.8129e-02, 9.4467e-03, 5.8069e-05],\n", - " [ 2.1227e-02, -3.1317e-02, 1.4729e-02]]]])\n", - "layer2.0.bn2.weight Parameter containing:\n", - "tensor([0.9931, 1.0149, 1.0834, 1.3062, 0.9436, 1.0593, 1.1705, 0.8661, 1.2560,\n", - " 1.0011, 1.0543, 1.0866, 1.0357, 0.9802, 1.0470, 1.0681, 0.9577, 1.0531,\n", - " 1.1537, 0.9731, 0.8574, 1.3457, 0.9703, 0.8292, 1.1049, 1.0989, 1.0087,\n", - " 0.9208, 1.0278, 0.9294, 0.9599, 1.0768, 1.0131, 0.9518, 1.0358, 1.1543,\n", - " 1.0567, 1.1767, 1.0570, 0.9420, 1.1126, 0.9607, 0.9808, 1.0174, 1.1352,\n", - " 1.0871, 0.9884, 1.0924, 1.0687, 0.9786, 0.8349, 0.9612, 1.1718, 1.0225,\n", - " 1.0773, 0.9246, 0.9866, 1.1723, 0.8953, 1.0282, 1.0341, 0.9762, 0.8887,\n", - " 0.8342, 0.8589, 0.9796, 1.0104, 1.0486, 1.0344, 1.0154, 0.8721, 0.9595,\n", - " 1.0041, 1.2363, 1.0264, 1.0738, 0.8888, 0.7846, 0.9974, 1.0654, 1.0377,\n", - " 0.7588, 0.8791, 0.9104, 0.8667, 0.8919, 1.0725, 1.1285, 0.9482, 0.9912,\n", - " 1.0816, 0.9307, 1.1489, 0.9555, 0.9685, 0.9760, 0.9407, 0.9562, 1.0751,\n", - " 0.9656, 1.0497, 1.3493, 1.0319, 1.0755, 1.1694, 0.9989, 0.9907, 0.9567,\n", - " 0.8660, 1.1067, 1.2941, 1.0724, 0.8705, 1.0477, 0.8153, 1.1535, 0.8960,\n", - " 0.9657, 1.0321, 1.0597, 0.9351, 1.0514, 0.9853, 0.9279, 1.0190, 1.0215,\n", - " 0.9553, 0.9368])\n", - "layer2.0.bn2.bias Parameter containing:\n", - "tensor([-1.2981e-02, -8.6542e-02, -2.7465e-02, -8.9233e-02, -4.8785e-02,\n", - " -5.1820e-02, -7.4506e-02, -8.4871e-03, -8.0788e-02, -4.5926e-02,\n", - " -2.8059e-02, -3.0307e-02, -2.2497e-02, -5.4316e-02, -1.2210e-01,\n", - " -2.3084e-02, -1.5159e-02, -3.7485e-02, -3.9592e-02, -1.7644e-02,\n", - " -9.2507e-02, -6.1548e-02, -4.2106e-02, -2.7228e-02, -4.7395e-02,\n", - " 4.0894e-02, -8.1798e-02, 7.9138e-03, -2.8363e-02, -2.0816e-02,\n", - " 4.5710e-02, -4.7065e-02, -6.4959e-02, -5.2531e-02, -1.4958e-02,\n", - " -2.6834e-02, -7.6824e-02, -4.7256e-02, -2.8755e-02, -3.9290e-02,\n", - " 2.5110e-03, 6.5448e-02, -2.4467e-02, -6.3492e-02, 4.6999e-03,\n", - " -4.1704e-02, -3.9343e-02, 4.9145e-05, -6.0200e-03, -3.4718e-02,\n", - " -5.2317e-02, -5.0237e-02, -4.9025e-02, -2.7047e-02, -5.2278e-02,\n", - " -4.9766e-02, -5.2599e-02, -1.2772e-01, -1.0851e-02, -6.3345e-02,\n", - " -5.3914e-02, -1.0681e-02, -2.0828e-02, -4.9643e-02, -4.6539e-02,\n", - " 2.7344e-02, -2.4344e-02, -3.5232e-02, -3.5927e-02, -4.2018e-02,\n", - " -5.2375e-02, -4.9287e-02, -1.5841e-03, -4.7609e-02, -6.0663e-02,\n", - " -4.4702e-02, -7.3340e-02, -1.0380e-01, -5.8219e-02, -9.6217e-03,\n", - " -6.6633e-03, -3.2506e-02, 4.1905e-02, -5.3879e-02, -3.6673e-02,\n", - " -4.4338e-02, -5.6134e-02, -6.1860e-03, -1.4108e-02, -7.2612e-02,\n", - " -4.0473e-02, -2.9741e-03, -9.7863e-02, -1.9554e-02, -8.3787e-02,\n", - " -3.0235e-02, -8.5696e-02, -7.1485e-02, -4.9228e-02, 1.4736e-02,\n", - " 2.4039e-02, -6.1848e-02, -5.7710e-02, -9.6235e-02, -1.1919e-02,\n", - " -2.6103e-02, -4.3060e-02, 1.3986e-02, -5.9900e-02, -2.8015e-02,\n", - " -1.5239e-02, -2.3644e-02, -5.2308e-02, -7.7621e-02, -1.2540e-01,\n", - " -1.4630e-02, -5.4057e-02, -3.7307e-02, -5.4744e-02, -2.6498e-02,\n", - " -1.0680e-01, -1.1601e-02, -1.6164e-02, -3.0841e-02, -5.8472e-02,\n", - " -1.2627e-02, -8.5479e-02, -4.6238e-02])\n", - "layer2.0.shortcut_conv.weight Parameter containing:\n", - "tensor([[[[ 0.2162]],\n", - "\n", - " [[ 0.0339]],\n", - "\n", - " [[ 0.1120]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.1301]],\n", - "\n", - " [[ 0.0351]],\n", - "\n", - " [[-0.1245]]],\n", - "\n", - "\n", - " [[[-0.1042]],\n", - "\n", - " [[ 0.1120]],\n", - "\n", - " [[-0.0587]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.1505]],\n", - "\n", - " [[ 0.0478]],\n", - "\n", - " [[-0.0350]]],\n", - "\n", - "\n", - " [[[ 0.0636]],\n", - "\n", - " [[-0.0045]],\n", - "\n", - " [[-0.0074]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0249]],\n", - "\n", - " [[-0.0809]],\n", - "\n", - " [[ 0.0892]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[-0.0323]],\n", - "\n", - " [[ 0.1375]],\n", - "\n", - " [[ 0.0807]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0398]],\n", - "\n", - " [[ 0.0997]],\n", - "\n", - " [[ 0.0951]]],\n", - "\n", - "\n", - " [[[-0.0811]],\n", - "\n", - " [[ 0.1213]],\n", - "\n", - " [[ 0.0251]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.1344]],\n", - "\n", - " [[ 0.1247]],\n", - "\n", - " [[-0.0439]]],\n", - "\n", - "\n", - " [[[ 0.0143]],\n", - "\n", - " [[-0.0505]],\n", - "\n", - " [[-0.0858]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0293]],\n", - "\n", - " [[-0.0258]],\n", - "\n", - " [[-0.1046]]]])\n", - "layer2.0.shortcut_bn.weight Parameter containing:\n", - "tensor([1.0056, 1.0531, 1.0043, 0.9518, 0.9567, 1.0751, 1.1187, 1.0293, 1.0514,\n", - " 1.0665, 1.0102, 1.0254, 1.0789, 1.0322, 1.1177, 1.0840, 1.0340, 1.0115,\n", - " 1.1417, 1.0085, 0.9390, 1.1613, 1.0394, 1.0185, 0.9537, 1.0575, 1.1024,\n", - " 1.0102, 1.0730, 1.0289, 1.0484, 1.1239, 1.0378, 1.0818, 0.9914, 1.0452,\n", - " 1.0040, 1.0889, 1.1251, 1.0409, 1.0309, 1.0383, 1.0509, 0.9937, 1.0173,\n", - " 1.0388, 1.1697, 1.0248, 1.0797, 0.9914, 0.9138, 0.9931, 0.9795, 1.0041,\n", - " 0.9696, 0.9805, 0.9996, 1.0714, 0.9691, 1.0337, 1.0087, 1.1008, 1.0351,\n", - " 0.9816, 1.0039, 1.0413, 1.0856, 1.0510, 0.9705, 1.0622, 0.9942, 1.0241,\n", - " 1.0249, 1.0274, 1.0713, 1.0870, 0.9671, 0.9346, 1.0485, 1.0752, 1.0161,\n", - " 0.9574, 1.0789, 1.0132, 0.9525, 1.0304, 0.9617, 1.0545, 0.9895, 1.0222,\n", - " 0.9965, 0.9874, 1.0452, 0.9885, 0.9668, 1.0647, 0.9278, 0.9529, 1.0227,\n", - " 1.0354, 1.0489, 1.1118, 0.9953, 1.0224, 1.0967, 1.0673, 1.0372, 1.0253,\n", - " 0.9920, 1.1055, 1.0380, 1.0451, 0.9972, 1.0180, 0.9287, 1.0507, 0.9703,\n", - " 1.0450, 1.0503, 1.0136, 1.0655, 1.0190, 1.0881, 1.0169, 1.0723, 0.9594,\n", - " 0.9919, 1.0379])\n", - "layer2.0.shortcut_bn.bias Parameter containing:\n", - "tensor([-1.2981e-02, -8.6542e-02, -2.7465e-02, -8.9233e-02, -4.8785e-02,\n", - " -5.1820e-02, -7.4506e-02, -8.4871e-03, -8.0788e-02, -4.5926e-02,\n", - " -2.8059e-02, -3.0307e-02, -2.2497e-02, -5.4316e-02, -1.2210e-01,\n", - " -2.3084e-02, -1.5159e-02, -3.7485e-02, -3.9592e-02, -1.7644e-02,\n", - " -9.2507e-02, -6.1548e-02, -4.2106e-02, -2.7228e-02, -4.7395e-02,\n", - " 4.0894e-02, -8.1798e-02, 7.9138e-03, -2.8363e-02, -2.0816e-02,\n", - " 4.5710e-02, -4.7065e-02, -6.4959e-02, -5.2531e-02, -1.4958e-02,\n", - " -2.6834e-02, -7.6824e-02, -4.7256e-02, -2.8755e-02, -3.9290e-02,\n", - " 2.5110e-03, 6.5448e-02, -2.4467e-02, -6.3492e-02, 4.6999e-03,\n", - " -4.1704e-02, -3.9343e-02, 4.9145e-05, -6.0200e-03, -3.4718e-02,\n", - " -5.2317e-02, -5.0237e-02, -4.9025e-02, -2.7047e-02, -5.2278e-02,\n", - " -4.9766e-02, -5.2599e-02, -1.2772e-01, -1.0851e-02, -6.3345e-02,\n", - " -5.3914e-02, -1.0681e-02, -2.0828e-02, -4.9643e-02, -4.6539e-02,\n", - " 2.7344e-02, -2.4344e-02, -3.5232e-02, -3.5927e-02, -4.2018e-02,\n", - " -5.2375e-02, -4.9287e-02, -1.5841e-03, -4.7609e-02, -6.0663e-02,\n", - " -4.4702e-02, -7.3340e-02, -1.0380e-01, -5.8219e-02, -9.6217e-03,\n", - " -6.6633e-03, -3.2506e-02, 4.1905e-02, -5.3879e-02, -3.6673e-02,\n", - " -4.4338e-02, -5.6134e-02, -6.1860e-03, -1.4108e-02, -7.2612e-02,\n", - " -4.0473e-02, -2.9741e-03, -9.7863e-02, -1.9554e-02, -8.3787e-02,\n", - " -3.0235e-02, -8.5696e-02, -7.1485e-02, -4.9228e-02, 1.4736e-02,\n", - " 2.4039e-02, -6.1848e-02, -5.7710e-02, -9.6235e-02, -1.1919e-02,\n", - " -2.6103e-02, -4.3060e-02, 1.3986e-02, -5.9900e-02, -2.8015e-02,\n", - " -1.5239e-02, -2.3644e-02, -5.2308e-02, -7.7621e-02, -1.2540e-01,\n", - " -1.4630e-02, -5.4057e-02, -3.7307e-02, -5.4744e-02, -2.6498e-02,\n", - " -1.0680e-01, -1.1601e-02, -1.6164e-02, -3.0841e-02, -5.8472e-02,\n", - " -1.2627e-02, -8.5479e-02, -4.6238e-02])\n", - "layer2.1.conv1.weight Parameter containing:\n", - "tensor([[[[-1.2671e-02, -2.3938e-02, -1.9099e-02],\n", - " [ 1.1139e-02, -2.8324e-02, -2.8830e-02],\n", - " [-2.3069e-02, -1.9006e-02, 2.0482e-02]],\n", - "\n", - " [[-1.5592e-03, -3.6956e-02, -4.0356e-02],\n", - " [ 1.4791e-02, -3.0037e-02, -3.2506e-02],\n", - " [ 5.5356e-03, 1.8473e-02, -1.3575e-02]],\n", - "\n", - " [[-1.5707e-02, 2.3170e-02, 9.3615e-04],\n", - " [-2.3199e-02, -3.8213e-03, -1.6252e-02],\n", - " [ 3.6089e-02, -3.3433e-03, 2.4761e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[ 2.0869e-02, 2.5472e-02, 1.1077e-02],\n", - " [ 4.3677e-02, 4.7397e-02, 5.2356e-04],\n", - " [-2.2244e-02, 3.8463e-02, 1.2060e-04]],\n", - "\n", - " [[ 3.4758e-02, 5.2871e-02, 4.4993e-02],\n", - " [ 4.9931e-02, 4.9706e-02, 2.1665e-02],\n", - " [ 4.7596e-02, 1.3522e-02, 2.3822e-02]],\n", - "\n", - " [[ 2.2113e-02, -4.2537e-03, -4.7076e-02],\n", - " [ 3.1440e-02, 5.0481e-03, 3.0701e-03],\n", - " [ 1.2286e-02, 2.8508e-02, 1.2381e-03]]],\n", - "\n", - "\n", - " [[[-2.3738e-02, -1.3931e-02, -2.1727e-02],\n", - " [-2.1146e-02, -5.2254e-03, -8.0258e-05],\n", - " [-3.2895e-02, -1.6613e-02, -4.7330e-02]],\n", - "\n", - " [[ 3.0972e-02, -1.8397e-02, 2.1567e-02],\n", - " [ 3.8275e-02, 3.8623e-02, 2.3262e-02],\n", - " [ 1.7491e-02, 2.6040e-02, -9.0549e-03]],\n", - "\n", - " [[ 2.3262e-02, 7.3751e-04, 1.5063e-02],\n", - " [ 4.1470e-02, 2.8551e-02, 3.1405e-02],\n", - " [ 2.1048e-02, 1.2375e-02, 4.5953e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-2.8296e-02, -9.1726e-03, -7.3690e-03],\n", - " [ 1.7543e-02, 1.2311e-03, 1.9880e-02],\n", - " [ 1.0991e-02, -5.1386e-02, 3.1852e-03]],\n", - "\n", - " [[-2.8715e-02, -2.7118e-02, -7.1705e-03],\n", - " [ 4.2610e-02, 2.4501e-02, 1.3790e-02],\n", - " [-6.7407e-03, -2.6416e-03, -7.7918e-03]],\n", - "\n", - " [[ 1.2130e-02, 1.8906e-02, 1.1645e-03],\n", - " [ 3.3207e-02, 1.5823e-02, -3.7893e-03],\n", - " [ 4.3907e-02, 3.4477e-02, 3.8621e-02]]],\n", - "\n", - "\n", - " [[[ 4.2056e-02, 5.3050e-03, 7.1378e-02],\n", - " [ 2.3840e-02, 3.1197e-02, 6.1677e-02],\n", - " [ 4.0494e-02, -2.5194e-02, -6.7998e-03]],\n", - "\n", - " [[-3.6769e-02, -1.3199e-02, -3.2012e-02],\n", - " [-7.7791e-03, -6.0538e-02, -1.0908e-02],\n", - " [ 3.7365e-02, 5.9451e-02, 1.1005e-02]],\n", - "\n", - " [[-3.7870e-02, 3.5183e-02, -2.5808e-02],\n", - " [-2.6994e-02, -3.9472e-02, -2.0497e-02],\n", - " [-5.8050e-02, -4.7024e-03, -1.8292e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[-7.0749e-03, -3.6702e-02, -2.4704e-02],\n", - " [ 2.5994e-02, 9.4172e-03, -2.3231e-02],\n", - " [ 2.3903e-02, 2.8629e-02, 1.9264e-02]],\n", - "\n", - " [[-6.7792e-03, -7.7618e-03, -6.9538e-02],\n", - " [ 4.0581e-02, 5.6283e-02, 2.9467e-02],\n", - " [-3.3667e-03, -6.5002e-02, -4.8670e-02]],\n", - "\n", - " [[ 1.8321e-02, -2.4471e-02, 9.8842e-03],\n", - " [ 3.9916e-03, 2.2426e-02, -3.6026e-03],\n", - " [ 3.8083e-02, 3.0053e-02, 1.9913e-02]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[ 2.5227e-03, 5.2476e-02, 2.3908e-02],\n", - " [-3.4207e-02, -3.0952e-02, -3.6447e-02],\n", - " [ 1.3552e-02, 2.7308e-02, 5.1622e-02]],\n", - "\n", - " [[-2.4969e-02, -2.1576e-02, -4.5853e-02],\n", - " [ 1.3879e-02, 6.2150e-02, 1.4080e-02],\n", - " [ 1.1243e-02, -4.2055e-02, 7.8067e-03]],\n", - "\n", - " [[-2.3491e-02, -5.1456e-03, -1.2547e-02],\n", - " [-1.6141e-02, 1.0643e-02, -2.5370e-02],\n", - " [-1.9203e-03, -5.2159e-02, 5.9291e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.1122e-02, -1.4899e-02, 2.1300e-02],\n", - " [ 3.6331e-03, 6.5192e-03, 4.7576e-02],\n", - " [ 2.3239e-02, -2.0316e-02, 2.7265e-02]],\n", - "\n", - " [[ 1.8296e-02, 4.2794e-02, 4.8028e-02],\n", - " [-8.7513e-03, -2.8567e-02, -4.4760e-02],\n", - " [ 5.6493e-03, -1.7444e-02, -2.9112e-02]],\n", - "\n", - " [[-2.2753e-03, -2.9479e-03, 2.8816e-02],\n", - " [ 1.9996e-03, 2.7665e-03, 4.4159e-02],\n", - " [-3.1448e-02, -4.9231e-02, 1.4012e-02]]],\n", - "\n", - "\n", - " [[[ 3.5169e-02, 3.1395e-02, -8.4316e-03],\n", - " [ 5.0173e-02, -1.6380e-02, 3.1588e-03],\n", - " [ 9.7549e-03, -1.6155e-03, 2.5622e-02]],\n", - "\n", - " [[ 2.9783e-03, -1.0025e-02, 4.0323e-03],\n", - " [ 2.4023e-02, 1.8000e-02, -2.0841e-02],\n", - " [ 1.7581e-03, 2.7513e-02, 3.2596e-02]],\n", - "\n", - " [[ 8.9868e-03, 1.2796e-02, -1.1676e-02],\n", - " [-8.7491e-03, 3.1234e-03, -8.2818e-03],\n", - " [ 2.3207e-02, 7.2682e-03, 4.6091e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 5.1432e-02, 3.8542e-02, 1.0702e-02],\n", - " [ 1.3778e-03, 1.6634e-02, -2.2224e-02],\n", - " [ 2.5061e-02, 3.7500e-02, -1.4531e-02]],\n", - "\n", - " [[ 3.4381e-02, -4.6145e-05, -2.3151e-02],\n", - " [ 4.0692e-02, 1.4137e-02, 3.8377e-03],\n", - " [ 6.2407e-03, -4.0193e-02, -4.6756e-02]],\n", - "\n", - " [[ 1.8292e-02, 3.8774e-03, -2.9178e-02],\n", - " [-3.8458e-02, 9.2545e-03, 2.3120e-02],\n", - " [ 1.1178e-02, 4.1157e-02, -2.7000e-03]]],\n", - "\n", - "\n", - " [[[ 4.6502e-04, 2.7367e-02, -1.8018e-02],\n", - " [-8.8800e-04, 3.2783e-02, 8.7174e-03],\n", - " [-1.3216e-02, 3.0975e-02, -1.3285e-03]],\n", - "\n", - " [[-1.9450e-02, 2.7712e-03, 2.3936e-02],\n", - " [-2.4075e-02, -2.3674e-03, -3.2410e-02],\n", - " [-1.9093e-02, -2.5896e-02, -7.2411e-03]],\n", - "\n", - " [[-1.0581e-02, 1.2100e-02, -2.1283e-02],\n", - " [-1.3377e-02, 2.6153e-04, -3.6690e-02],\n", - " [-2.5659e-02, -2.2961e-02, -4.4005e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-3.8530e-02, -1.6303e-02, -2.1558e-02],\n", - " [ 2.9950e-02, -1.8199e-02, -2.5115e-04],\n", - " [-2.1965e-02, -1.9656e-02, -2.8129e-03]],\n", - "\n", - " [[-6.2657e-03, -6.5198e-02, -3.4879e-02],\n", - " [-1.6735e-02, -2.4723e-02, -1.7410e-02],\n", - " [ 1.5641e-02, -1.8898e-02, -2.1710e-02]],\n", - "\n", - " [[ 1.3554e-02, 5.3783e-03, 2.0453e-03],\n", - " [ 2.4260e-02, -2.5494e-02, 1.7761e-02],\n", - " [-2.7732e-02, -2.5852e-02, -9.7575e-03]]]])\n", - "layer2.1.bn1.weight Parameter containing:\n", - "tensor([0.9469, 0.9599, 1.0408, 1.0384, 0.9564, 0.9625, 0.9672, 0.9335, 0.9827,\n", - " 1.0143, 1.1081, 0.9712, 0.9803, 1.0061, 1.0102, 0.9363, 1.0412, 0.9488,\n", - " 1.0403, 1.0078, 0.9938, 0.9206, 1.0497, 0.9752, 0.9866, 0.9762, 1.0210,\n", - " 0.9545, 0.9374, 1.0035, 0.9426, 0.9650, 0.9559, 0.9546, 1.1425, 0.9578,\n", - " 1.0181, 0.9881, 1.0366, 1.0273, 1.0111, 0.9961, 1.0984, 0.9692, 0.9251,\n", - " 1.0338, 0.9482, 1.0294, 0.9981, 0.9767, 0.9457, 0.9172, 1.0938, 1.0430,\n", - " 1.1256, 1.0880, 1.0017, 1.0415, 0.9938, 1.1456, 1.0022, 0.9620, 0.9321,\n", - " 1.0135, 1.0086, 0.9947, 0.8943, 1.0304, 1.1431, 0.9540, 1.0271, 0.9725,\n", - " 1.0664, 1.0102, 1.0321, 1.0395, 1.0043, 1.0035, 0.9879, 1.0392, 1.0987,\n", - " 0.9446, 0.9369, 0.9410, 1.0415, 1.0260, 0.9566, 0.9980, 0.8972, 1.0363,\n", - " 0.9065, 0.9226, 0.9759, 1.0496, 0.9694, 0.9439, 0.9542, 0.9537, 1.0676,\n", - " 1.0573, 1.0023, 1.0358, 1.0215, 0.8992, 0.9458, 0.9192, 0.8680, 1.0038,\n", - " 1.0073, 0.9660, 1.0339, 1.0393, 0.9894, 1.0238, 1.0558, 0.9541, 0.9365,\n", - " 1.0021, 0.9507, 1.0065, 1.0754, 0.9610, 0.9913, 0.9653, 1.0200, 1.0461,\n", - " 1.0000, 0.9818])\n", - "layer2.1.bn1.bias Parameter containing:\n", - "tensor([-0.0673, -0.0338, -0.0798, -0.1658, -0.0523, -0.0830, -0.1435, -0.0267,\n", - " -0.0379, -0.0789, -0.1152, -0.0774, -0.0791, -0.0604, -0.0548, -0.0850,\n", - " -0.0389, -0.1077, -0.0098, -0.0761, -0.0727, -0.1523, -0.0636, -0.0924,\n", - " -0.0278, -0.0917, -0.0908, -0.0543, -0.0616, -0.0411, -0.0775, -0.0810,\n", - " -0.0747, -0.0256, -0.1169, -0.0433, -0.0451, -0.1127, -0.0723, -0.1305,\n", - " -0.1152, -0.0542, -0.0730, -0.0952, -0.1198, -0.0611, -0.0426, -0.0912,\n", - " -0.0861, -0.0774, -0.1032, -0.0950, -0.0548, -0.0913, -0.0961, -0.0706,\n", - " -0.0829, -0.0349, -0.0849, -0.1585, -0.0403, -0.0749, -0.0919, -0.0914,\n", - " -0.0395, -0.0244, -0.0733, -0.0359, -0.1861, -0.0820, -0.0695, -0.0439,\n", - " -0.0523, -0.0789, -0.0342, -0.1350, -0.0378, -0.0595, -0.0199, -0.0518,\n", - " -0.0643, -0.0442, -0.0830, -0.0320, -0.0403, -0.0410, -0.1096, -0.0638,\n", - " -0.0641, -0.0460, -0.0740, -0.0923, -0.0401, -0.0058, -0.0803, -0.0584,\n", - " -0.0683, -0.0805, -0.0518, -0.0837, -0.0866, -0.1192, -0.0260, -0.0665,\n", - " -0.0630, -0.1184, -0.0660, -0.0332, -0.0414, -0.1052, -0.0006, -0.0555,\n", - " -0.0702, -0.0166, -0.1229, -0.0376, -0.0829, -0.0606, -0.0979, -0.0913,\n", - " -0.1539, -0.0495, -0.0757, -0.1031, -0.0495, -0.0982, -0.1310, -0.0706])\n", - "layer2.1.conv2.weight Parameter containing:\n", - "tensor([[[[ 9.0620e-03, 1.3316e-02, 1.0495e-02],\n", - " [ 2.3512e-02, -1.3791e-03, 1.4756e-02],\n", - " [ 3.8027e-02, 1.9696e-02, -2.3292e-03]],\n", - "\n", - " [[-1.1708e-02, 1.3712e-02, -5.3562e-03],\n", - " [-3.9382e-02, -3.5332e-02, 2.0547e-02],\n", - " [ 6.4936e-03, 9.7651e-04, -3.3298e-03]],\n", - "\n", - " [[-2.7199e-02, -3.8818e-02, -2.6458e-02],\n", - " [-2.8592e-02, -4.5702e-02, -5.0769e-03],\n", - " [ 2.2421e-02, -1.0246e-02, 6.8318e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[-2.4165e-02, -6.6939e-03, 2.7444e-02],\n", - " [-1.9542e-02, -4.5656e-02, 6.5853e-03],\n", - " [-3.9443e-02, -2.1822e-03, -3.6982e-02]],\n", - "\n", - " [[ 1.1827e-02, -1.2058e-02, -6.3863e-03],\n", - " [-3.9754e-02, -7.1089e-02, -6.8766e-02],\n", - " [-1.2733e-02, -3.8078e-02, -2.0078e-02]],\n", - "\n", - " [[ 1.2081e-02, 6.0684e-02, -1.3995e-02],\n", - " [-2.9236e-03, 2.1464e-03, -2.5115e-02],\n", - " [ 4.3272e-02, 7.2460e-02, 1.9598e-02]]],\n", - "\n", - "\n", - " [[[ 1.3322e-02, 3.8556e-02, 2.1243e-02],\n", - " [ 1.8053e-02, 2.3932e-02, -1.5429e-02],\n", - " [ 1.7548e-02, -7.2034e-03, 7.1181e-03]],\n", - "\n", - " [[-5.8165e-03, -4.9490e-04, -3.5662e-02],\n", - " [-1.0573e-02, -1.0955e-02, -3.1837e-02],\n", - " [ 1.2337e-02, -1.6183e-02, 3.5962e-04]],\n", - "\n", - " [[-1.6080e-02, 9.4185e-03, -1.9492e-02],\n", - " [ 5.1737e-02, 4.8596e-02, 5.4428e-02],\n", - " [-7.2937e-03, 5.3389e-03, 3.1466e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-4.3382e-02, -5.6812e-02, -2.2632e-02],\n", - " [-3.3871e-03, -1.2200e-02, -6.9654e-03],\n", - " [ 1.3559e-02, 1.6472e-02, -1.0585e-02]],\n", - "\n", - " [[ 1.4523e-03, 1.7408e-02, 6.0656e-03],\n", - " [ 3.9184e-02, 1.5487e-02, 1.1933e-02],\n", - " [ 1.2125e-02, -6.5983e-03, 4.9426e-02]],\n", - "\n", - " [[ 7.4308e-03, -1.3106e-02, 1.6033e-02],\n", - " [ 2.4728e-02, -7.3725e-03, 2.4851e-02],\n", - " [-9.9803e-03, -4.9287e-02, 7.5864e-03]]],\n", - "\n", - "\n", - " [[[-4.9769e-02, -1.4696e-02, -2.3030e-02],\n", - " [-4.7186e-02, -6.6357e-02, -2.2492e-02],\n", - " [-5.2339e-02, -4.6180e-02, -5.6023e-03]],\n", - "\n", - " [[ 3.9187e-03, 1.5691e-02, 7.6873e-03],\n", - " [-3.1093e-02, -4.1031e-02, 2.3911e-02],\n", - " [-2.7188e-02, 7.6569e-03, -2.5654e-02]],\n", - "\n", - " [[ 1.9282e-02, -1.3629e-02, 1.8839e-02],\n", - " [-5.8584e-03, 2.7730e-02, 7.9821e-03],\n", - " [ 8.8666e-03, -3.1551e-02, 1.9405e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 2.4015e-02, -3.3378e-02, -5.3954e-04],\n", - " [-4.5346e-03, -2.1869e-03, 3.7101e-02],\n", - " [-3.1134e-02, -1.5882e-02, -2.8024e-02]],\n", - "\n", - " [[ 2.4350e-02, 1.7124e-02, -3.3614e-02],\n", - " [ 2.1010e-02, -2.8533e-02, -2.6490e-02],\n", - " [ 3.0194e-02, -3.2890e-02, 4.3761e-03]],\n", - "\n", - " [[-3.4936e-02, -1.5606e-02, 1.6964e-04],\n", - " [-3.0659e-02, -8.4506e-03, 2.2456e-02],\n", - " [ 2.6422e-02, 2.0289e-02, 9.1066e-03]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[-2.6236e-02, -3.4802e-02, -5.1656e-02],\n", - " [-2.8055e-02, -8.5368e-03, -2.1985e-02],\n", - " [-2.5813e-02, -1.0077e-02, -2.8292e-02]],\n", - "\n", - " [[-4.4758e-02, -5.8627e-02, -3.4626e-03],\n", - " [ 4.3765e-03, 2.5680e-02, 2.2966e-02],\n", - " [-1.1803e-02, -1.6649e-02, -1.0021e-02]],\n", - "\n", - " [[ 8.3297e-03, -1.7979e-03, -2.6646e-02],\n", - " [ 5.8005e-03, -4.2564e-02, 2.8873e-02],\n", - " [-1.8539e-02, -3.3966e-02, 1.7152e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-4.1487e-02, -3.3227e-02, -4.9985e-03],\n", - " [-1.9580e-02, 7.5260e-03, 2.4756e-02],\n", - " [ 2.8643e-03, -1.1592e-02, -1.8972e-02]],\n", - "\n", - " [[-5.8783e-02, -2.4775e-02, 6.1720e-04],\n", - " [-3.2235e-02, -5.0035e-03, 1.1334e-02],\n", - " [ 1.2860e-02, 2.2570e-03, 3.1842e-02]],\n", - "\n", - " [[ 5.6097e-03, -1.5921e-02, -1.5595e-02],\n", - " [-2.3235e-02, -2.1884e-02, 1.2557e-03],\n", - " [ 3.4493e-03, 7.8864e-03, -2.4334e-02]]],\n", - "\n", - "\n", - " [[[-3.9457e-03, 8.6690e-03, -2.3602e-02],\n", - " [ 8.4130e-03, -7.7556e-03, 2.1163e-02],\n", - " [-3.4924e-05, 1.1466e-02, 4.7752e-03]],\n", - "\n", - " [[-7.1314e-02, -2.9460e-02, -5.2409e-02],\n", - " [-2.2031e-02, -3.2676e-02, -2.9026e-02],\n", - " [-2.9801e-02, -3.6596e-02, -5.7815e-02]],\n", - "\n", - " [[-3.4363e-02, -1.0897e-03, -5.6942e-03],\n", - " [ 8.1054e-03, 4.0558e-02, 6.3196e-03],\n", - " [ 5.4653e-02, -4.7664e-03, 1.6212e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 2.5112e-02, 3.1311e-02, -1.5590e-02],\n", - " [ 7.0375e-02, 6.6667e-02, 6.6758e-02],\n", - " [-2.0276e-02, 2.5930e-02, 7.1680e-03]],\n", - "\n", - " [[ 5.9966e-04, -3.9549e-03, 3.0694e-02],\n", - " [ 2.4426e-02, 1.7874e-02, -2.1178e-02],\n", - " [-2.7890e-03, -7.3431e-03, -2.7056e-02]],\n", - "\n", - " [[-5.6897e-02, 1.9620e-02, -5.0968e-02],\n", - " [-2.3845e-02, 2.4067e-03, -3.8237e-02],\n", - " [-2.7367e-02, -4.8455e-02, -4.9325e-02]]],\n", - "\n", - "\n", - " [[[-2.0665e-02, -2.9083e-02, -2.0078e-02],\n", - " [-1.2909e-02, -4.0212e-02, 1.2505e-02],\n", - " [ 3.1596e-02, 2.8582e-02, 5.2544e-02]],\n", - "\n", - " [[-1.4806e-02, -3.6197e-02, -3.2830e-03],\n", - " [-3.7018e-02, 3.2018e-02, 4.3647e-02],\n", - " [ 1.0543e-02, 4.0195e-03, 5.5371e-02]],\n", - "\n", - " [[-1.9200e-02, -6.9304e-02, -5.1689e-02],\n", - " [-1.2329e-03, -3.0719e-02, -5.0884e-03],\n", - " [ 1.5455e-02, -2.5330e-03, 4.8852e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-6.2182e-03, 1.1251e-02, -3.3474e-02],\n", - " [-1.6241e-02, -1.3244e-02, -7.0312e-02],\n", - " [ 2.9194e-02, -9.2372e-04, 1.0911e-02]],\n", - "\n", - " [[-1.2244e-02, -1.1533e-02, 2.8962e-02],\n", - " [ 3.9788e-02, 2.6832e-02, 1.3362e-02],\n", - " [ 3.6984e-02, -2.2891e-02, -3.2071e-02]],\n", - "\n", - " [[ 1.0560e-02, 1.4959e-02, 3.4440e-02],\n", - " [-3.0914e-03, 7.2294e-04, 4.0611e-03],\n", - " [ 5.2256e-03, 2.2626e-02, 5.8988e-03]]]])\n", - "layer2.1.bn2.weight Parameter containing:\n", - "tensor([0.9669, 0.9217, 0.9698, 1.0776, 0.8574, 0.9466, 0.9596, 0.8409, 0.9237,\n", - " 0.9604, 0.9425, 0.9107, 0.9738, 0.9234, 1.0417, 0.9799, 0.8988, 0.9993,\n", - " 0.8893, 0.9971, 0.7985, 0.9634, 0.9848, 0.8414, 0.9464, 0.8792, 1.0549,\n", - " 0.8567, 0.9816, 0.8909, 0.8655, 0.9836, 0.9409, 0.9513, 1.0185, 1.0638,\n", - " 0.9734, 0.9440, 0.9009, 0.9552, 0.9628, 0.9267, 0.9959, 0.8419, 1.0082,\n", - " 0.9657, 0.9144, 1.0230, 0.9963, 0.8837, 0.8098, 0.9780, 1.0287, 0.9983,\n", - " 1.0574, 0.8762, 0.9858, 0.9770, 0.9042, 0.8879, 0.9769, 0.9435, 0.8430,\n", - " 0.8512, 0.8121, 0.9104, 1.1214, 1.0722, 0.9728, 1.0472, 0.8562, 0.9894,\n", - " 0.9725, 1.0077, 0.9654, 1.0169, 0.8221, 0.7715, 0.9867, 0.9983, 0.9642,\n", - " 0.7586, 0.8606, 0.9114, 0.8180, 0.9168, 0.8913, 1.0396, 0.9633, 1.0078,\n", - " 1.0352, 0.8088, 0.9190, 0.9156, 0.9024, 0.9893, 0.9497, 0.9773, 1.0055,\n", - " 0.9963, 0.9586, 0.9777, 1.0456, 0.9290, 1.0590, 0.8944, 1.0062, 0.9626,\n", - " 0.9110, 0.9895, 0.9059, 1.0135, 0.8311, 0.9899, 0.8107, 0.9532, 1.0057,\n", - " 0.9510, 0.9788, 1.0119, 0.8398, 0.9277, 0.9727, 0.9752, 0.9373, 0.9794,\n", - " 0.8550, 0.9457])\n", - "layer2.1.bn2.bias Parameter containing:\n", - "tensor([-0.0631, -0.1395, -0.0694, -0.1654, -0.0476, -0.1273, -0.0678, -0.1002,\n", - " -0.0325, -0.0740, -0.1281, -0.1002, -0.0175, -0.0615, -0.0975, -0.0528,\n", - " -0.0580, -0.0978, -0.0452, -0.0390, -0.0917, -0.1588, -0.0876, -0.0832,\n", - " -0.0911, -0.0724, -0.1224, -0.0939, -0.0113, -0.0898, -0.0591, -0.0562,\n", - " -0.0751, -0.0556, -0.0274, -0.1637, -0.1727, -0.0791, -0.0973, -0.0137,\n", - " -0.0436, -0.0590, -0.0689, -0.0996, -0.0387, -0.0633, -0.1175, -0.0700,\n", - " -0.0791, -0.1303, -0.0326, -0.1016, -0.0775, -0.0637, -0.1030, -0.0783,\n", - " -0.0339, -0.1278, -0.0296, -0.0876, -0.0457, -0.0382, -0.0050, -0.0750,\n", - " -0.0725, -0.0017, -0.0492, -0.0742, -0.0510, -0.0604, -0.0823, -0.0722,\n", - " -0.0629, -0.0963, -0.0659, -0.0422, -0.1470, -0.0926, -0.0610, -0.0476,\n", - " -0.0626, -0.1338, -0.0497, -0.0442, -0.0355, -0.0271, -0.0796, -0.1819,\n", - " -0.0443, -0.1882, -0.0896, -0.0531, -0.1230, -0.0728, -0.1342, -0.0824,\n", - " -0.0452, -0.0723, -0.0767, -0.0520, -0.0523, -0.0455, -0.0583, -0.1160,\n", - " -0.0359, -0.0998, -0.0496, -0.0457, -0.0740, -0.0542, -0.0391, -0.0431,\n", - " -0.1116, -0.1315, -0.1144, -0.0804, -0.0101, -0.0375, -0.0806, -0.0234,\n", - " -0.0276, -0.0227, -0.0626, 0.0198, -0.0983, -0.0464, -0.1192, -0.0449])\n", - "layer3.0.conv1.weight Parameter containing:\n", - "tensor([[[[-2.8938e-03, -3.0210e-02, 2.4601e-02],\n", - " [ 6.5808e-03, -1.7178e-02, 2.3446e-03],\n", - " [-1.6810e-02, 1.1131e-02, -3.5020e-03]],\n", - "\n", - " [[ 2.4755e-02, -3.2800e-02, 3.2009e-02],\n", - " [ 1.9366e-02, -1.1100e-03, 4.9418e-03],\n", - " [ 2.5708e-02, -1.1415e-02, 1.0658e-02]],\n", - "\n", - " [[ 1.6535e-02, -5.5367e-03, 2.4412e-02],\n", - " [ 4.0183e-03, -2.3662e-02, -4.1649e-02],\n", - " [ 1.5828e-02, -3.0302e-02, -2.3527e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 3.0539e-03, -3.9847e-03, -3.5681e-02],\n", - " [ 5.4152e-03, -1.0178e-02, 1.4431e-02],\n", - " [-7.5026e-02, -5.5198e-02, 4.9839e-02]],\n", - "\n", - " [[ 3.7400e-02, 6.3145e-03, 4.9693e-02],\n", - " [ 2.0811e-03, 4.8901e-03, -1.7024e-02],\n", - " [ 3.5467e-02, -4.9344e-03, 1.6703e-02]],\n", - "\n", - " [[-4.6327e-03, 2.8507e-02, 4.8495e-02],\n", - " [ 1.7117e-02, -2.4190e-03, 3.9677e-02],\n", - " [-1.7740e-02, 1.2457e-02, 6.9044e-02]]],\n", - "\n", - "\n", - " [[[ 2.1504e-02, 5.2210e-02, 1.2908e-02],\n", - " [-2.6065e-02, -1.5139e-02, 3.4843e-02],\n", - " [-1.7559e-02, -5.9843e-02, 1.8697e-02]],\n", - "\n", - " [[ 3.6117e-02, -2.0767e-02, -1.6192e-02],\n", - " [ 4.7451e-02, 2.7762e-02, -2.7552e-02],\n", - " [-2.8505e-03, -5.8121e-02, -5.2318e-03]],\n", - "\n", - " [[-3.4431e-02, -2.0220e-02, 1.0149e-02],\n", - " [-1.5946e-02, 9.6492e-03, 1.6029e-04],\n", - " [-3.6388e-02, 1.3628e-02, -7.8374e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[-4.6442e-02, -2.9462e-02, 5.9477e-05],\n", - " [ 7.8429e-03, 2.8139e-02, 3.6842e-02],\n", - " [ 8.4957e-04, 1.0251e-03, 1.3028e-02]],\n", - "\n", - " [[-8.7456e-03, 4.2044e-02, -2.0930e-02],\n", - " [-2.2964e-02, 8.5618e-03, -3.3057e-02],\n", - " [ 1.2229e-02, -2.6854e-02, -1.5350e-03]],\n", - "\n", - " [[-2.5714e-02, 8.3247e-03, -1.1987e-02],\n", - " [-1.9772e-02, -9.5624e-03, 4.1201e-02],\n", - " [ 2.0758e-02, 1.6308e-02, 1.8109e-02]]],\n", - "\n", - "\n", - " [[[ 2.0949e-02, 1.8669e-02, -1.4822e-02],\n", - " [ 3.7938e-02, -3.4789e-02, 1.0094e-02],\n", - " [-2.0915e-03, -4.5284e-03, 1.6016e-02]],\n", - "\n", - " [[ 2.8699e-02, -3.7645e-02, -5.9033e-03],\n", - " [ 5.7764e-02, -7.0500e-04, 5.1310e-02],\n", - " [-2.5892e-02, 3.4094e-02, -1.7359e-02]],\n", - "\n", - " [[-2.8955e-02, -1.7045e-02, -9.4707e-03],\n", - " [-9.5516e-03, -8.9949e-03, 1.2905e-02],\n", - " [-3.1667e-02, 1.7581e-02, 2.1029e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-8.9293e-04, 8.4387e-03, 1.9757e-02],\n", - " [-2.4149e-02, -1.4277e-02, 5.7769e-03],\n", - " [-3.3923e-02, 2.1901e-02, -1.3470e-02]],\n", - "\n", - " [[ 2.6607e-02, -1.6824e-03, 1.5170e-02],\n", - " [ 3.2252e-02, -2.1156e-02, 7.6330e-04],\n", - " [ 1.6243e-02, 1.0417e-02, -3.2937e-02]],\n", - "\n", - " [[-8.1500e-03, 6.0820e-03, -2.3481e-02],\n", - " [ 7.4835e-03, 2.2905e-02, 2.9158e-02],\n", - " [-5.4513e-02, -8.0323e-02, 1.0160e-02]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[-1.7320e-02, -1.1638e-02, 3.1909e-02],\n", - " [-8.7472e-03, -1.5413e-02, -1.9191e-02],\n", - " [-1.7741e-03, 4.8668e-03, 9.9375e-03]],\n", - "\n", - " [[ 5.7990e-03, 2.6241e-03, 4.7727e-03],\n", - " [-6.8299e-03, 1.6355e-02, 1.5993e-02],\n", - " [-3.5001e-02, 4.5646e-02, -1.3821e-02]],\n", - "\n", - " [[-2.1266e-02, -5.8349e-03, -3.3456e-03],\n", - " [ 1.5735e-02, -5.0056e-02, -3.8845e-02],\n", - " [ 1.7596e-02, -1.8245e-03, 1.6151e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.0309e-02, -1.3287e-02, 3.8571e-03],\n", - " [ 6.3270e-03, -3.5081e-02, -2.1760e-02],\n", - " [ 2.4283e-02, -3.5668e-02, 1.6085e-02]],\n", - "\n", - " [[ 2.4923e-02, 1.6274e-02, 5.0354e-02],\n", - " [ 2.7379e-03, 1.3269e-02, 3.0485e-02],\n", - " [-4.1538e-02, -1.9868e-02, -4.0987e-02]],\n", - "\n", - " [[-6.5920e-03, 9.8657e-03, -1.8862e-02],\n", - " [ 2.4254e-02, -2.3374e-02, -2.2944e-02],\n", - " [-5.6579e-03, 5.4663e-03, 1.4742e-02]]],\n", - "\n", - "\n", - " [[[ 1.6981e-03, 1.7682e-02, 2.2609e-02],\n", - " [ 1.6195e-02, -3.6837e-03, -9.3915e-03],\n", - " [ 1.2224e-03, -2.6607e-02, -2.7341e-03]],\n", - "\n", - " [[ 2.0577e-02, -2.1010e-02, 4.7610e-02],\n", - " [ 1.2225e-02, -1.6388e-02, -4.9970e-03],\n", - " [ 2.0274e-02, -1.1799e-02, 3.1679e-02]],\n", - "\n", - " [[-2.8022e-02, -2.3188e-02, -1.6347e-02],\n", - " [-1.5444e-02, 3.4093e-02, 1.0499e-02],\n", - " [ 9.3748e-03, -1.0355e-02, 5.8578e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-2.1788e-02, 1.5121e-02, 3.4046e-02],\n", - " [-1.4173e-02, -9.8143e-03, -1.2345e-02],\n", - " [ 3.0159e-02, 3.3297e-02, 1.0358e-02]],\n", - "\n", - " [[-2.5139e-03, 2.3068e-03, 2.6721e-02],\n", - " [-2.3739e-02, 2.3444e-02, -1.1328e-02],\n", - " [ 6.5631e-03, 4.0345e-03, -1.9840e-02]],\n", - "\n", - " [[ 1.7692e-02, -2.0572e-02, -2.8111e-04],\n", - " [-1.2023e-02, 2.4996e-02, 2.4255e-02],\n", - " [-1.9400e-02, 3.1629e-02, -2.4052e-02]]],\n", - "\n", - "\n", - " [[[-2.4696e-02, 2.3890e-02, 1.3941e-03],\n", - " [ 6.2046e-03, 7.4803e-03, -1.8126e-02],\n", - " [-1.8920e-02, -3.0955e-02, -8.0770e-03]],\n", - "\n", - " [[ 7.2004e-03, -4.3122e-02, -3.4520e-02],\n", - " [ 6.0169e-03, -1.7975e-02, 1.6013e-02],\n", - " [-1.2754e-02, 7.1421e-03, -2.0949e-02]],\n", - "\n", - " [[-9.5524e-03, -1.2836e-02, -3.4424e-03],\n", - " [ 5.2468e-04, -2.8843e-02, -6.9024e-03],\n", - " [-3.4978e-03, 7.0781e-02, 6.6465e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.7539e-02, -2.4638e-02, -5.0424e-03],\n", - " [ 6.8030e-03, 2.3334e-02, 1.4998e-02],\n", - " [ 2.7096e-02, 3.0525e-02, -7.3968e-04]],\n", - "\n", - " [[-3.0516e-02, -3.9977e-02, -1.0588e-02],\n", - " [-2.5134e-03, 8.7634e-05, -2.6856e-02],\n", - " [ 1.7962e-02, 6.0497e-03, -6.6210e-05]],\n", - "\n", - " [[ 6.9240e-03, -1.9080e-02, 1.1087e-02],\n", - " [ 7.3493e-03, 4.4861e-03, 1.1068e-02],\n", - " [-2.2379e-02, 1.1944e-02, -1.2094e-02]]]])\n", - "layer3.0.bn1.weight Parameter containing:\n", - "tensor([0.9061, 1.0347, 1.0760, 1.0512, 0.9397, 1.0302, 0.9768, 1.0078, 1.0257,\n", - " 0.9997, 1.0039, 1.0391, 1.1620, 1.0561, 0.9218, 1.0240, 0.9866, 1.0253,\n", - " 0.9590, 0.9844, 1.0767, 1.0360, 1.0178, 1.0186, 1.0044, 1.0089, 1.0140,\n", - " 1.0252, 1.0301, 1.0595, 1.0679, 0.9882, 0.9082, 1.0731, 1.0115, 0.9792,\n", - " 0.9640, 1.0047, 0.8549, 1.0938, 0.8506, 0.9995, 0.9996, 0.9056, 0.9211,\n", - " 0.8706, 0.9308, 1.0217, 1.0128, 1.0225, 0.9260, 0.9987, 1.0239, 1.0860,\n", - " 0.9033, 0.9618, 1.0172, 1.0064, 1.0140, 1.0230, 0.8893, 1.0679, 0.9589,\n", - " 0.8701, 1.0907, 0.9620, 1.0249, 1.0191, 1.0027, 1.0092, 0.9996, 1.0263,\n", - " 0.9019, 0.9809, 1.0138, 0.9840, 1.0444, 1.0324, 1.0631, 0.9025, 1.0824,\n", - " 0.9573, 0.9810, 1.0726, 1.0022, 1.0539, 1.0418, 1.0293, 1.0278, 0.9352,\n", - " 0.9852, 0.9486, 1.0826, 1.0651, 1.0367, 1.0072, 1.0827, 1.0474, 1.0641,\n", - " 1.0313, 0.9774, 0.9143, 0.9635, 1.0805, 0.9259, 1.0173, 0.8997, 1.0448,\n", - " 0.9693, 1.0160, 1.0576, 0.9509, 1.0058, 0.9660, 1.0458, 1.0454, 0.8917,\n", - " 0.8641, 1.0203, 1.0884, 0.8875, 0.8684, 1.0861, 1.0081, 0.9273, 1.0426,\n", - " 1.1006, 0.9325, 1.0276, 0.9169, 0.9800, 1.0925, 1.0594, 0.8335, 1.0556,\n", - " 1.0944, 1.0630, 0.9928, 1.0146, 0.9829, 0.8950, 0.9355, 0.9619, 1.0186,\n", - " 0.9641, 1.0488, 0.9760, 1.0066, 0.9680, 1.0791, 1.0448, 0.9901, 1.0138,\n", - " 0.9991, 1.0693, 1.0020, 1.0520, 0.9433, 1.0379, 0.9295, 0.9748, 1.0369,\n", - " 1.0888, 1.0418, 0.9144, 0.9162, 0.9564, 0.9888, 0.8738, 0.9832, 0.9302,\n", - " 0.9893, 0.8909, 0.9891, 0.8922, 0.9782, 0.9909, 1.0133, 0.9086, 0.9760,\n", - " 0.9767, 1.0158, 0.9915, 1.0143, 0.9966, 0.9907, 0.9776, 0.9047, 1.0013,\n", - " 0.9864, 1.0067, 0.9248, 0.9688, 0.9908, 0.8471, 0.8832, 1.0943, 1.0505,\n", - " 1.0315, 1.0334, 0.8834, 0.9405, 1.0267, 0.9765, 1.0431, 0.9543, 0.9125,\n", - " 1.0020, 1.0517, 0.9656, 1.0035, 0.9053, 0.9510, 1.0689, 1.0049, 0.9757,\n", - " 1.0286, 0.8977, 1.0217, 1.0082, 1.0706, 1.0439, 0.9330, 1.0338, 1.1586,\n", - " 0.9350, 0.9765, 0.8640, 1.0371, 1.0314, 0.9612, 0.9800, 1.0228, 1.0002,\n", - " 1.0706, 1.0041, 1.0228, 0.9090, 0.9900, 1.0428, 0.9930, 1.0704, 0.9801,\n", - " 0.9030, 0.9475, 0.9210, 1.0132, 0.9513, 1.0110, 1.0776, 1.0226, 1.0207,\n", - " 0.9829, 0.9680, 0.9230, 1.0265])\n", - "layer3.0.bn1.bias Parameter containing:\n", - "tensor([-0.1787, -0.0670, -0.1104, -0.0911, -0.0801, -0.1031, -0.1063, -0.1099,\n", - " -0.0839, -0.1178, -0.0499, -0.1453, -0.0722, -0.0695, -0.1185, -0.1001,\n", - " -0.0961, -0.1003, -0.0839, -0.0936, -0.0897, -0.0942, -0.0798, -0.0622,\n", - " -0.0943, -0.0427, -0.0825, -0.0613, -0.0811, -0.0732, -0.0727, -0.0806,\n", - " -0.1232, -0.0963, -0.0856, -0.1087, -0.0833, -0.0848, -0.1352, -0.0787,\n", - " -0.1479, -0.0362, -0.1301, -0.1184, -0.1073, -0.1560, -0.0911, -0.0814,\n", - " -0.0877, -0.1103, -0.1217, -0.0934, -0.1284, -0.0638, -0.1504, -0.0730,\n", - " -0.0922, -0.1440, -0.0523, -0.0886, -0.1526, -0.0309, -0.1076, -0.1189,\n", - " -0.0806, -0.1403, -0.0696, -0.0722, -0.0710, -0.0701, -0.1449, -0.1027,\n", - " -0.1150, -0.0578, -0.0902, -0.0958, -0.0891, -0.0910, -0.0921, -0.0914,\n", - " -0.0470, -0.0853, -0.1400, -0.0813, -0.0946, -0.0688, -0.0459, -0.0683,\n", - " -0.0741, -0.0944, -0.1044, -0.0858, -0.0640, -0.0608, -0.1008, -0.0600,\n", - " -0.0892, -0.1165, -0.1024, -0.1160, -0.1608, -0.1190, -0.1343, -0.1135,\n", - " -0.1093, -0.0494, -0.0885, -0.0808, -0.0432, -0.0850, -0.0874, -0.0886,\n", - " -0.0916, -0.0610, -0.1053, -0.0913, -0.1357, -0.0879, -0.1158, -0.0854,\n", - " -0.1429, -0.1255, -0.0469, -0.0949, -0.0887, -0.1348, -0.0951, -0.1010,\n", - " -0.0553, -0.1506, -0.0773, -0.0961, -0.1000, -0.1271, -0.0111, -0.0747,\n", - " -0.0983, -0.0901, -0.1035, -0.0523, -0.1340, -0.1469, -0.1036, -0.1320,\n", - " -0.0788, -0.0828, -0.0806, -0.0947, -0.1260, -0.1188, -0.0600, -0.0903,\n", - " -0.1103, -0.0842, -0.0682, -0.0880, -0.0689, -0.1146, -0.0663, -0.1001,\n", - " -0.0970, -0.1130, -0.0439, -0.0618, -0.1329, -0.0918, -0.1008, -0.1022,\n", - " -0.1366, -0.0886, -0.0923, -0.1017, -0.0743, -0.1107, -0.1248, -0.0922,\n", - " -0.0982, -0.1354, -0.1418, -0.0818, -0.0899, -0.0553, -0.0871, -0.1282,\n", - " -0.0706, -0.0805, -0.1150, -0.1173, -0.0729, -0.0961, -0.0577, -0.0929,\n", - " -0.0875, -0.1348, -0.1249, -0.1365, -0.0689, -0.0694, -0.0860, -0.0885,\n", - " -0.1044, -0.0753, -0.0606, -0.1009, -0.1671, -0.1205, -0.1800, -0.0734,\n", - " -0.0862, -0.1019, -0.0975, -0.1568, -0.1117, -0.1548, -0.0654, -0.1274,\n", - " -0.0855, -0.1067, -0.1046, -0.0634, -0.0671, -0.0478, -0.1219, -0.0808,\n", - " -0.0356, -0.0614, -0.0897, -0.1689, -0.1237, -0.1068, -0.0926, -0.0682,\n", - " -0.0673, -0.0980, -0.0866, -0.1098, -0.0595, -0.1140, -0.0806, -0.1057,\n", - " -0.1030, -0.0569, -0.0721, -0.1279, -0.1018, -0.1139, -0.0615, -0.1067,\n", - " -0.0735, -0.1219, -0.1145, -0.0392, -0.1084, -0.0808, -0.0675, -0.0412])\n", - "layer3.0.conv2.weight Parameter containing:\n", - "tensor([[[[-0.0214, -0.0107, -0.0109],\n", - " [-0.0276, -0.0046, -0.0019],\n", - " [-0.0285, 0.0086, 0.0139]],\n", - "\n", - " [[ 0.0022, 0.0144, -0.0277],\n", - " [ 0.0292, 0.0160, -0.0108],\n", - " [ 0.0132, 0.0574, 0.0289]],\n", - "\n", - " [[-0.0278, 0.0145, -0.0049],\n", - " [ 0.0079, 0.0042, 0.0348],\n", - " [ 0.0134, 0.0028, -0.0223]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0053, -0.0094, 0.0096],\n", - " [ 0.0040, -0.0227, -0.0190],\n", - " [ 0.0357, 0.0016, 0.0191]],\n", - "\n", - " [[ 0.0082, -0.0114, -0.0138],\n", - " [ 0.0073, -0.0244, -0.0040],\n", - " [ 0.0254, 0.0098, 0.0303]],\n", - "\n", - " [[-0.0097, 0.0010, -0.0121],\n", - " [-0.0180, 0.0130, -0.0219],\n", - " [ 0.0072, 0.0129, -0.0006]]],\n", - "\n", - "\n", - " [[[ 0.0099, 0.0353, 0.0159],\n", - " [-0.0033, -0.0035, -0.0017],\n", - " [-0.0143, -0.0117, -0.0098]],\n", - "\n", - " [[ 0.0281, -0.0012, 0.0075],\n", - " [ 0.0288, 0.0100, -0.0020],\n", - " [ 0.0016, 0.0260, -0.0053]],\n", - "\n", - " [[ 0.0293, -0.0065, -0.0090],\n", - " [-0.0288, -0.0237, -0.0069],\n", - " [-0.0452, -0.0268, -0.0010]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0031, -0.0219, 0.0129],\n", - " [-0.0017, 0.0037, 0.0417],\n", - " [-0.0069, 0.0084, -0.0066]],\n", - "\n", - " [[ 0.0249, 0.0099, 0.0340],\n", - " [ 0.0362, 0.0196, 0.0393],\n", - " [ 0.0078, 0.0027, 0.0078]],\n", - "\n", - " [[ 0.0255, 0.0275, 0.0331],\n", - " [ 0.0168, 0.0004, 0.0083],\n", - " [-0.0081, -0.0078, -0.0197]]],\n", - "\n", - "\n", - " [[[ 0.0481, 0.0492, 0.0372],\n", - " [ 0.0172, -0.0387, -0.0074],\n", - " [-0.0222, -0.0242, -0.0616]],\n", - "\n", - " [[ 0.0019, 0.0307, 0.0023],\n", - " [ 0.0091, 0.0094, -0.0030],\n", - " [ 0.0219, -0.0130, -0.0093]],\n", - "\n", - " [[ 0.0172, 0.0382, 0.0198],\n", - " [ 0.0220, 0.0401, 0.0123],\n", - " [-0.0001, -0.0152, 0.0177]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0011, 0.0025, 0.0049],\n", - " [-0.0053, -0.0176, -0.0262],\n", - " [ 0.0189, 0.0165, -0.0020]],\n", - "\n", - " [[ 0.0171, 0.0041, -0.0016],\n", - " [-0.0089, -0.0153, -0.0008],\n", - " [ 0.0189, 0.0374, 0.0036]],\n", - "\n", - " [[ 0.0264, -0.0059, -0.0073],\n", - " [-0.0090, 0.0259, 0.0117],\n", - " [ 0.0016, -0.0017, 0.0198]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[ 0.0057, -0.0139, -0.0260],\n", - " [-0.0335, -0.0236, -0.0194],\n", - " [-0.0076, 0.0078, -0.0027]],\n", - "\n", - " [[ 0.0255, -0.0182, -0.0244],\n", - " [ 0.0128, -0.0183, -0.0070],\n", - " [ 0.0053, -0.0110, 0.0198]],\n", - "\n", - " [[ 0.0052, -0.0080, 0.0055],\n", - " [-0.0020, -0.0061, -0.0151],\n", - " [ 0.0397, 0.0304, 0.0184]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0044, -0.0023, -0.0192],\n", - " [ 0.0090, -0.0075, 0.0229],\n", - " [-0.0066, 0.0129, 0.0241]],\n", - "\n", - " [[-0.0238, -0.0174, 0.0229],\n", - " [ 0.0057, -0.0127, -0.0277],\n", - " [-0.0239, -0.0150, 0.0047]],\n", - "\n", - " [[-0.0127, -0.0376, -0.0095],\n", - " [-0.0089, -0.0277, -0.0007],\n", - " [ 0.0142, 0.0022, -0.0197]]],\n", - "\n", - "\n", - " [[[ 0.0211, -0.0185, -0.0099],\n", - " [-0.0169, 0.0075, 0.0242],\n", - " [-0.0085, -0.0022, 0.0176]],\n", - "\n", - " [[-0.0230, -0.0001, 0.0112],\n", - " [ 0.0105, -0.0100, 0.0319],\n", - " [ 0.0436, 0.0082, 0.0280]],\n", - "\n", - " [[-0.0323, -0.0262, -0.0027],\n", - " [ 0.0013, 0.0182, -0.0145],\n", - " [-0.0491, -0.0368, 0.0062]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0272, 0.0284, 0.0171],\n", - " [ 0.0156, 0.0173, 0.0069],\n", - " [-0.0398, 0.0368, 0.0365]],\n", - "\n", - " [[-0.0210, -0.0316, -0.0054],\n", - " [-0.0044, 0.0173, -0.0213],\n", - " [ 0.0172, 0.0107, 0.0009]],\n", - "\n", - " [[-0.0052, -0.0406, 0.0017],\n", - " [ 0.0289, 0.0095, -0.0076],\n", - " [-0.0059, 0.0249, -0.0022]]],\n", - "\n", - "\n", - " [[[-0.0250, -0.0216, 0.0087],\n", - " [ 0.0114, -0.0257, -0.0030],\n", - " [-0.0314, -0.0055, -0.0429]],\n", - "\n", - " [[-0.0375, -0.0081, -0.0144],\n", - " [ 0.0251, 0.0122, -0.0046],\n", - " [ 0.0348, 0.0360, 0.0058]],\n", - "\n", - " [[ 0.0087, 0.0354, 0.0022],\n", - " [ 0.0136, 0.0033, -0.0153],\n", - " [-0.0095, 0.0035, -0.0282]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0125, -0.0045, 0.0161],\n", - " [ 0.0001, 0.0047, -0.0010],\n", - " [ 0.0173, 0.0316, 0.0081]],\n", - "\n", - " [[-0.0102, -0.0060, 0.0126],\n", - " [-0.0195, -0.0114, 0.0203],\n", - " [-0.0016, 0.0137, 0.0091]],\n", - "\n", - " [[ 0.0342, 0.0114, 0.0099],\n", - " [ 0.0082, -0.0026, -0.0128],\n", - " [-0.0224, 0.0479, -0.0131]]]])\n", - "layer3.0.bn2.weight Parameter containing:\n", - "tensor([1.0119, 0.9384, 1.0827, 1.0648, 1.0254, 1.0561, 1.0390, 1.0293, 1.0542,\n", - " 1.1104, 1.0157, 1.0315, 1.0969, 1.1143, 1.1288, 1.0071, 1.0915, 1.0707,\n", - " 0.9861, 1.0539, 1.0197, 1.0023, 1.0544, 1.0703, 1.0772, 1.0126, 1.0308,\n", - " 0.9459, 1.0869, 1.0696, 1.0180, 1.0500, 0.9495, 0.9905, 0.9891, 0.9669,\n", - " 1.1350, 1.0054, 1.0277, 1.1071, 1.0835, 1.0175, 1.1351, 0.9666, 1.0443,\n", - " 1.0653, 0.9898, 1.0989, 1.1159, 1.0458, 1.1057, 0.9943, 0.9539, 0.9514,\n", - " 1.1175, 0.8948, 1.1319, 1.1366, 1.0852, 1.0669, 1.0453, 0.9932, 1.0121,\n", - " 1.0484, 1.0055, 1.0050, 1.0672, 1.0756, 1.0323, 0.9517, 1.0874, 0.9885,\n", - " 0.9905, 1.0063, 1.0442, 0.9888, 1.0702, 1.0889, 0.9140, 0.9522, 1.0362,\n", - " 1.0711, 1.0422, 1.1315, 1.0167, 1.1342, 0.8848, 1.0353, 1.0933, 1.0097,\n", - " 1.0541, 1.0233, 1.0751, 1.0267, 0.9831, 0.9609, 1.0753, 1.0302, 1.0146,\n", - " 1.0836, 1.1139, 1.0072, 1.0257, 0.9531, 1.0608, 1.0504, 1.0427, 0.9754,\n", - " 1.1409, 1.0762, 1.1007, 1.0515, 1.0316, 1.0814, 1.0715, 1.0389, 1.0805,\n", - " 0.9638, 1.0554, 0.9259, 1.0075, 1.0366, 0.9909, 1.0715, 1.0229, 1.0265,\n", - " 0.9960, 1.0285, 1.0408, 1.0881, 1.0641, 1.1841, 0.9831, 1.1062, 1.0657,\n", - " 1.0388, 1.1276, 1.0904, 1.0004, 1.0091, 0.9559, 1.0164, 0.9498, 0.9922,\n", - " 1.1715, 1.0042, 1.0876, 1.1105, 1.0552, 1.0394, 1.0504, 0.9417, 1.0226,\n", - " 0.9802, 1.0365, 0.9911, 0.9943, 0.9997, 0.9738, 0.9340, 1.1164, 1.0912,\n", - " 1.0402, 1.0733, 1.0423, 0.9550, 1.0115, 0.9951, 1.0020, 0.9947, 1.0187,\n", - " 1.0672, 0.9908, 1.0988, 0.9963, 1.0964, 1.0812, 1.1380, 1.0009, 0.9751,\n", - " 1.0534, 1.1093, 1.0231, 1.0245, 0.8350, 1.0952, 1.0597, 0.9853, 1.0292,\n", - " 1.0843, 1.0527, 1.0730, 0.9615, 1.1262, 0.9904, 1.0345, 1.0790, 1.0669,\n", - " 1.0724, 1.0305, 0.9846, 1.0156, 1.0134, 0.9018, 1.1278, 1.0555, 1.0536,\n", - " 1.0175, 0.9649, 1.0734, 1.0182, 0.9771, 0.9521, 1.0517, 0.9989, 0.9906,\n", - " 1.0523, 0.9449, 1.0905, 1.0786, 1.1407, 1.0546, 0.9064, 1.0434, 0.9211,\n", - " 1.0347, 0.8981, 1.0382, 1.0147, 1.0623, 0.9599, 1.0316, 1.0457, 1.0186,\n", - " 1.1254, 1.1296, 1.1115, 0.9257, 0.9933, 1.0522, 0.9944, 1.0506, 0.9876,\n", - " 0.9906, 1.0165, 0.8943, 0.9334, 1.0913, 0.9658, 1.0258, 1.0106, 0.9817,\n", - " 1.0933, 0.9433, 1.0289, 1.0667])\n", - "layer3.0.bn2.bias Parameter containing:\n", - "tensor([-0.0755, -0.0719, -0.0956, -0.1168, -0.1131, -0.0915, -0.0598, -0.1081,\n", - " -0.1179, -0.0952, -0.0843, -0.0633, -0.0956, -0.0895, -0.0859, -0.0657,\n", - " -0.0718, -0.0734, -0.0954, -0.0896, -0.0917, -0.0895, -0.0989, -0.0541,\n", - " -0.0895, -0.0864, -0.0950, -0.0644, -0.1172, -0.0876, -0.1271, -0.1238,\n", - " -0.0690, -0.1027, -0.0964, -0.0774, -0.0782, -0.1324, -0.0958, -0.0867,\n", - " -0.0766, -0.0599, -0.1177, -0.1013, -0.0862, -0.0843, -0.0909, -0.0763,\n", - " -0.0594, -0.0976, -0.0880, -0.0501, -0.0928, -0.0974, -0.1033, -0.1063,\n", - " -0.0736, -0.0788, -0.0619, -0.0717, -0.0822, -0.0713, -0.0758, -0.0622,\n", - " -0.0840, -0.0806, -0.1197, -0.0879, -0.0915, -0.1236, -0.0803, -0.0861,\n", - " -0.0713, -0.1014, -0.1262, -0.0877, -0.0946, -0.0900, -0.0831, -0.1069,\n", - " -0.0975, -0.1071, -0.0907, -0.1050, -0.0708, -0.0908, -0.1228, -0.0650,\n", - " -0.1774, -0.1430, -0.0872, -0.0692, -0.1001, -0.0892, -0.0924, -0.0851,\n", - " -0.1070, -0.0903, -0.0953, -0.0785, -0.0789, -0.1276, -0.0917, -0.0950,\n", - " -0.1073, -0.0781, -0.1001, -0.0755, -0.0976, -0.0857, -0.1092, -0.1065,\n", - " -0.1036, -0.0960, -0.0764, -0.0968, -0.1196, -0.0907, -0.0702, -0.0925,\n", - " -0.1132, -0.0908, -0.0933, -0.0812, -0.0916, -0.1031, -0.0769, -0.1044,\n", - " -0.1095, -0.0834, -0.0466, -0.0939, -0.1198, -0.0800, -0.1381, -0.0655,\n", - " -0.1258, -0.0981, -0.0926, -0.0740, -0.1012, -0.1057, -0.1134, -0.0877,\n", - " -0.1252, -0.0627, -0.1136, -0.0758, -0.1047, -0.0449, -0.0940, -0.1118,\n", - " -0.1121, -0.0740, -0.0836, -0.0855, -0.0764, -0.1345, -0.1051, -0.1177,\n", - " -0.0892, -0.0686, -0.0428, -0.0641, -0.0841, -0.0900, -0.1286, -0.0907,\n", - " -0.0964, -0.0649, -0.0511, -0.1129, -0.0960, -0.0700, -0.0757, -0.0907,\n", - " -0.0469, -0.0820, -0.1011, -0.0797, -0.1031, -0.0720, -0.0966, -0.0867,\n", - " -0.1202, -0.0668, -0.0907, -0.1095, -0.0925, -0.0882, -0.1421, -0.0695,\n", - " -0.1184, -0.0824, -0.0652, -0.1343, -0.1453, -0.0800, -0.0881, -0.0690,\n", - " -0.0882, -0.0614, -0.1213, -0.1122, -0.0952, -0.0945, -0.0591, -0.0854,\n", - " -0.1119, -0.0706, -0.0883, -0.1093, -0.0873, -0.0830, -0.0958, -0.0559,\n", - " -0.1080, -0.1117, -0.0335, -0.0701, -0.1163, -0.0996, -0.0995, -0.0798,\n", - " -0.0977, -0.0878, -0.0828, -0.1044, -0.0683, -0.0817, -0.0601, -0.0860,\n", - " -0.1298, -0.0825, -0.1011, -0.0311, -0.0796, -0.1116, -0.1409, -0.0530,\n", - " -0.0907, -0.1049, -0.0740, -0.1359, -0.0716, -0.0704, -0.1097, -0.0979,\n", - " -0.0780, -0.0969, -0.0487, -0.1053, -0.0492, -0.1102, -0.0792, -0.0921])\n", - "layer3.0.shortcut_conv.weight Parameter containing:\n", - "tensor([[[[ 0.1287]],\n", - "\n", - " [[-0.0332]],\n", - "\n", - " [[-0.0495]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0439]],\n", - "\n", - " [[ 0.0160]],\n", - "\n", - " [[ 0.0147]]],\n", - "\n", - "\n", - " [[[ 0.0304]],\n", - "\n", - " [[-0.0746]],\n", - "\n", - " [[-0.0246]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0141]],\n", - "\n", - " [[-0.0442]],\n", - "\n", - " [[ 0.0355]]],\n", - "\n", - "\n", - " [[[ 0.1153]],\n", - "\n", - " [[-0.0623]],\n", - "\n", - " [[-0.0273]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0472]],\n", - "\n", - " [[-0.0022]],\n", - "\n", - " [[-0.0559]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[-0.0353]],\n", - "\n", - " [[-0.0721]],\n", - "\n", - " [[ 0.0526]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0233]],\n", - "\n", - " [[ 0.1033]],\n", - "\n", - " [[-0.0018]]],\n", - "\n", - "\n", - " [[[-0.0268]],\n", - "\n", - " [[ 0.1036]],\n", - "\n", - " [[ 0.0685]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0291]],\n", - "\n", - " [[-0.0261]],\n", - "\n", - " [[-0.0402]]],\n", - "\n", - "\n", - " [[[-0.0194]],\n", - "\n", - " [[-0.0895]],\n", - "\n", - " [[ 0.0330]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0139]],\n", - "\n", - " [[-0.0454]],\n", - "\n", - " [[-0.0404]]]])\n", - "layer3.0.shortcut_bn.weight Parameter containing:\n", - "tensor([0.9961, 1.0297, 1.0394, 0.9886, 0.9974, 0.9935, 1.0201, 0.9893, 1.0292,\n", - " 1.0017, 1.0291, 0.9921, 0.9860, 1.0360, 1.0332, 0.9975, 0.9960, 0.9750,\n", - " 1.0053, 0.9879, 0.9756, 0.9804, 1.0613, 1.0357, 1.0245, 1.0222, 0.9865,\n", - " 1.0318, 0.9853, 1.0022, 0.9951, 0.9852, 0.9804, 0.9720, 1.0117, 1.0068,\n", - " 1.0167, 0.9779, 0.9935, 1.0345, 0.9953, 1.0129, 0.9515, 0.9842, 1.0005,\n", - " 1.0093, 1.0047, 1.0619, 1.0454, 1.0080, 1.0235, 0.9593, 0.9769, 0.9948,\n", - " 1.0086, 0.9753, 1.0169, 1.0139, 1.0104, 1.0031, 1.0335, 0.9674, 0.9881,\n", - " 1.0314, 0.9951, 1.0025, 0.9884, 0.9770, 1.0086, 0.9939, 0.9998, 0.9796,\n", - " 0.9894, 1.0292, 1.0134, 0.9811, 1.0466, 1.0405, 0.9304, 0.9861, 0.9896,\n", - " 0.9120, 1.0059, 1.0098, 0.9744, 0.9987, 0.8989, 1.0100, 0.9922, 0.9780,\n", - " 0.9932, 1.0067, 0.9855, 0.9746, 0.9926, 0.9818, 1.0324, 0.9936, 1.0296,\n", - " 1.0343, 0.9481, 0.9449, 0.9791, 0.9461, 1.0283, 1.0003, 1.0000, 0.9881,\n", - " 0.9747, 0.9710, 0.9883, 1.0095, 0.9782, 1.0363, 0.9810, 1.0152, 1.0161,\n", - " 0.9391, 1.0096, 0.9620, 0.9905, 1.0231, 1.0099, 1.0045, 1.0248, 0.9783,\n", - " 0.9516, 1.0396, 1.0492, 1.0304, 0.9702, 0.9914, 1.0251, 1.0117, 1.0149,\n", - " 1.0228, 1.0268, 0.9980, 1.0049, 0.9876, 1.0151, 0.9991, 0.9651, 0.9457,\n", - " 1.0144, 0.9495, 0.9432, 1.0230, 1.0877, 1.0209, 0.9865, 0.9554, 1.0235,\n", - " 0.9386, 1.0033, 1.0058, 0.9780, 1.0165, 0.9841, 0.9669, 1.0396, 1.0317,\n", - " 1.0194, 0.9639, 1.0583, 0.9585, 1.0002, 1.0167, 0.9869, 1.0233, 1.0380,\n", - " 0.9895, 1.0125, 0.9881, 1.0050, 0.9928, 1.0075, 1.0085, 1.0002, 0.9457,\n", - " 1.0325, 1.0370, 1.0178, 1.0055, 0.9116, 1.0270, 1.0444, 0.9693, 0.9844,\n", - " 0.9828, 0.9783, 0.9802, 0.9674, 1.0374, 0.9531, 1.0550, 1.0066, 0.9793,\n", - " 1.0020, 0.9739, 0.9739, 1.0216, 1.0090, 0.9594, 0.9608, 1.0677, 1.0088,\n", - " 0.9768, 0.9741, 0.9988, 1.0017, 0.9587, 0.9666, 0.9970, 0.9987, 1.0117,\n", - " 0.9786, 0.9544, 0.9625, 1.0127, 0.9780, 1.0227, 0.9385, 1.0324, 0.9408,\n", - " 1.0003, 0.9506, 1.0368, 1.0347, 0.9962, 0.9938, 1.0976, 1.0253, 1.0174,\n", - " 0.9370, 0.9770, 1.0025, 0.9594, 0.9978, 1.0000, 1.0066, 1.0002, 0.9581,\n", - " 0.9819, 0.9930, 0.9400, 0.9619, 1.0309, 0.9531, 0.9920, 0.9624, 1.0137,\n", - " 1.0539, 0.9746, 0.9789, 0.9881])\n", - "layer3.0.shortcut_bn.bias Parameter containing:\n", - "tensor([-0.0755, -0.0719, -0.0956, -0.1168, -0.1131, -0.0915, -0.0598, -0.1081,\n", - " -0.1179, -0.0952, -0.0843, -0.0633, -0.0956, -0.0895, -0.0859, -0.0657,\n", - " -0.0718, -0.0734, -0.0954, -0.0896, -0.0917, -0.0895, -0.0989, -0.0541,\n", - " -0.0895, -0.0864, -0.0950, -0.0644, -0.1172, -0.0876, -0.1271, -0.1238,\n", - " -0.0690, -0.1027, -0.0964, -0.0774, -0.0782, -0.1324, -0.0958, -0.0867,\n", - " -0.0766, -0.0599, -0.1177, -0.1013, -0.0862, -0.0843, -0.0909, -0.0763,\n", - " -0.0594, -0.0976, -0.0880, -0.0501, -0.0928, -0.0974, -0.1033, -0.1063,\n", - " -0.0736, -0.0788, -0.0619, -0.0717, -0.0822, -0.0713, -0.0758, -0.0622,\n", - " -0.0840, -0.0806, -0.1197, -0.0879, -0.0915, -0.1236, -0.0803, -0.0861,\n", - " -0.0713, -0.1014, -0.1262, -0.0877, -0.0946, -0.0900, -0.0831, -0.1069,\n", - " -0.0975, -0.1071, -0.0907, -0.1050, -0.0708, -0.0908, -0.1228, -0.0650,\n", - " -0.1774, -0.1430, -0.0872, -0.0692, -0.1001, -0.0892, -0.0924, -0.0851,\n", - " -0.1070, -0.0903, -0.0953, -0.0785, -0.0789, -0.1276, -0.0917, -0.0950,\n", - " -0.1073, -0.0781, -0.1001, -0.0755, -0.0976, -0.0857, -0.1092, -0.1065,\n", - " -0.1036, -0.0960, -0.0764, -0.0968, -0.1196, -0.0907, -0.0702, -0.0925,\n", - " -0.1132, -0.0908, -0.0933, -0.0812, -0.0916, -0.1031, -0.0769, -0.1044,\n", - " -0.1095, -0.0834, -0.0466, -0.0939, -0.1198, -0.0800, -0.1381, -0.0655,\n", - " -0.1258, -0.0981, -0.0926, -0.0740, -0.1012, -0.1057, -0.1134, -0.0877,\n", - " -0.1252, -0.0627, -0.1136, -0.0758, -0.1047, -0.0449, -0.0940, -0.1118,\n", - " -0.1121, -0.0740, -0.0836, -0.0855, -0.0764, -0.1345, -0.1051, -0.1177,\n", - " -0.0892, -0.0686, -0.0428, -0.0641, -0.0841, -0.0900, -0.1286, -0.0907,\n", - " -0.0964, -0.0649, -0.0511, -0.1129, -0.0960, -0.0700, -0.0757, -0.0907,\n", - " -0.0469, -0.0820, -0.1011, -0.0797, -0.1031, -0.0720, -0.0966, -0.0867,\n", - " -0.1202, -0.0668, -0.0907, -0.1095, -0.0925, -0.0882, -0.1421, -0.0695,\n", - " -0.1184, -0.0824, -0.0652, -0.1343, -0.1453, -0.0800, -0.0881, -0.0690,\n", - " -0.0882, -0.0614, -0.1213, -0.1122, -0.0952, -0.0945, -0.0591, -0.0854,\n", - " -0.1119, -0.0706, -0.0883, -0.1093, -0.0873, -0.0830, -0.0958, -0.0559,\n", - " -0.1080, -0.1117, -0.0335, -0.0701, -0.1163, -0.0996, -0.0995, -0.0798,\n", - " -0.0977, -0.0878, -0.0828, -0.1044, -0.0683, -0.0817, -0.0601, -0.0860,\n", - " -0.1298, -0.0825, -0.1011, -0.0311, -0.0796, -0.1116, -0.1409, -0.0530,\n", - " -0.0907, -0.1049, -0.0740, -0.1359, -0.0716, -0.0704, -0.1097, -0.0979,\n", - " -0.0780, -0.0969, -0.0487, -0.1053, -0.0492, -0.1102, -0.0792, -0.0921])\n", - "layer3.1.conv1.weight Parameter containing:\n", - "tensor([[[[-0.0089, -0.0060, -0.0078],\n", - " [-0.0083, -0.0144, -0.0087],\n", - " [ 0.0079, -0.0106, -0.0188]],\n", - "\n", - " [[ 0.0221, 0.0075, -0.0010],\n", - " [ 0.0078, 0.0238, 0.0237],\n", - " [-0.0274, -0.0009, -0.0108]],\n", - "\n", - " [[-0.0164, -0.0243, 0.0176],\n", - " [-0.0234, 0.0124, 0.0029],\n", - " [ 0.0187, 0.0050, 0.0236]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0223, 0.0203, 0.0211],\n", - " [ 0.0131, 0.0264, 0.0292],\n", - " [ 0.0029, 0.0059, -0.0041]],\n", - "\n", - " [[ 0.0133, 0.0358, 0.0329],\n", - " [-0.0023, -0.0191, -0.0157],\n", - " [ 0.0123, -0.0018, 0.0313]],\n", - "\n", - " [[ 0.0187, -0.0173, -0.0203],\n", - " [-0.0146, -0.0170, 0.0135],\n", - " [-0.0026, -0.0255, 0.0016]]],\n", - "\n", - "\n", - " [[[ 0.0035, -0.0063, -0.0086],\n", - " [ 0.0185, -0.0081, -0.0105],\n", - " [ 0.0020, 0.0077, 0.0059]],\n", - "\n", - " [[-0.0036, 0.0112, -0.0225],\n", - " [ 0.0080, 0.0294, 0.0012],\n", - " [-0.0424, -0.0239, 0.0066]],\n", - "\n", - " [[ 0.0141, 0.0135, -0.0167],\n", - " [-0.0067, 0.0298, -0.0052],\n", - " [-0.0146, 0.0024, -0.0051]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0113, -0.0088, 0.0199],\n", - " [ 0.0106, 0.0112, 0.0165],\n", - " [ 0.0094, 0.0180, -0.0111]],\n", - "\n", - " [[-0.0123, -0.0243, 0.0136],\n", - " [-0.0087, 0.0113, 0.0470],\n", - " [-0.0176, -0.0201, -0.0065]],\n", - "\n", - " [[ 0.0094, -0.0160, -0.0185],\n", - " [-0.0041, 0.0019, -0.0137],\n", - " [-0.0054, -0.0282, 0.0160]]],\n", - "\n", - "\n", - " [[[ 0.0223, 0.0120, 0.0172],\n", - " [-0.0038, -0.0191, -0.0039],\n", - " [-0.0211, -0.0171, -0.0283]],\n", - "\n", - " [[-0.0338, -0.0060, -0.0046],\n", - " [-0.0102, 0.0018, -0.0045],\n", - " [ 0.0404, 0.0651, 0.0263]],\n", - "\n", - " [[-0.0213, -0.0207, 0.0050],\n", - " [-0.0230, -0.0233, 0.0018],\n", - " [-0.0216, -0.0057, -0.0030]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0011, 0.0098, -0.0103],\n", - " [-0.0096, 0.0110, 0.0143],\n", - " [-0.0016, -0.0168, -0.0251]],\n", - "\n", - " [[ 0.0131, 0.0102, 0.0009],\n", - " [-0.0201, -0.0286, -0.0234],\n", - " [ 0.0084, 0.0084, 0.0085]],\n", - "\n", - " [[ 0.0129, 0.0274, 0.0181],\n", - " [-0.0245, -0.0168, -0.0007],\n", - " [ 0.0059, -0.0039, -0.0160]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[ 0.0009, 0.0100, -0.0076],\n", - " [-0.0015, -0.0030, 0.0290],\n", - " [ 0.0140, 0.0312, -0.0070]],\n", - "\n", - " [[ 0.0038, 0.0222, 0.0296],\n", - " [-0.0521, 0.0081, 0.0183],\n", - " [-0.0018, 0.0102, 0.0051]],\n", - "\n", - " [[ 0.0273, 0.0253, 0.0158],\n", - " [ 0.0034, 0.0326, 0.0349],\n", - " [ 0.0144, 0.0285, -0.0126]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0185, 0.0185, -0.0245],\n", - " [-0.0177, 0.0276, 0.0237],\n", - " [ 0.0372, 0.0134, 0.0167]],\n", - "\n", - " [[ 0.0060, -0.0221, -0.0462],\n", - " [-0.0235, -0.0135, 0.0108],\n", - " [ 0.0101, 0.0221, 0.0232]],\n", - "\n", - " [[-0.0022, 0.0149, 0.0309],\n", - " [-0.0073, -0.0145, -0.0246],\n", - " [-0.0053, 0.0137, -0.0079]]],\n", - "\n", - "\n", - " [[[-0.0138, 0.0086, -0.0118],\n", - " [-0.0012, -0.0124, 0.0198],\n", - " [-0.0138, -0.0054, -0.0149]],\n", - "\n", - " [[-0.0191, 0.0080, 0.0100],\n", - " [ 0.0100, -0.0196, -0.0140],\n", - " [ 0.0006, 0.0038, 0.0065]],\n", - "\n", - " [[ 0.0015, -0.0038, 0.0016],\n", - " [ 0.0130, 0.0152, -0.0030],\n", - " [-0.0230, 0.0007, 0.0087]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0062, 0.0096, -0.0146],\n", - " [ 0.0090, -0.0076, 0.0036],\n", - " [ 0.0204, -0.0055, 0.0151]],\n", - "\n", - " [[-0.0062, 0.0229, 0.0011],\n", - " [ 0.0150, 0.0210, -0.0116],\n", - " [ 0.0360, 0.0070, 0.0110]],\n", - "\n", - " [[ 0.0160, 0.0141, 0.0097],\n", - " [ 0.0261, 0.0006, 0.0071],\n", - " [ 0.0121, 0.0004, -0.0189]]],\n", - "\n", - "\n", - " [[[-0.0238, -0.0239, -0.0015],\n", - " [ 0.0065, -0.0170, 0.0101],\n", - " [ 0.0099, -0.0033, 0.0196]],\n", - "\n", - " [[ 0.0159, 0.0363, -0.0085],\n", - " [-0.0036, 0.0383, 0.0210],\n", - " [-0.0322, 0.0326, -0.0029]],\n", - "\n", - " [[-0.0424, -0.0062, 0.0063],\n", - " [ 0.0033, -0.0216, -0.0250],\n", - " [ 0.0295, -0.0276, -0.0038]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0210, -0.0169, -0.0170],\n", - " [ 0.0021, 0.0172, -0.0203],\n", - " [-0.0038, 0.0303, -0.0193]],\n", - "\n", - " [[-0.0305, -0.0032, -0.0458],\n", - " [-0.0235, -0.0152, -0.0123],\n", - " [ 0.0110, -0.0191, -0.0239]],\n", - "\n", - " [[-0.0303, -0.0097, -0.0278],\n", - " [-0.0252, 0.0002, -0.0171],\n", - " [-0.0116, -0.0104, 0.0082]]]])\n", - "layer3.1.bn1.weight Parameter containing:\n", - "tensor([0.8894, 1.0014, 0.9618, 1.0368, 0.9632, 1.0074, 0.9507, 1.0399, 1.0297,\n", - " 0.9589, 0.9957, 0.9815, 1.0905, 0.9845, 0.9782, 0.9698, 1.0131, 0.9468,\n", - " 1.0746, 1.0122, 0.9194, 1.0240, 0.9088, 1.0045, 0.9597, 0.9893, 1.0038,\n", - " 0.9854, 0.9923, 0.9757, 0.9893, 0.9965, 1.0263, 0.9475, 1.0418, 1.0299,\n", - " 0.9838, 0.9242, 0.9551, 1.0142, 1.0137, 0.9776, 0.9908, 0.9987, 1.0010,\n", - " 1.0373, 0.9471, 0.9570, 1.0478, 0.9621, 0.9507, 0.9834, 1.0207, 1.0203,\n", - " 1.0108, 0.9961, 0.9667, 1.0613, 0.9601, 0.9387, 1.0048, 0.9592, 1.0594,\n", - " 1.0004, 1.0143, 0.9796, 1.0497, 0.9597, 0.9803, 0.9279, 0.9861, 1.0094,\n", - " 0.9469, 0.9917, 0.9735, 0.9667, 0.9915, 1.0499, 0.9900, 1.0731, 1.0523,\n", - " 0.9240, 1.0704, 1.0292, 0.9927, 1.0082, 0.9503, 1.0278, 0.9857, 0.9680,\n", - " 0.9988, 0.9566, 1.0067, 1.0650, 1.0072, 1.0185, 0.9237, 1.0230, 0.9879,\n", - " 0.9544, 0.9825, 1.0734, 0.9985, 0.9623, 1.0024, 1.0158, 1.0302, 1.0588,\n", - " 0.9510, 0.9827, 1.0133, 0.9888, 0.9917, 0.9738, 1.0048, 0.9558, 0.9368,\n", - " 0.9419, 0.9548, 0.9739, 0.9351, 1.0554, 0.9564, 0.9554, 0.9746, 0.9900,\n", - " 0.9743, 1.0129, 1.0568, 1.0261, 0.8503, 0.9825, 1.0072, 1.0561, 1.0295,\n", - " 0.9250, 0.9449, 0.9652, 1.0274, 0.9811, 0.9868, 0.9987, 0.9758, 0.9687,\n", - " 0.9447, 1.0006, 0.9603, 1.0075, 1.0120, 0.9941, 1.0805, 1.0201, 0.9765,\n", - " 0.9931, 1.0093, 1.0358, 1.0294, 1.0004, 1.0088, 1.0204, 0.9912, 0.9467,\n", - " 1.0267, 0.9717, 0.9932, 1.0262, 0.9624, 0.9232, 1.0127, 0.9702, 1.0690,\n", - " 0.9871, 0.9572, 0.9898, 1.0525, 1.0083, 1.0368, 0.9964, 1.0338, 0.8748,\n", - " 0.9964, 0.9670, 0.9804, 1.0126, 0.9944, 1.0274, 0.9512, 0.9921, 0.9697,\n", - " 1.0625, 0.9837, 0.9659, 1.0131, 0.9245, 1.0110, 0.9741, 0.9294, 0.9803,\n", - " 1.0149, 1.0334, 0.9848, 1.0037, 0.9699, 1.0264, 0.9817, 0.9894, 0.8620,\n", - " 1.0600, 1.0214, 0.9951, 1.0127, 0.9452, 0.9733, 1.0047, 1.0568, 1.0307,\n", - " 1.0134, 1.0109, 0.9813, 0.9758, 0.9878, 0.9932, 0.9886, 0.9710, 1.0007,\n", - " 1.0168, 0.9869, 0.9622, 1.0478, 1.0258, 0.9697, 1.0451, 1.0032, 0.9455,\n", - " 0.9796, 1.0273, 1.0156, 1.0052, 0.9761, 1.0013, 0.9775, 0.9848, 0.9489,\n", - " 0.9581, 1.0368, 0.9701, 1.0158, 1.0756, 1.0068, 0.9919, 0.9778, 0.9810,\n", - " 1.0233, 1.0328, 0.9609, 1.0135])\n", - "layer3.1.bn1.bias Parameter containing:\n", - "tensor([-0.2024, -0.0811, -0.0922, -0.0835, -0.1042, -0.0950, -0.1535, -0.1058,\n", - " -0.0725, -0.1229, -0.1334, -0.1228, -0.1481, -0.1424, -0.0705, -0.1125,\n", - " -0.1025, -0.1020, -0.1740, -0.1615, -0.1094, -0.1909, -0.0887, -0.0977,\n", - " -0.1232, -0.1152, -0.1137, -0.1736, -0.0889, -0.1433, -0.1535, -0.1320,\n", - " -0.1768, -0.1870, -0.0638, -0.0644, -0.1462, -0.1610, -0.0984, -0.0944,\n", - " -0.1171, -0.1277, -0.0961, -0.1230, -0.0759, -0.1204, -0.2193, -0.1107,\n", - " -0.1224, -0.0838, -0.0812, -0.1262, -0.1052, -0.0962, -0.1050, -0.1323,\n", - " -0.1272, -0.0985, -0.0897, -0.1146, -0.0901, -0.0966, -0.0778, -0.0866,\n", - " -0.1539, -0.0671, -0.0921, -0.0988, -0.1000, -0.1371, -0.1845, -0.1234,\n", - " -0.1169, -0.1379, -0.1190, -0.0677, -0.1166, -0.0890, -0.0833, -0.0995,\n", - " -0.1120, -0.1498, -0.0999, -0.1339, -0.0877, -0.1662, -0.0699, -0.1084,\n", - " -0.1234, -0.1270, -0.1432, -0.1244, -0.0884, -0.1719, -0.0971, -0.1034,\n", - " -0.0922, -0.0886, -0.1187, -0.1214, -0.0703, -0.1064, -0.0996, -0.1092,\n", - " -0.1004, -0.0719, -0.1040, -0.1137, -0.1492, -0.1240, -0.1093, -0.0972,\n", - " -0.1311, -0.0610, -0.0975, -0.1565, -0.1471, -0.1279, -0.1107, -0.0912,\n", - " -0.1073, -0.1538, -0.1031, -0.1418, -0.1104, -0.1107, -0.0868, -0.1587,\n", - " -0.0774, -0.1199, -0.1776, -0.0625, -0.1020, -0.1542, -0.1789, -0.1120,\n", - " -0.1138, -0.1263, -0.1202, -0.0816, -0.1033, -0.1082, -0.1073, -0.1751,\n", - " -0.1201, -0.0667, -0.1008, -0.1173, -0.1198, -0.1258, -0.0948, -0.1198,\n", - " -0.0755, -0.0960, -0.1107, -0.1115, -0.0655, -0.1048, -0.1052, -0.1628,\n", - " -0.1215, -0.0877, -0.1073, -0.1295, -0.1083, -0.1357, -0.1008, -0.0685,\n", - " -0.1130, -0.1215, -0.1278, -0.1115, -0.0825, -0.0946, -0.0917, -0.1774,\n", - " -0.1626, -0.1593, -0.1224, -0.1393, -0.0884, -0.1989, -0.0981, -0.1358,\n", - " -0.1323, -0.1139, -0.1106, -0.1173, -0.0777, -0.1334, -0.0942, -0.1504,\n", - " -0.0807, -0.1390, -0.0976, -0.1683, -0.2453, -0.1118, -0.1057, -0.1144,\n", - " -0.1056, -0.0634, -0.1000, -0.0946, -0.0647, -0.1244, -0.2231, -0.1076,\n", - " -0.1368, -0.1047, -0.1290, -0.0922, -0.1252, -0.1696, -0.0970, -0.1317,\n", - " -0.1719, -0.0943, -0.1631, -0.1203, -0.1298, -0.1731, -0.0776, -0.0930,\n", - " -0.1405, -0.0987, -0.1073, -0.1154, -0.1397, -0.1594, -0.0856, -0.0981,\n", - " -0.0991, -0.0981, -0.0937, -0.1345, -0.0902, -0.0947, -0.1005, -0.0755,\n", - " -0.0743, -0.1312, -0.1276, -0.1481, -0.1267, -0.1051, -0.1152, -0.0936,\n", - " -0.1153, -0.1397, -0.0979, -0.1554, -0.0985, -0.1164, -0.1008, -0.1003])\n", - "layer3.1.conv2.weight Parameter containing:\n", - "tensor([[[[ 0.0057, 0.0059, 0.0141],\n", - " [ 0.0202, 0.0084, -0.0168],\n", - " [ 0.0042, -0.0093, -0.0161]],\n", - "\n", - " [[-0.0071, -0.0256, -0.0123],\n", - " [ 0.0112, 0.0096, -0.0154],\n", - " [ 0.0074, 0.0087, 0.0290]],\n", - "\n", - " [[ 0.0193, 0.0200, 0.0237],\n", - " [ 0.0042, 0.0013, -0.0049],\n", - " [ 0.0255, 0.0222, 0.0247]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0049, 0.0150, 0.0204],\n", - " [ 0.0048, 0.0343, 0.0375],\n", - " [ 0.0155, -0.0052, 0.0337]],\n", - "\n", - " [[ 0.0311, -0.0015, -0.0184],\n", - " [ 0.0050, 0.0187, -0.0070],\n", - " [ 0.0176, 0.0203, 0.0322]],\n", - "\n", - " [[ 0.0042, 0.0106, -0.0159],\n", - " [-0.0078, 0.0319, -0.0009],\n", - " [ 0.0051, 0.0336, 0.0013]]],\n", - "\n", - "\n", - " [[[ 0.0060, -0.0394, -0.0219],\n", - " [-0.0232, -0.0402, -0.0123],\n", - " [-0.0100, -0.0346, -0.0340]],\n", - "\n", - " [[ 0.0050, 0.0313, 0.0401],\n", - " [-0.0034, -0.0222, 0.0018],\n", - " [ 0.0058, 0.0137, -0.0026]],\n", - "\n", - " [[ 0.0069, 0.0209, -0.0171],\n", - " [-0.0181, -0.0249, -0.0093],\n", - " [ 0.0046, -0.0241, 0.0079]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0178, 0.0215, -0.0115],\n", - " [ 0.0114, -0.0228, 0.0059],\n", - " [-0.0155, -0.0046, -0.0166]],\n", - "\n", - " [[ 0.0264, 0.0007, 0.0135],\n", - " [ 0.0048, -0.0157, 0.0191],\n", - " [-0.0309, -0.0115, -0.0107]],\n", - "\n", - " [[ 0.0129, 0.0396, 0.0171],\n", - " [ 0.0243, 0.0272, -0.0214],\n", - " [ 0.0051, 0.0171, -0.0188]]],\n", - "\n", - "\n", - " [[[ 0.0133, -0.0035, -0.0081],\n", - " [-0.0374, -0.0194, 0.0275],\n", - " [-0.0025, -0.0212, 0.0177]],\n", - "\n", - " [[-0.0197, -0.0340, -0.0122],\n", - " [ 0.0075, 0.0080, -0.0352],\n", - " [-0.0108, -0.0002, -0.0069]],\n", - "\n", - " [[ 0.0178, 0.0160, -0.0038],\n", - " [ 0.0167, -0.0179, -0.0075],\n", - " [ 0.0003, -0.0214, -0.0138]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0011, 0.0236, 0.0067],\n", - " [ 0.0328, -0.0104, 0.0123],\n", - " [-0.0324, -0.0213, -0.0151]],\n", - "\n", - " [[ 0.0477, 0.0054, 0.0267],\n", - " [ 0.0214, 0.0166, 0.0255],\n", - " [ 0.0306, 0.0288, 0.0306]],\n", - "\n", - " [[-0.0001, 0.0285, 0.0163],\n", - " [-0.0139, 0.0104, -0.0082],\n", - " [ 0.0160, 0.0019, -0.0099]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[ 0.0320, 0.0127, -0.0041],\n", - " [ 0.0126, 0.0071, -0.0102],\n", - " [ 0.0449, -0.0208, 0.0264]],\n", - "\n", - " [[-0.0218, -0.0197, 0.0082],\n", - " [ 0.0212, -0.0167, -0.0232],\n", - " [-0.0246, -0.0144, -0.0146]],\n", - "\n", - " [[ 0.0196, 0.0360, 0.0037],\n", - " [ 0.0174, 0.0118, -0.0163],\n", - " [ 0.0149, -0.0078, 0.0053]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0014, -0.0044, -0.0235],\n", - " [-0.0170, -0.0056, -0.0082],\n", - " [-0.0130, 0.0023, 0.0069]],\n", - "\n", - " [[ 0.0082, 0.0044, 0.0025],\n", - " [ 0.0071, -0.0022, -0.0127],\n", - " [ 0.0079, -0.0118, 0.0401]],\n", - "\n", - " [[ 0.0127, 0.0259, 0.0020],\n", - " [-0.0149, -0.0037, -0.0306],\n", - " [ 0.0252, 0.0114, -0.0398]]],\n", - "\n", - "\n", - " [[[-0.0067, 0.0159, -0.0116],\n", - " [-0.0255, 0.0141, -0.0340],\n", - " [-0.0379, -0.0369, -0.0058]],\n", - "\n", - " [[-0.0109, -0.0326, -0.0156],\n", - " [-0.0299, -0.0425, -0.0245],\n", - " [-0.0656, -0.0205, 0.0051]],\n", - "\n", - " [[-0.0484, -0.0066, -0.0070],\n", - " [-0.0176, -0.0068, -0.0352],\n", - " [-0.0163, 0.0053, -0.0278]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0009, 0.0022, -0.0277],\n", - " [ 0.0112, 0.0170, -0.0237],\n", - " [-0.0069, 0.0070, -0.0239]],\n", - "\n", - " [[ 0.0105, 0.0199, -0.0199],\n", - " [-0.0018, 0.0375, 0.0256],\n", - " [ 0.0020, -0.0027, 0.0162]],\n", - "\n", - " [[-0.0299, 0.0109, -0.0016],\n", - " [-0.0151, 0.0335, 0.0003],\n", - " [-0.0382, 0.0336, 0.0335]]],\n", - "\n", - "\n", - " [[[ 0.0120, -0.0307, -0.0122],\n", - " [-0.0485, 0.0204, -0.0066],\n", - " [-0.0292, 0.0023, -0.0102]],\n", - "\n", - " [[ 0.0103, -0.0036, -0.0009],\n", - " [-0.0090, -0.0189, -0.0218],\n", - " [-0.0022, 0.0280, 0.0087]],\n", - "\n", - " [[ 0.0105, -0.0090, 0.0102],\n", - " [-0.0376, -0.0115, -0.0124],\n", - " [ 0.0302, 0.0081, 0.0010]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0258, 0.0199, 0.0164],\n", - " [-0.0088, -0.0327, 0.0286],\n", - " [-0.0255, -0.0063, 0.0088]],\n", - "\n", - " [[ 0.0173, -0.0106, -0.0176],\n", - " [-0.0141, 0.0150, -0.0110],\n", - " [-0.0014, 0.0298, 0.0401]],\n", - "\n", - " [[-0.0097, 0.0207, 0.0122],\n", - " [ 0.0098, 0.0361, 0.0345],\n", - " [ 0.0262, 0.0573, -0.0103]]]])\n", - "layer3.1.bn2.weight Parameter containing:\n", - "tensor([0.9815, 0.9771, 0.9937, 0.9977, 0.9733, 0.9092, 0.9829, 0.9169, 0.9609,\n", - " 0.9104, 0.8757, 0.9897, 0.9988, 1.0001, 1.0297, 0.8942, 0.9350, 0.9427,\n", - " 0.9168, 0.9627, 0.9683, 0.9354, 0.9495, 0.9964, 0.9769, 0.9127, 1.0025,\n", - " 0.9153, 0.9689, 0.9639, 0.9897, 1.0123, 0.8282, 0.9270, 0.9236, 0.9664,\n", - " 0.9928, 0.9019, 0.9426, 0.9955, 0.9358, 0.9749, 0.9828, 0.9415, 1.0004,\n", - " 0.9123, 0.9367, 0.9801, 0.9494, 1.0004, 0.9750, 0.9334, 1.0004, 0.9339,\n", - " 0.8686, 0.8470, 0.9903, 1.0119, 1.0165, 0.9926, 0.9256, 0.9330, 0.9856,\n", - " 0.9424, 1.0011, 0.9819, 0.9341, 1.0159, 0.9502, 0.9183, 0.9090, 0.9548,\n", - " 1.0113, 0.9721, 0.9436, 0.9371, 0.9473, 0.9046, 0.9444, 0.9053, 0.9450,\n", - " 0.9036, 1.0273, 0.9669, 0.9598, 0.9456, 0.8899, 0.9520, 0.9313, 0.9089,\n", - " 0.9674, 0.9503, 0.8887, 0.9139, 0.9624, 0.9340, 1.0025, 0.8952, 0.9185,\n", - " 0.9119, 0.9541, 0.9354, 0.9323, 0.9323, 1.0323, 0.9860, 0.9632, 0.9852,\n", - " 0.9095, 0.9466, 0.9968, 1.0321, 0.8953, 0.9952, 0.9438, 0.9600, 1.0027,\n", - " 0.9396, 0.9181, 0.9470, 0.9278, 0.9733, 1.0050, 0.9730, 0.9877, 0.9527,\n", - " 0.9141, 0.8995, 0.9542, 0.9364, 0.9155, 0.9720, 0.9578, 1.0013, 1.0498,\n", - " 0.9335, 0.9727, 0.9802, 0.9157, 0.8526, 0.8498, 0.9313, 0.9292, 0.9360,\n", - " 0.9503, 0.9614, 1.0159, 0.9742, 0.9591, 0.9834, 0.9546, 0.8601, 0.9151,\n", - " 0.9533, 0.9906, 0.8816, 0.9180, 0.8996, 0.9657, 0.8390, 0.9566, 0.9444,\n", - " 0.9687, 0.9363, 1.0103, 0.8784, 0.9648, 0.9563, 0.9561, 1.0161, 0.9506,\n", - " 0.9709, 0.9270, 0.9539, 0.9587, 0.9816, 1.0350, 0.9984, 0.9343, 0.9864,\n", - " 0.9699, 1.0122, 0.9517, 0.9404, 0.9368, 0.9919, 0.9200, 0.9307, 0.9481,\n", - " 0.9009, 0.9305, 0.9460, 0.9792, 0.9991, 0.8814, 0.9784, 0.9849, 0.9834,\n", - " 0.9355, 0.9520, 0.9651, 0.9303, 0.9693, 0.9200, 0.9991, 0.9569, 0.9554,\n", - " 0.9542, 0.8686, 1.0286, 0.9254, 0.8529, 0.9538, 0.9630, 0.8646, 0.9576,\n", - " 0.9614, 0.9150, 0.9488, 1.0792, 0.9474, 0.9442, 0.9534, 0.8712, 0.9632,\n", - " 0.9681, 0.9126, 0.9659, 0.9386, 0.9330, 0.9911, 0.9612, 0.9779, 0.9494,\n", - " 0.8655, 0.9691, 0.9830, 1.0095, 0.9631, 1.0426, 1.0039, 0.9078, 0.9335,\n", - " 0.9570, 0.9473, 0.9557, 0.9158, 0.9798, 0.9176, 0.8368, 0.9515, 0.9396,\n", - " 0.9669, 0.9546, 0.9310, 0.8975])\n", - "layer3.1.bn2.bias Parameter containing:\n", - "tensor([-0.1031, -0.1217, -0.1326, -0.1258, -0.1390, -0.0922, -0.0841, -0.1330,\n", - " -0.1585, -0.1332, -0.1653, -0.0609, -0.0702, -0.1586, -0.1070, -0.1527,\n", - " -0.1529, -0.1246, -0.1743, -0.1270, -0.1165, -0.1431, -0.1666, -0.0825,\n", - " -0.1168, -0.1519, -0.1148, -0.0939, -0.1390, -0.1273, -0.1161, -0.1486,\n", - " -0.1854, -0.1222, -0.1213, -0.1401, -0.1382, -0.1273, -0.1230, -0.1030,\n", - " -0.1122, -0.1197, -0.1796, -0.1384, -0.0926, -0.1447, -0.1543, -0.1356,\n", - " -0.0909, -0.1386, -0.1104, -0.1318, -0.1326, -0.1224, -0.1635, -0.1458,\n", - " -0.0994, -0.1207, -0.0871, -0.1215, -0.1321, -0.1024, -0.1104, -0.1533,\n", - " -0.0775, -0.1287, -0.1569, -0.1287, -0.1081, -0.0918, -0.1163, -0.1112,\n", - " -0.0950, -0.1562, -0.1089, -0.0883, -0.1193, -0.0902, -0.1262, -0.1601,\n", - " -0.1588, -0.2315, -0.1030, -0.1270, -0.0776, -0.1719, -0.1275, -0.1466,\n", - " -0.1918, -0.2018, -0.1322, -0.1183, -0.1523, -0.1425, -0.1505, -0.1293,\n", - " -0.1456, -0.1323, -0.1759, -0.1544, -0.1577, -0.1870, -0.1309, -0.0995,\n", - " -0.1034, -0.1288, -0.1301, -0.1242, -0.1568, -0.1373, -0.1541, -0.1013,\n", - " -0.1291, -0.0844, -0.1278, -0.1105, -0.1554, -0.1066, -0.0992, -0.1333,\n", - " -0.1511, -0.1137, -0.0950, -0.1285, -0.1438, -0.1257, -0.1515, -0.1425,\n", - " -0.1811, -0.1336, -0.1211, -0.1363, -0.1193, -0.1516, -0.1115, -0.1189,\n", - " -0.1272, -0.1287, -0.1752, -0.1583, -0.2094, -0.1342, -0.1100, -0.1305,\n", - " -0.1064, -0.1320, -0.1358, -0.1208, -0.1355, -0.1005, -0.1263, -0.1729,\n", - " -0.1199, -0.0826, -0.0964, -0.1561, -0.1532, -0.1950, -0.1397, -0.1617,\n", - " -0.1192, -0.1215, -0.0973, -0.1043, -0.1371, -0.1510, -0.1281, -0.1020,\n", - " -0.1469, -0.0536, -0.0947, -0.1230, -0.1114, -0.1059, -0.1259, -0.1463,\n", - " -0.0779, -0.1328, -0.1282, -0.1450, -0.1135, -0.1125, -0.1291, -0.1261,\n", - " -0.1648, -0.1404, -0.1629, -0.1324, -0.1645, -0.1539, -0.1878, -0.1183,\n", - " -0.1369, -0.1324, -0.1300, -0.1355, -0.1158, -0.1134, -0.1212, -0.1176,\n", - " -0.1827, -0.0990, -0.1417, -0.1345, -0.1632, -0.1474, -0.1211, -0.1373,\n", - " -0.1854, -0.1219, -0.1227, -0.1246, -0.1039, -0.0776, -0.1508, -0.0964,\n", - " -0.1037, -0.1521, -0.0945, -0.0795, -0.1662, -0.1484, -0.0996, -0.1285,\n", - " -0.1003, -0.1070, -0.1675, -0.1326, -0.1186, -0.1332, -0.0866, -0.1046,\n", - " -0.1211, -0.0580, -0.1816, -0.0841, -0.1563, -0.1192, -0.1486, -0.0603,\n", - " -0.0866, -0.1133, -0.1044, -0.2205, -0.1373, -0.0997, -0.1647, -0.1118,\n", - " -0.0883, -0.1750, -0.0968, -0.1916, -0.1200, -0.1369, -0.1829, -0.1554])\n", - "layer4.0.conv1.weight Parameter containing:\n", - "tensor([[[[-1.6195e-02, -4.3060e-02, -3.5322e-02],\n", - " [ 1.3287e-02, -3.1615e-02, 6.6042e-03],\n", - " [-4.0014e-04, -1.5452e-02, -9.1024e-03]],\n", - "\n", - " [[-1.4157e-02, -2.3567e-02, 1.4208e-02],\n", - " [-1.2353e-02, 1.6866e-02, 1.7598e-02],\n", - " [-2.0796e-03, 2.8303e-02, 1.2474e-02]],\n", - "\n", - " [[-1.2214e-02, -2.6098e-03, -3.1828e-02],\n", - " [-1.0792e-02, -2.9335e-02, 3.1974e-03],\n", - " [-6.6637e-03, -4.7699e-03, -1.9124e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-5.4387e-03, -7.4101e-03, -1.2255e-02],\n", - " [-9.4177e-03, -3.5547e-02, 8.5545e-03],\n", - " [-4.9719e-03, 4.6679e-03, 1.3081e-02]],\n", - "\n", - " [[ 9.7803e-03, -9.9373e-03, -1.0429e-02],\n", - " [ 5.0466e-03, -2.2508e-03, 1.5473e-02],\n", - " [-8.3493e-03, 8.0397e-03, -6.2141e-03]],\n", - "\n", - " [[-3.4990e-02, -1.2068e-02, -4.8616e-03],\n", - " [ 1.5458e-02, 3.9716e-05, 1.8083e-02],\n", - " [-2.0388e-02, -2.6572e-02, -3.2755e-02]]],\n", - "\n", - "\n", - " [[[ 2.7451e-02, -1.6241e-02, -5.4663e-03],\n", - " [-2.5122e-03, 9.7752e-03, 1.9901e-03],\n", - " [-1.1972e-02, 6.6506e-03, 2.7320e-02]],\n", - "\n", - " [[ 8.9008e-03, -3.0568e-03, 1.7271e-02],\n", - " [-2.8983e-03, 2.2242e-02, 1.4016e-02],\n", - " [ 1.3805e-02, 1.3811e-02, 1.8592e-02]],\n", - "\n", - " [[-1.1898e-02, 1.1536e-02, -1.6407e-02],\n", - " [ 1.8869e-02, 1.0717e-02, -4.6789e-03],\n", - " [-3.3365e-03, 1.0226e-02, 7.3902e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[-6.0078e-03, 2.0500e-02, 3.9849e-03],\n", - " [ 1.4595e-02, 1.4047e-02, -2.0757e-02],\n", - " [ 1.1118e-02, 3.0116e-02, 5.9856e-03]],\n", - "\n", - " [[-1.8926e-02, 2.3320e-02, 2.0495e-03],\n", - " [-9.0993e-03, 1.3444e-02, 1.7396e-02],\n", - " [ 5.3953e-03, -1.1074e-02, -1.0029e-02]],\n", - "\n", - " [[ 3.5316e-02, -1.0975e-02, 2.5410e-02],\n", - " [ 1.1206e-02, 3.4463e-02, 8.5281e-03],\n", - " [ 5.2337e-03, -1.8434e-02, 2.2219e-02]]],\n", - "\n", - "\n", - " [[[ 1.5784e-02, -2.8938e-02, -2.7249e-02],\n", - " [ 3.6373e-04, 1.7282e-03, 1.1883e-02],\n", - " [ 1.6549e-02, -2.4496e-02, -1.4036e-02]],\n", - "\n", - " [[-9.9012e-03, 1.3803e-02, -2.6475e-02],\n", - " [ 2.0324e-03, -3.2427e-02, 2.6785e-03],\n", - " [-2.3451e-02, -1.3071e-02, -1.3452e-02]],\n", - "\n", - " [[ 8.0489e-03, -8.7835e-03, 2.3109e-02],\n", - " [ 5.7620e-03, 5.4874e-03, 7.9275e-03],\n", - " [ 2.2045e-02, -7.0428e-03, -2.5761e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[-6.4942e-05, -9.1919e-03, 1.5764e-02],\n", - " [ 1.1611e-02, -8.8280e-03, 8.2890e-03],\n", - " [ 9.9335e-03, 1.7437e-02, -2.8565e-03]],\n", - "\n", - " [[-2.0067e-02, -8.6341e-03, -1.6160e-02],\n", - " [-2.8909e-02, -2.4014e-03, 2.9025e-02],\n", - " [ 1.4045e-02, -2.3337e-02, -7.0309e-03]],\n", - "\n", - " [[-2.1758e-02, 1.3736e-03, 8.1537e-03],\n", - " [-6.1176e-03, 1.5924e-02, -5.0352e-03],\n", - " [ 1.9466e-02, 2.9483e-02, 1.9495e-02]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[-2.3250e-02, 2.0028e-02, 2.9691e-02],\n", - " [ 4.9181e-03, 6.5347e-04, -3.8907e-03],\n", - " [-5.9318e-03, 2.8015e-02, -4.9769e-02]],\n", - "\n", - " [[-6.7370e-03, -2.8322e-02, -2.8513e-02],\n", - " [ 1.4032e-02, -4.3892e-04, -7.2716e-03],\n", - " [ 2.4244e-02, -1.1580e-02, -1.0610e-04]],\n", - "\n", - " [[-2.0042e-02, -4.6635e-02, -8.6490e-03],\n", - " [-6.1748e-03, 2.5440e-02, 3.8151e-02],\n", - " [-1.4358e-02, 3.1131e-03, -1.8387e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-4.9713e-02, -1.6618e-02, -2.0875e-02],\n", - " [ 1.5922e-02, -1.2795e-02, 1.7373e-02],\n", - " [ 2.1518e-02, 1.2405e-02, -2.8119e-03]],\n", - "\n", - " [[ 1.7024e-02, -1.3888e-03, 1.4596e-02],\n", - " [-2.7721e-02, 1.4081e-02, -4.4708e-03],\n", - " [ 1.6595e-02, 7.6307e-03, 1.7189e-03]],\n", - "\n", - " [[-3.4501e-02, -2.7804e-02, 3.3934e-02],\n", - " [-3.6700e-02, -2.7948e-03, -2.7233e-02],\n", - " [ 3.1503e-02, 2.9363e-03, -1.7693e-02]]],\n", - "\n", - "\n", - " [[[-7.1497e-03, -9.2680e-03, -1.6912e-02],\n", - " [-2.5124e-02, 1.3490e-02, -3.0191e-03],\n", - " [-3.4414e-02, -1.7793e-02, -6.2751e-04]],\n", - "\n", - " [[-1.0422e-02, 7.3811e-03, 4.8907e-04],\n", - " [-3.3726e-02, -3.7883e-02, 6.4931e-03],\n", - " [ 1.0105e-02, -1.0586e-02, 8.0654e-03]],\n", - "\n", - " [[ 1.3769e-02, -2.7422e-02, 1.1539e-02],\n", - " [ 3.9909e-02, 1.7185e-02, -9.6350e-03],\n", - " [ 2.6662e-02, 2.6606e-02, 3.8052e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 1.3738e-02, -2.4817e-02, 1.7280e-02],\n", - " [ 6.5198e-03, 3.2310e-02, 6.7140e-03],\n", - " [ 1.7105e-03, 1.2700e-02, 1.7214e-02]],\n", - "\n", - " [[-1.8066e-02, -2.9962e-02, -1.9691e-02],\n", - " [-1.8091e-03, 4.9750e-03, 3.6765e-02],\n", - " [-6.4435e-04, -1.4589e-02, 3.2546e-02]],\n", - "\n", - " [[-3.8230e-02, 7.1454e-03, -1.3799e-02],\n", - " [-1.3242e-02, 3.7163e-03, 4.9353e-03],\n", - " [ 2.6065e-03, -8.7617e-03, 1.9123e-04]]],\n", - "\n", - "\n", - " [[[ 2.7979e-02, 2.8406e-02, 1.1925e-03],\n", - " [ 2.6382e-03, -1.0972e-02, -2.5136e-02],\n", - " [ 4.6399e-03, -1.9916e-02, -3.2441e-02]],\n", - "\n", - " [[-5.8886e-03, -2.7692e-02, -3.8166e-04],\n", - " [-1.0334e-02, -1.1055e-02, -2.7850e-02],\n", - " [-8.6016e-05, -2.0821e-02, -2.4210e-02]],\n", - "\n", - " [[-1.0134e-02, 2.7711e-02, 1.8510e-02],\n", - " [ 2.5513e-02, 4.7151e-02, 2.7833e-02],\n", - " [-1.1608e-02, -1.1251e-03, 4.3334e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.5657e-02, -1.6549e-02, 2.1792e-02],\n", - " [-2.4143e-02, -7.6517e-03, -1.7215e-02],\n", - " [-1.3531e-02, 1.7204e-02, 2.0492e-02]],\n", - "\n", - " [[ 1.7322e-03, -2.6014e-02, -1.5783e-02],\n", - " [-1.8252e-02, -3.1492e-02, 2.1263e-02],\n", - " [-4.6617e-03, -4.9437e-02, 1.2522e-02]],\n", - "\n", - " [[-5.0840e-03, -1.1199e-02, 1.7259e-02],\n", - " [-2.2503e-03, -2.5246e-02, -1.4658e-02],\n", - " [ 2.6199e-03, -2.5215e-02, -1.4805e-02]]]])\n", - "layer4.0.bn1.weight Parameter containing:\n", - "tensor([0.9534, 0.9690, 0.9779, 0.9481, 0.9661, 1.0123, 0.9933, 0.9159, 0.9811,\n", - " 0.9206, 1.0421, 1.0466, 1.0091, 0.9920, 0.9846, 0.9807, 0.9996, 1.0425,\n", - " 0.9745, 0.9941, 1.0109, 0.9837, 1.0157, 0.9821, 1.0174, 0.9381, 1.0259,\n", - " 0.9520, 0.9610, 0.9995, 1.0986, 1.0051, 0.9452, 0.9813, 1.0712, 0.9902,\n", - " 1.0526, 1.0237, 1.0363, 0.9657, 0.9260, 1.0403, 1.0011, 0.9876, 0.9901,\n", - " 0.9819, 1.0163, 0.9651, 1.0024, 1.0359, 0.9798, 0.9556, 0.9099, 0.9861,\n", - " 1.0158, 0.9633, 1.0052, 1.0268, 0.8805, 0.9295, 0.9899, 1.0083, 1.0109,\n", - " 1.0104, 1.0468, 1.0115, 1.0763, 0.9179, 0.9844, 0.9861, 0.9508, 0.9012,\n", - " 0.9999, 0.9355, 0.9764, 0.9445, 0.9730, 1.0513, 1.0292, 1.0491, 1.0474,\n", - " 0.9812, 1.0326, 1.0261, 1.0150, 0.9971, 0.9672, 1.0045, 0.9567, 1.0128,\n", - " 1.0227, 1.0236, 1.0422, 0.9762, 0.9627, 1.0149, 1.0167, 1.0260, 1.0085,\n", - " 1.0187, 0.9780, 1.0015, 0.9958, 0.9427, 0.9465, 1.0131, 0.9742, 0.9766,\n", - " 0.9988, 0.9577, 1.0082, 0.9634, 0.9935, 1.0256, 0.9242, 1.0261, 0.9989,\n", - " 0.9949, 1.0067, 1.0322, 0.9711, 0.9497, 1.0195, 0.9055, 1.0064, 0.9844,\n", - " 0.9814, 0.9728, 1.0224, 0.9679, 0.9614, 1.0349, 1.0423, 0.9596, 0.9673,\n", - " 0.9739, 0.9891, 0.9619, 0.9872, 1.0058, 0.9381, 1.0493, 0.9425, 0.9895,\n", - " 1.0559, 1.0504, 1.0078, 0.9718, 0.9642, 0.9505, 1.0015, 0.9588, 0.9706,\n", - " 0.9973, 1.0377, 0.9871, 0.9899, 0.9458, 1.0140, 0.9517, 1.0294, 0.9697,\n", - " 1.0190, 1.0207, 0.9097, 0.9653, 0.9344, 1.0160, 0.9589, 0.9415, 0.9669,\n", - " 1.0531, 0.9557, 0.9938, 0.9876, 1.0481, 0.9405, 0.9630, 1.0038, 0.9767,\n", - " 1.0061, 1.0094, 1.0396, 0.9756, 0.9531, 1.0328, 0.9250, 0.9481, 0.9716,\n", - " 0.9042, 0.9550, 1.0123, 0.9906, 0.9989, 0.9321, 1.0071, 1.0029, 0.9474,\n", - " 0.9859, 1.0049, 1.0039, 1.0362, 1.0165, 1.0032, 0.9968, 1.0036, 1.0069,\n", - " 0.9967, 0.9749, 0.9296, 0.9297, 1.0247, 0.9740, 1.0138, 0.9648, 0.9531,\n", - " 1.0081, 1.0393, 0.9835, 1.0207, 1.0476, 0.9793, 1.0733, 1.0040, 1.0099,\n", - " 0.9120, 0.9656, 0.9239, 0.9971, 1.0198, 0.9772, 0.9728, 1.0036, 0.9730,\n", - " 1.0232, 0.9582, 0.9908, 0.9818, 0.9472, 0.9884, 1.0278, 1.0084, 1.0249,\n", - " 0.9170, 0.9236, 1.0222, 1.0302, 0.9916, 0.9527, 0.9615, 0.9906, 0.9947,\n", - " 0.9456, 0.9808, 0.9915, 0.9550, 1.0177, 0.9746, 0.9998, 0.9447, 0.9958,\n", - " 0.9704, 1.0024, 0.9991, 1.0001, 0.9900, 0.9976, 0.9804, 1.0698, 0.9950,\n", - " 0.9954, 0.9723, 0.9784, 0.9220, 1.0127, 1.0142, 0.9826, 1.0532, 1.0166,\n", - " 1.0256, 0.9422, 0.9629, 1.0517, 0.9915, 1.0066, 1.0042, 0.9927, 1.0177,\n", - " 1.0036, 0.9713, 0.9261, 0.9843, 0.9568, 0.8978, 1.0141, 1.0160, 0.9223,\n", - " 0.9694, 0.9814, 0.9412, 0.9750, 1.0287, 1.0105, 0.9912, 1.0344, 0.9992,\n", - " 0.9417, 1.0092, 0.9744, 0.9925, 0.9815, 0.9724, 0.9699, 0.9620, 0.9853,\n", - " 0.9886, 0.9213, 0.9863, 1.0109, 0.9926, 0.8956, 0.9822, 0.9506, 0.9584,\n", - " 0.9641, 1.0227, 0.9630, 1.0442, 0.8903, 1.0128, 0.9640, 1.0305, 0.9512,\n", - " 0.9702, 1.0634, 0.9853, 1.0459, 1.0792, 0.9543, 0.9602, 1.0012, 1.0700,\n", - " 1.0283, 1.0327, 1.0158, 1.0004, 0.9954, 1.0422, 0.9440, 1.0295, 0.9736,\n", - " 0.9947, 1.0174, 0.9618, 1.0027, 0.9854, 1.0047, 1.0212, 0.9810, 1.0320,\n", - " 1.0045, 1.0151, 0.9683, 0.9965, 1.0195, 0.9927, 0.9648, 0.9768, 1.0195,\n", - " 0.9632, 1.0413, 0.9660, 0.9954, 0.9986, 0.9687, 0.9979, 0.9959, 0.9528,\n", - " 0.9811, 0.9361, 0.9754, 0.9515, 1.0222, 0.9065, 0.9596, 1.0143, 0.9654,\n", - " 1.0097, 1.0257, 0.9529, 0.9987, 1.0023, 1.0235, 0.9968, 0.9391, 1.0246,\n", - " 1.0415, 0.8944, 1.0037, 0.9884, 0.9838, 1.0100, 0.9934, 1.0158, 0.9992,\n", - " 0.9945, 1.0269, 0.9988, 0.9746, 0.9849, 0.9837, 0.9853, 0.9961, 0.9961,\n", - " 0.9367, 1.0434, 1.0180, 0.9880, 1.0014, 0.9977, 1.0152, 0.9922, 0.9402,\n", - " 0.9401, 0.9328, 0.9936, 0.9267, 0.9819, 0.9703, 1.0273, 1.0675, 0.9921,\n", - " 1.0293, 0.9936, 0.9824, 0.9465, 0.9486, 0.9946, 0.9840, 1.0305, 1.0158,\n", - " 1.0217, 1.0431, 0.9639, 1.0107, 1.0006, 1.0016, 0.9956, 1.0443, 0.9965,\n", - " 1.0049, 1.0160, 1.0023, 0.9719, 1.0077, 0.9105, 1.0384, 1.0231, 1.0132,\n", - " 1.0314, 0.9939, 0.9616, 1.0067, 1.0498, 0.9761, 1.0339, 0.9703, 1.0237,\n", - " 0.9617, 0.9809, 1.0074, 0.9658, 0.9654, 0.9655, 0.9909, 0.9582, 0.9765,\n", - " 1.0340, 0.9436, 0.9821, 0.9701, 1.0471, 0.9938, 1.0092, 1.0402, 1.0519,\n", - " 0.9758, 0.9893, 1.0087, 0.9748, 1.0169, 1.0258, 0.9980, 0.9642, 0.9609,\n", - " 1.0336, 1.0385, 1.0035, 0.9217, 1.0398, 0.9457, 0.9696, 0.9964, 1.0307,\n", - " 0.9936, 0.9372, 0.9538, 0.9759, 0.9730, 1.0237, 0.9669, 1.0096])\n", - "layer4.0.bn1.bias Parameter containing:\n", - "tensor([-0.1313, -0.1636, -0.1204, -0.1072, -0.1112, -0.0995, -0.1602, -0.1600,\n", - " -0.1085, -0.1312, -0.1073, -0.1452, -0.0914, -0.1252, -0.1654, -0.1372,\n", - " -0.1688, -0.1468, -0.1604, -0.0992, -0.1659, -0.1345, -0.1312, -0.1844,\n", - " -0.1081, -0.1385, -0.0724, -0.1348, -0.1353, -0.1345, -0.1175, -0.1163,\n", - " -0.1461, -0.1335, -0.1096, -0.1257, -0.1304, -0.1433, -0.1206, -0.1464,\n", - " -0.1422, -0.1156, -0.1401, -0.1501, -0.1286, -0.1440, -0.1155, -0.1654,\n", - " -0.1292, -0.1382, -0.1352, -0.1472, -0.1765, -0.1114, -0.1139, -0.1491,\n", - " -0.1383, -0.1242, -0.1776, -0.1247, -0.1412, -0.1067, -0.1106, -0.1121,\n", - " -0.1345, -0.1257, -0.1265, -0.1531, -0.1490, -0.1013, -0.1468, -0.1372,\n", - " -0.1792, -0.1360, -0.1198, -0.1086, -0.1617, -0.1660, -0.1228, -0.1483,\n", - " -0.0945, -0.1110, -0.1253, -0.1089, -0.1319, -0.1490, -0.1170, -0.1364,\n", - " -0.1354, -0.1216, -0.1171, -0.1352, -0.1398, -0.1338, -0.1219, -0.1710,\n", - " -0.1311, -0.1457, -0.1274, -0.1073, -0.1173, -0.1264, -0.1163, -0.1548,\n", - " -0.1421, -0.1454, -0.1357, -0.1657, -0.1303, -0.1507, -0.1211, -0.1941,\n", - " -0.1222, -0.1104, -0.1506, -0.1259, -0.1086, -0.1291, -0.1060, -0.1140,\n", - " -0.1336, -0.1632, -0.1211, -0.1490, -0.1231, -0.1739, -0.1502, -0.1617,\n", - " -0.1118, -0.1397, -0.1093, -0.1283, -0.0986, -0.1305, -0.1638, -0.1666,\n", - " -0.1335, -0.1897, -0.1499, -0.1116, -0.1827, -0.1186, -0.1934, -0.1323,\n", - " -0.1338, -0.1509, -0.1360, -0.1442, -0.1689, -0.1680, -0.1318, -0.1438,\n", - " -0.1223, -0.1627, -0.1273, -0.1161, -0.1623, -0.1373, -0.1423, -0.1192,\n", - " -0.1059, -0.0972, -0.1507, -0.1390, -0.1647, -0.1244, -0.1644, -0.1498,\n", - " -0.1398, -0.1459, -0.1667, -0.1316, -0.1936, -0.1124, -0.1577, -0.1089,\n", - " -0.1290, -0.1378, -0.1233, -0.1526, -0.1156, -0.1673, -0.1438, -0.1725,\n", - " -0.1309, -0.1138, -0.1615, -0.1055, -0.1204, -0.1736, -0.1382, -0.1065,\n", - " -0.1216, -0.1760, -0.1800, -0.0827, -0.0992, -0.1225, -0.1353, -0.1260,\n", - " -0.1208, -0.1600, -0.1191, -0.1677, -0.1315, -0.1472, -0.1356, -0.1218,\n", - " -0.1187, -0.1620, -0.1179, -0.1591, -0.1496, -0.1036, -0.2044, -0.1428,\n", - " -0.1366, -0.1065, -0.1027, -0.1475, -0.1195, -0.1082, -0.1053, -0.1461,\n", - " -0.1691, -0.1706, -0.1300, -0.1864, -0.2089, -0.1170, -0.1480, -0.1576,\n", - " -0.1383, -0.1583, -0.1084, -0.1296, -0.1503, -0.1452, -0.1285, -0.1061,\n", - " -0.1148, -0.1278, -0.1292, -0.1543, -0.1608, -0.0848, -0.1095, -0.1184,\n", - " -0.1345, -0.1086, -0.1362, -0.1644, -0.1193, -0.1414, -0.1349, -0.1336,\n", - " -0.1059, -0.1538, -0.0908, -0.1335, -0.1084, -0.1521, -0.1355, -0.0934,\n", - " -0.1003, -0.1576, -0.1681, -0.1774, -0.1273, -0.1339, -0.1386, -0.1601,\n", - " -0.1350, -0.1508, -0.1035, -0.1223, -0.1325, -0.1116, -0.1383, -0.1226,\n", - " -0.1313, -0.1304, -0.1158, -0.1140, -0.1451, -0.1867, -0.1453, -0.1258,\n", - " -0.1335, -0.1790, -0.1468, -0.1700, -0.1195, -0.2057, -0.2011, -0.1837,\n", - " -0.1499, -0.1168, -0.1462, -0.2006, -0.1992, -0.1075, -0.1263, -0.1031,\n", - " -0.1248, -0.1572, -0.1477, -0.1345, -0.1228, -0.1207, -0.1314, -0.1488,\n", - " -0.2015, -0.1650, -0.1175, -0.1292, -0.1483, -0.1341, -0.1506, -0.1534,\n", - " -0.1612, -0.1565, -0.1284, -0.1653, -0.2038, -0.1381, -0.2177, -0.1264,\n", - " -0.1406, -0.1023, -0.1211, -0.1226, -0.1256, -0.1343, -0.1095, -0.1531,\n", - " -0.1050, -0.1293, -0.1657, -0.1386, -0.1151, -0.1445, -0.1054, -0.1113,\n", - " -0.1317, -0.1385, -0.1116, -0.1155, -0.1675, -0.1208, -0.1604, -0.1439,\n", - " -0.1369, -0.1378, -0.1473, -0.1286, -0.0956, -0.0966, -0.1479, -0.1595,\n", - " -0.1180, -0.1170, -0.1364, -0.1524, -0.1400, -0.1881, -0.1347, -0.1378,\n", - " -0.1508, -0.0998, -0.1386, -0.1385, -0.1280, -0.1320, -0.1358, -0.1685,\n", - " -0.1500, -0.1462, -0.1385, -0.1401, -0.1104, -0.1173, -0.1374, -0.1525,\n", - " -0.1615, -0.1578, -0.1641, -0.1375, -0.1990, -0.1294, -0.1518, -0.1708,\n", - " -0.1386, -0.1386, -0.1358, -0.0890, -0.1142, -0.1281, -0.1446, -0.1275,\n", - " -0.1392, -0.1407, -0.1318, -0.1416, -0.1491, -0.1211, -0.1878, -0.1545,\n", - " -0.1248, -0.1439, -0.1822, -0.1450, -0.1298, -0.1392, -0.1355, -0.1164,\n", - " -0.1082, -0.1086, -0.1520, -0.1792, -0.1319, -0.1390, -0.1580, -0.1233,\n", - " -0.1509, -0.1281, -0.1462, -0.1493, -0.1781, -0.1589, -0.1167, -0.1008,\n", - " -0.0953, -0.1499, -0.1674, -0.1244, -0.1516, -0.1418, -0.1308, -0.1808,\n", - " -0.0961, -0.1359, -0.1715, -0.1258, -0.1250, -0.1408, -0.1399, -0.1728,\n", - " -0.1399, -0.1178, -0.1152, -0.1127, -0.1319, -0.1285, -0.1043, -0.1546,\n", - " -0.1249, -0.1142, -0.1647, -0.1136, -0.0991, -0.1799, -0.1017, -0.1487,\n", - " -0.1128, -0.1135, -0.1450, -0.1197, -0.1501, -0.1437, -0.1226, -0.1115,\n", - " -0.1573, -0.1187, -0.2205, -0.1382, -0.1460, -0.0959, -0.1433, -0.1307,\n", - " -0.1407, -0.1565, -0.1459, -0.1592, -0.0840, -0.1889, -0.1479, -0.1313,\n", - " -0.1135, -0.1300, -0.1179, -0.0983, -0.1496, -0.1093, -0.1448, -0.1395,\n", - " -0.1300, -0.1570, -0.1539, -0.1475, -0.1276, -0.1268, -0.1325, -0.1324,\n", - " -0.1404, -0.1444, -0.1513, -0.1405, -0.1721, -0.1507, -0.1526, -0.0866])\n", - "layer4.0.conv2.weight Parameter containing:\n", - "tensor([[[[-1.7886e-02, -1.8696e-02, -6.5051e-03],\n", - " [ 7.2588e-03, -2.0011e-02, -6.7659e-03],\n", - " [-5.8594e-03, -2.5900e-02, -1.4812e-02]],\n", - "\n", - " [[-1.9013e-02, 1.5084e-03, 1.4686e-03],\n", - " [-4.1722e-03, 1.4034e-02, 6.9301e-03],\n", - " [ 3.9946e-03, 3.1264e-02, 1.7031e-02]],\n", - "\n", - " [[-4.3595e-03, -1.5497e-02, -7.2879e-04],\n", - " [-2.1562e-02, -1.8119e-02, -2.8660e-02],\n", - " [ 4.3782e-03, 1.1628e-02, 2.1428e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.3216e-02, 1.7132e-04, -1.1600e-02],\n", - " [ 1.0509e-02, 6.1292e-03, 9.3679e-04],\n", - " [ 9.9824e-03, -2.7226e-03, -7.6713e-04]],\n", - "\n", - " [[ 1.9380e-02, 4.5168e-04, -1.6261e-02],\n", - " [ 9.4663e-03, 9.7768e-03, 1.0472e-02],\n", - " [ 1.0579e-03, -1.0838e-02, -1.6651e-02]],\n", - "\n", - " [[ 7.1076e-03, -1.0444e-02, 1.2159e-02],\n", - " [-4.5764e-03, -6.8533e-03, -2.4884e-02],\n", - " [ 3.1004e-03, -5.1720e-03, -9.0846e-03]]],\n", - "\n", - "\n", - " [[[-9.9561e-03, -2.9929e-02, -1.3286e-02],\n", - " [ 1.2054e-02, 1.7240e-02, 3.4074e-02],\n", - " [-1.3281e-02, 2.6545e-02, 5.2089e-03]],\n", - "\n", - " [[ 1.0709e-02, -6.2842e-03, 3.2054e-03],\n", - " [-2.6670e-02, -1.0388e-03, -7.6723e-03],\n", - " [ 6.0955e-03, 2.3389e-02, 4.6433e-03]],\n", - "\n", - " [[ 3.2169e-03, -2.2192e-03, 1.3805e-02],\n", - " [ 1.7897e-02, 1.2070e-02, -1.4670e-02],\n", - " [ 6.3306e-03, 3.5517e-02, 3.2253e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[-3.2339e-04, -1.4459e-02, -6.0876e-03],\n", - " [ 1.4088e-02, 4.1493e-03, 2.0326e-03],\n", - " [ 2.6903e-02, 6.3054e-03, 1.0476e-02]],\n", - "\n", - " [[-1.3155e-02, 1.9230e-02, -2.7132e-02],\n", - " [ 1.4182e-02, 1.3118e-02, 6.7906e-04],\n", - " [ 1.9165e-02, -8.0848e-03, 1.1247e-02]],\n", - "\n", - " [[-3.5696e-03, -1.3386e-02, 1.5999e-02],\n", - " [-5.2860e-03, 4.1539e-03, -6.4848e-03],\n", - " [-9.2003e-04, 2.5272e-03, -2.2210e-02]]],\n", - "\n", - "\n", - " [[[-6.4596e-04, 2.6375e-03, 1.9750e-02],\n", - " [ 3.5908e-03, -9.7965e-03, 3.7559e-03],\n", - " [ 2.3296e-02, 1.5638e-02, 4.5195e-03]],\n", - "\n", - " [[-2.1403e-02, -1.8759e-02, 1.0075e-02],\n", - " [-4.0992e-02, -1.0487e-02, 4.4678e-03],\n", - " [-2.7970e-02, -1.7760e-02, -7.7252e-04]],\n", - "\n", - " [[-2.1670e-02, -1.3132e-02, 9.6461e-03],\n", - " [-3.7094e-02, 3.3480e-03, -2.0624e-02],\n", - " [-3.3094e-02, -2.6236e-02, -1.4729e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-2.3484e-02, 3.4918e-02, 5.6633e-03],\n", - " [ 1.3121e-02, -4.1489e-03, 1.8868e-02],\n", - " [ 5.6241e-03, 2.3452e-02, -1.3444e-02]],\n", - "\n", - " [[ 1.8781e-02, -6.7572e-03, -1.7906e-02],\n", - " [-2.1573e-02, 2.1743e-02, -1.1799e-02],\n", - " [-8.9948e-04, 9.6979e-03, 9.4103e-03]],\n", - "\n", - " [[ 3.0581e-04, -1.2199e-02, 1.0750e-02],\n", - " [-3.6209e-02, -1.1889e-03, 4.2709e-03],\n", - " [ 3.2966e-02, 1.4731e-02, 1.1646e-02]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[ 7.0561e-03, -1.1558e-02, -1.6542e-02],\n", - " [ 7.7660e-03, 2.1384e-02, -2.2963e-02],\n", - " [-2.4396e-02, 6.4466e-03, 2.3426e-02]],\n", - "\n", - " [[ 9.9911e-04, -1.6690e-02, -7.5850e-03],\n", - " [-1.5195e-03, -2.1704e-03, -3.2688e-02],\n", - " [-1.4812e-02, 9.4784e-06, -1.1086e-02]],\n", - "\n", - " [[ 1.3946e-02, 1.2512e-02, -4.8051e-03],\n", - " [ 1.3617e-02, 3.2708e-03, 2.3468e-02],\n", - " [ 7.1597e-03, -3.1580e-04, -1.6377e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-9.0539e-03, 1.7373e-02, -2.9568e-02],\n", - " [-4.2147e-03, -1.6024e-02, -1.4590e-02],\n", - " [-2.5991e-02, 1.3760e-02, -4.2937e-03]],\n", - "\n", - " [[ 2.5654e-02, -1.5972e-02, -1.3203e-02],\n", - " [ 2.3070e-03, -2.2005e-03, 8.4325e-03],\n", - " [ 8.5600e-03, -3.0322e-02, -3.4817e-02]],\n", - "\n", - " [[-2.3064e-02, 1.9884e-02, -6.6768e-03],\n", - " [-1.3571e-02, -8.7839e-03, -6.3031e-04],\n", - " [-1.0146e-03, 1.1744e-02, 2.2794e-02]]],\n", - "\n", - "\n", - " [[[-5.9354e-03, -2.4573e-02, -6.1468e-03],\n", - " [-1.7341e-02, 1.6610e-02, -2.3676e-03],\n", - " [-4.0607e-03, 2.3387e-02, -2.7001e-02]],\n", - "\n", - " [[ 2.9649e-02, 1.0233e-03, 1.1201e-02],\n", - " [-8.8892e-03, 3.3015e-03, 9.1620e-03],\n", - " [ 4.9614e-02, 1.6541e-02, 5.4178e-03]],\n", - "\n", - " [[ 5.7061e-03, -2.0687e-04, 1.8121e-02],\n", - " [ 9.9960e-03, -2.2996e-02, -3.8443e-02],\n", - " [ 2.9459e-02, -8.1446e-03, -7.2190e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[-2.1950e-02, -5.8360e-03, -8.6734e-03],\n", - " [-6.4084e-04, 1.6201e-02, -2.3953e-03],\n", - " [-9.0676e-04, 1.2011e-02, -1.3787e-02]],\n", - "\n", - " [[-7.7046e-04, 7.1940e-03, 1.6575e-02],\n", - " [-4.9565e-02, -1.0073e-02, -2.1843e-02],\n", - " [-4.4395e-03, -2.4704e-02, -3.3025e-02]],\n", - "\n", - " [[-2.1877e-02, 2.5838e-02, 1.5783e-02],\n", - " [-4.3521e-02, 4.0280e-03, -1.4476e-04],\n", - " [-1.0410e-03, 2.4805e-03, 7.8111e-03]]],\n", - "\n", - "\n", - " [[[ 6.3897e-03, -1.0151e-02, -7.8149e-03],\n", - " [ 3.2832e-03, -3.7016e-03, -1.6792e-03],\n", - " [ 3.0218e-04, -7.0481e-03, 1.4635e-02]],\n", - "\n", - " [[-6.8840e-03, -4.6635e-03, -1.7792e-02],\n", - " [-5.9637e-03, -1.8775e-02, 7.1388e-03],\n", - " [-1.8852e-02, 4.5739e-03, -8.0146e-04]],\n", - "\n", - " [[-8.8560e-03, 9.4655e-03, -2.5622e-03],\n", - " [-1.2442e-02, -5.4622e-03, 4.5498e-03],\n", - " [-1.0570e-02, -7.3137e-04, -2.4904e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-2.5443e-02, -8.4710e-03, -4.8651e-05],\n", - " [ 2.7532e-02, -7.2788e-03, 3.8603e-02],\n", - " [ 1.2313e-02, 3.8152e-03, 6.0188e-03]],\n", - "\n", - " [[ 4.4933e-03, 7.1814e-03, 1.5395e-02],\n", - " [-1.4873e-02, 1.3530e-02, -7.5244e-03],\n", - " [-5.2787e-03, -5.8126e-04, -2.5811e-02]],\n", - "\n", - " [[-2.2218e-02, -7.7884e-03, 2.2616e-03],\n", - " [-1.0254e-02, -2.6434e-02, 1.5597e-02],\n", - " [-1.9715e-02, 1.5517e-02, 1.6434e-02]]]])\n", - "layer4.0.bn2.weight Parameter containing:\n", - "tensor([0.9555, 0.9785, 0.9496, 0.9762, 0.9477, 0.9735, 0.9830, 0.9709, 0.9828,\n", - " 0.9800, 0.9667, 0.9463, 0.9382, 0.9838, 0.9651, 0.9838, 0.9959, 0.9737,\n", - " 0.9975, 0.9388, 1.0001, 0.9637, 0.9640, 0.9840, 0.9687, 0.9413, 0.9266,\n", - " 0.9813, 0.9410, 0.9559, 0.9835, 0.9626, 0.9521, 0.9745, 1.0076, 0.9600,\n", - " 0.9822, 0.9820, 0.9353, 0.9054, 0.9364, 0.9649, 0.9596, 0.9963, 0.9624,\n", - " 0.9268, 0.9765, 1.0038, 0.9494, 0.9595, 0.9965, 1.0301, 0.9637, 0.9209,\n", - " 0.9634, 0.9810, 1.0096, 0.9736, 0.9506, 0.9889, 0.9523, 0.9745, 0.9406,\n", - " 0.9872, 0.9760, 1.0253, 0.9682, 0.9641, 0.9850, 1.0172, 0.9645, 0.9324,\n", - " 0.9707, 0.9516, 0.9803, 0.8969, 0.9797, 0.9510, 0.9722, 0.9290, 0.9546,\n", - " 0.9707, 0.9955, 0.9561, 0.9662, 1.0180, 0.9996, 0.9805, 0.9548, 0.9760,\n", - " 0.8972, 0.9846, 0.9154, 0.9770, 0.9695, 0.9210, 0.9686, 0.9508, 0.9365,\n", - " 0.9962, 0.9534, 0.9397, 0.9752, 0.9787, 0.9633, 0.9820, 0.9829, 0.9563,\n", - " 1.0000, 0.9692, 1.0029, 0.9389, 0.9895, 0.9848, 0.9618, 0.9606, 0.8920,\n", - " 0.9950, 1.0084, 1.0011, 0.9852, 0.9696, 0.9810, 0.9393, 1.0241, 0.9766,\n", - " 0.9407, 1.0100, 0.9954, 1.0044, 0.9674, 0.9553, 0.9904, 0.9750, 1.0066,\n", - " 0.9549, 0.9426, 0.9569, 0.9396, 0.9754, 0.9547, 1.0431, 0.9609, 0.9645,\n", - " 0.9803, 0.9494, 0.9675, 1.0050, 1.0065, 0.9358, 0.9977, 1.0043, 1.0053,\n", - " 0.9672, 0.9477, 0.9334, 0.9452, 0.9477, 0.9697, 0.9561, 0.9601, 0.9548,\n", - " 1.0286, 0.9487, 0.9529, 0.9548, 0.9512, 0.9689, 0.9678, 0.9416, 0.9688,\n", - " 0.9717, 0.9849, 0.9832, 0.9674, 0.8993, 0.9652, 0.9835, 0.9444, 0.9778,\n", - " 0.9206, 0.9684, 0.9371, 0.9898, 0.9586, 1.0066, 0.9892, 0.9753, 0.9446,\n", - " 0.9486, 0.9731, 0.9771, 0.9552, 0.9496, 0.9854, 0.8991, 0.9244, 0.9771,\n", - " 0.9557, 0.9823, 0.9617, 0.9556, 0.9713, 0.9306, 0.9830, 0.9637, 0.9615,\n", - " 0.9432, 0.9936, 0.9965, 0.9939, 0.9662, 0.9847, 0.9459, 1.0106, 0.9978,\n", - " 0.9864, 0.9836, 0.9472, 0.9943, 0.9877, 0.9788, 0.9589, 0.9004, 0.9368,\n", - " 0.9213, 0.9875, 0.9673, 0.9724, 0.9362, 0.9608, 0.9686, 0.9624, 0.9685,\n", - " 0.9486, 0.9572, 0.9623, 0.9320, 0.9418, 0.9778, 0.9820, 0.9436, 0.9621,\n", - " 0.9683, 0.9400, 0.9803, 0.9238, 0.9475, 0.9266, 0.9434, 0.9181, 1.0136,\n", - " 0.9859, 0.9299, 0.9896, 0.9835, 0.9883, 0.9865, 0.9275, 1.0005, 0.9368,\n", - " 0.9942, 0.9573, 0.9808, 0.9787, 0.9534, 1.0137, 0.8962, 0.9409, 0.9807,\n", - " 0.9453, 1.0381, 0.9634, 0.9773, 0.9643, 0.9484, 0.9605, 0.9253, 0.9943,\n", - " 0.9505, 0.9833, 0.9996, 0.9519, 0.9952, 1.0069, 0.9695, 0.9791, 0.9849,\n", - " 0.9753, 1.0021, 0.9570, 0.9735, 0.9800, 1.0057, 0.9682, 0.9981, 0.9178,\n", - " 0.9467, 1.0186, 0.9955, 0.9593, 0.9700, 0.9298, 0.9841, 0.9576, 0.9683,\n", - " 0.9715, 0.9853, 0.9751, 0.9591, 0.9580, 0.9272, 0.9904, 0.9850, 0.9661,\n", - " 0.9751, 1.0004, 0.9607, 0.9932, 0.9582, 0.9322, 0.9509, 1.0128, 0.9531,\n", - " 0.9501, 0.9875, 0.9256, 0.9406, 0.9517, 0.9849, 0.9961, 0.9599, 0.9165,\n", - " 0.9653, 0.9585, 0.9558, 0.9775, 0.9540, 0.9849, 0.9814, 0.9843, 0.9834,\n", - " 0.9572, 0.9750, 1.0024, 0.9440, 0.9530, 0.9244, 0.9580, 0.9388, 0.9935,\n", - " 0.9627, 0.9777, 0.9621, 0.9731, 0.9123, 0.9389, 0.9417, 0.9847, 0.9396,\n", - " 0.9555, 0.9516, 1.0058, 0.9725, 0.8901, 0.9674, 0.9561, 0.9598, 1.0162,\n", - " 0.9572, 0.9861, 0.9811, 0.9995, 1.0045, 0.9723, 0.9668, 0.9831, 0.9109,\n", - " 0.9407, 0.9635, 0.9981, 0.9481, 0.9573, 0.9463, 0.9792, 0.9070, 0.9721,\n", - " 0.9638, 0.9737, 0.9175, 1.0011, 0.9757, 1.0017, 0.9957, 0.9562, 0.9789,\n", - " 0.9631, 0.9546, 0.9913, 0.9549, 1.0033, 0.9550, 0.9639, 0.9441, 0.9702,\n", - " 0.9904, 0.9402, 0.9744, 1.0061, 0.9798, 1.0017, 1.0137, 1.0099, 0.9962,\n", - " 0.9511, 0.9597, 1.0050, 0.9719, 0.9704, 0.9813, 0.9900, 0.9667, 0.9582,\n", - " 1.0062, 0.9669, 1.0110, 0.9397, 1.0132, 0.9442, 0.9814, 0.8946, 0.9956,\n", - " 0.9420, 0.9621, 0.9445, 0.9682, 0.9952, 0.9799, 0.9466, 0.9718, 0.9509,\n", - " 0.9710, 0.9798, 0.9832, 0.9664, 0.9887, 0.9617, 0.9754, 1.0175, 0.9922,\n", - " 1.0063, 0.9900, 0.9609, 0.9710, 0.9622, 0.9798, 0.9414, 1.0162, 0.9600,\n", - " 0.9837, 1.0000, 0.9270, 0.9603, 0.9866, 0.9681, 1.0009, 0.9595, 0.9672,\n", - " 0.9514, 0.9584, 0.9382, 0.9560, 0.9736, 0.9953, 0.9513, 0.9762, 0.9818,\n", - " 0.9787, 0.9892, 0.9558, 0.9719, 0.9303, 0.9718, 0.9597, 0.9713, 0.9641,\n", - " 0.9278, 0.9673, 0.9914, 0.9579, 0.9361, 1.0325, 0.9826, 0.9888, 0.9810,\n", - " 0.9229, 1.0155, 1.0057, 0.9610, 0.9168, 0.9803, 0.9771, 0.9940, 0.9315,\n", - " 0.9625, 0.9365, 0.9514, 0.9863, 0.9985, 1.0196, 1.0129, 0.9435])\n", - "layer4.0.bn2.bias Parameter containing:\n", - "tensor([-0.1077, -0.1421, -0.1821, -0.1668, -0.1509, -0.1312, -0.1380, -0.1309,\n", - " -0.1376, -0.0952, -0.1164, -0.1412, -0.1481, -0.1340, -0.0982, -0.1111,\n", - " -0.1355, -0.1359, -0.0763, -0.1281, -0.1032, -0.1161, -0.1639, -0.1614,\n", - " -0.1359, -0.1743, -0.1449, -0.1202, -0.1613, -0.1218, -0.0912, -0.1359,\n", - " -0.1473, -0.1254, -0.1260, -0.1315, -0.1130, -0.1370, -0.1554, -0.1598,\n", - " -0.1412, -0.1671, -0.1722, -0.1153, -0.1522, -0.1593, -0.1292, -0.1383,\n", - " -0.1708, -0.1354, -0.1689, -0.1445, -0.1057, -0.1420, -0.1475, -0.1008,\n", - " -0.1661, -0.1452, -0.1685, -0.0975, -0.1722, -0.1182, -0.0939, -0.1553,\n", - " -0.1652, -0.1251, -0.1249, -0.1043, -0.1387, -0.1344, -0.1790, -0.1616,\n", - " -0.1750, -0.1710, -0.1289, -0.1422, -0.1303, -0.1413, -0.1527, -0.1282,\n", - " -0.1279, -0.1226, -0.1151, -0.1201, -0.1322, -0.1368, -0.1222, -0.1290,\n", - " -0.1492, -0.1467, -0.1736, -0.1057, -0.1849, -0.0941, -0.1291, -0.1398,\n", - " -0.1382, -0.1133, -0.1222, -0.1267, -0.1743, -0.1167, -0.1169, -0.1327,\n", - " -0.1233, -0.1442, -0.1083, -0.1543, -0.1278, -0.1602, -0.1623, -0.1351,\n", - " -0.1327, -0.1430, -0.1688, -0.1055, -0.1419, -0.1694, -0.1434, -0.1472,\n", - " -0.1157, -0.1287, -0.1160, -0.0923, -0.1165, -0.1386, -0.1188, -0.1284,\n", - " -0.1347, -0.1452, -0.1594, -0.1360, -0.1470, -0.1863, -0.1705, -0.1492,\n", - " -0.1368, -0.1487, -0.1645, -0.1125, -0.1228, -0.1313, -0.1517, -0.1273,\n", - " -0.1473, -0.1441, -0.1204, -0.1410, -0.1254, -0.1298, -0.1175, -0.1477,\n", - " -0.1687, -0.1658, -0.1260, -0.1066, -0.1634, -0.1069, -0.1038, -0.1724,\n", - " -0.1575, -0.1456, -0.1222, -0.1235, -0.1357, -0.1030, -0.1516, -0.1180,\n", - " -0.2283, -0.1081, -0.0908, -0.1186, -0.1629, -0.1338, -0.1590, -0.1619,\n", - " -0.1434, -0.1391, -0.1531, -0.1400, -0.1208, -0.1121, -0.1173, -0.1413,\n", - " -0.1178, -0.1539, -0.1020, -0.1621, -0.1607, -0.1491, -0.1088, -0.0963,\n", - " -0.1369, -0.1579, -0.1558, -0.1364, -0.1629, -0.1167, -0.1064, -0.1564,\n", - " -0.1016, -0.1513, -0.1344, -0.1452, -0.1330, -0.1143, -0.1711, -0.1544,\n", - " -0.1010, -0.1073, -0.1114, -0.1060, -0.2010, -0.1259, -0.1225, -0.1472,\n", - " -0.1328, -0.2026, -0.1359, -0.1276, -0.1133, -0.1622, -0.1322, -0.1666,\n", - " -0.1334, -0.1379, -0.1284, -0.1683, -0.1491, -0.1184, -0.1172, -0.1208,\n", - " -0.1506, -0.1204, -0.1467, -0.1479, -0.1494, -0.0968, -0.1067, -0.1380,\n", - " -0.1635, -0.0869, -0.1374, -0.1196, -0.1598, -0.1280, -0.1609, -0.1608,\n", - " -0.1404, -0.1249, -0.2054, -0.1497, -0.1637, -0.1352, -0.1547, -0.1162,\n", - " -0.1721, -0.1412, -0.1409, -0.1169, -0.1034, -0.1353, -0.1230, -0.1313,\n", - " -0.1110, -0.1485, -0.1141, -0.1649, -0.1388, -0.0882, -0.1378, -0.1768,\n", - " -0.1178, -0.1274, -0.1642, -0.1661, -0.1437, -0.1595, -0.1388, -0.2104,\n", - " -0.1090, -0.1467, -0.1431, -0.1173, -0.1046, -0.1935, -0.1478, -0.1360,\n", - " -0.1500, -0.1175, -0.1217, -0.1539, -0.1516, -0.1312, -0.1663, -0.0784,\n", - " -0.1408, -0.1435, -0.1334, -0.1537, -0.1486, -0.1203, -0.1474, -0.1198,\n", - " -0.1518, -0.1653, -0.1304, -0.1242, -0.1329, -0.1106, -0.1312, -0.1558,\n", - " -0.1402, -0.0926, -0.1435, -0.0997, -0.1235, -0.1709, -0.1523, -0.1424,\n", - " -0.1248, -0.1423, -0.1372, -0.1385, -0.1349, -0.1320, -0.1480, -0.1883,\n", - " -0.1528, -0.1106, -0.1115, -0.1545, -0.2491, -0.1437, -0.1366, -0.1148,\n", - " -0.1526, -0.0981, -0.1287, -0.0970, -0.1171, -0.1253, -0.1360, -0.1191,\n", - " -0.1091, -0.1361, -0.1243, -0.1733, -0.1326, -0.1444, -0.1465, -0.1514,\n", - " -0.1156, -0.1539, -0.1099, -0.1416, -0.1336, -0.1725, -0.1333, -0.1748,\n", - " -0.1506, -0.1016, -0.0820, -0.1540, -0.1404, -0.1233, -0.1787, -0.1224,\n", - " -0.1389, -0.1248, -0.1288, -0.1181, -0.1983, -0.1307, -0.1433, -0.1397,\n", - " -0.1570, -0.1298, -0.1285, -0.1475, -0.1472, -0.0788, -0.0871, -0.1186,\n", - " -0.1054, -0.1655, -0.1623, -0.1446, -0.0858, -0.1977, -0.1148, -0.1272,\n", - " -0.1156, -0.1302, -0.0837, -0.1272, -0.1368, -0.1431, -0.1257, -0.1206,\n", - " -0.1128, -0.1089, -0.1071, -0.1777, -0.1622, -0.1102, -0.2165, -0.1409,\n", - " -0.2035, -0.0944, -0.1122, -0.1559, -0.1246, -0.1746, -0.1242, -0.1357,\n", - " -0.1038, -0.1866, -0.1053, -0.1582, -0.1073, -0.1082, -0.0941, -0.1234,\n", - " -0.1571, -0.1522, -0.1285, -0.1717, -0.1760, -0.1226, -0.1873, -0.1178,\n", - " -0.1140, -0.1283, -0.1302, -0.1645, -0.1375, -0.1337, -0.1517, -0.1147,\n", - " -0.1548, -0.1192, -0.1427, -0.1613, -0.1634, -0.1307, -0.1150, -0.1227,\n", - " -0.1003, -0.1405, -0.1071, -0.1345, -0.1354, -0.1312, -0.0875, -0.1288,\n", - " -0.1407, -0.1009, -0.1498, -0.1397, -0.1114, -0.1694, -0.1349, -0.1294,\n", - " -0.1948, -0.1227, -0.1455, -0.1091, -0.1289, -0.0721, -0.1536, -0.1525,\n", - " -0.1659, -0.1372, -0.1063, -0.1048, -0.1303, -0.1184, -0.2030, -0.1300,\n", - " -0.1413, -0.1604, -0.1453, -0.1992, -0.1359, -0.1073, -0.1755, -0.1209,\n", - " -0.1073, -0.1238, -0.1732, -0.1200, -0.0910, -0.1428, -0.1074, -0.1454,\n", - " -0.1542, -0.1509, -0.1703, -0.1571, -0.1164, -0.1268, -0.1080, -0.1576,\n", - " -0.1055, -0.1228, -0.1551, -0.1254, -0.1390, -0.1149, -0.1284, -0.1413])\n", - "layer4.0.shortcut_conv.weight Parameter containing:\n", - "tensor([[[[-0.0372]],\n", - "\n", - " [[-0.0450]],\n", - "\n", - " [[-0.0141]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0531]],\n", - "\n", - " [[ 0.0417]],\n", - "\n", - " [[ 0.0791]]],\n", - "\n", - "\n", - " [[[-0.0139]],\n", - "\n", - " [[ 0.0392]],\n", - "\n", - " [[-0.0129]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.0276]],\n", - "\n", - " [[ 0.0577]],\n", - "\n", - " [[ 0.0175]]],\n", - "\n", - "\n", - " [[[-0.0635]],\n", - "\n", - " [[ 0.0519]],\n", - "\n", - " [[-0.0476]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0614]],\n", - "\n", - " [[-0.0259]],\n", - "\n", - " [[-0.0251]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[ 0.0758]],\n", - "\n", - " [[ 0.0010]],\n", - "\n", - " [[-0.0365]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0223]],\n", - "\n", - " [[-0.0071]],\n", - "\n", - " [[ 0.0088]]],\n", - "\n", - "\n", - " [[[-0.0509]],\n", - "\n", - " [[-0.0467]],\n", - "\n", - " [[ 0.0209]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0022]],\n", - "\n", - " [[-0.0226]],\n", - "\n", - " [[ 0.0770]]],\n", - "\n", - "\n", - " [[[-0.0103]],\n", - "\n", - " [[ 0.0359]],\n", - "\n", - " [[-0.0126]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.0311]],\n", - "\n", - " [[ 0.0278]],\n", - "\n", - " [[-0.0614]]]])\n", - "layer4.0.shortcut_bn.weight Parameter containing:\n", - "tensor([0.9454, 0.9364, 0.8905, 0.9216, 0.9349, 0.9505, 0.9566, 0.9267, 0.9410,\n", - " 0.9398, 0.9523, 0.9230, 0.9259, 0.9375, 0.9117, 0.9437, 0.9479, 0.9363,\n", - " 0.9458, 0.9379, 0.9648, 0.9435, 0.9325, 0.9490, 0.9359, 0.9109, 0.9319,\n", - " 0.9613, 0.9306, 0.9377, 0.9287, 0.9642, 0.9484, 0.9523, 0.9468, 0.9423,\n", - " 0.9534, 0.9389, 0.9448, 0.8943, 0.9367, 0.9419, 0.9420, 0.9698, 0.9606,\n", - " 0.9161, 0.9627, 0.9337, 0.9162, 0.9507, 0.9376, 0.9306, 0.9499, 0.9384,\n", - " 0.9663, 0.9517, 0.9150, 0.9380, 0.9580, 0.9460, 0.9379, 0.9400, 0.9419,\n", - " 0.9045, 0.9396, 0.9372, 0.9189, 0.9340, 0.9214, 0.9442, 0.9297, 0.9199,\n", - " 0.9387, 0.9416, 0.9435, 0.9209, 0.9542, 0.9221, 0.9461, 0.9202, 0.9453,\n", - " 0.9545, 0.9388, 0.9386, 0.9114, 0.9444, 0.9671, 0.9598, 0.8984, 0.9551,\n", - " 0.9159, 0.9347, 0.9070, 0.9479, 0.9400, 0.9062, 0.9460, 0.9422, 0.9532,\n", - " 0.9363, 0.9118, 0.9303, 0.9167, 0.9567, 0.9494, 0.9324, 0.9519, 0.9788,\n", - " 0.9430, 0.9440, 0.8834, 0.9436, 0.9480, 0.9337, 0.9357, 0.9402, 0.8856,\n", - " 0.9098, 0.9426, 0.9636, 0.9661, 0.9518, 0.9420, 0.9275, 0.9252, 0.9419,\n", - " 0.9410, 0.9383, 0.9162, 0.9629, 0.9515, 0.9156, 0.9188, 0.9282, 0.9260,\n", - " 0.9329, 0.9382, 0.9005, 0.9181, 0.9358, 0.9311, 0.9387, 0.9482, 0.9345,\n", - " 0.9523, 0.9490, 0.9399, 0.9148, 0.9448, 0.9227, 0.9735, 0.9258, 0.9416,\n", - " 0.9587, 0.9017, 0.9389, 0.9152, 0.9303, 0.9709, 0.9155, 0.9271, 0.9396,\n", - " 0.9448, 0.9264, 0.9333, 0.9276, 0.9324, 0.9544, 0.9016, 0.9318, 0.9596,\n", - " 0.9114, 0.9449, 0.9452, 0.9285, 0.9156, 0.9248, 0.9164, 0.9393, 0.9621,\n", - " 0.9354, 0.9452, 0.9382, 0.9522, 0.9404, 0.9453, 0.9509, 0.9473, 0.9435,\n", - " 0.9179, 0.9595, 0.9643, 0.9439, 0.9360, 0.9280, 0.9137, 0.9217, 0.9421,\n", - " 0.9285, 0.9196, 0.9618, 0.9617, 0.9308, 0.9222, 0.9328, 0.9397, 0.9418,\n", - " 0.9508, 0.9493, 0.9743, 0.9421, 0.9489, 0.9166, 0.9135, 0.9594, 0.9322,\n", - " 0.9255, 0.8959, 0.9352, 0.9526, 0.9536, 0.9314, 0.9578, 0.8772, 0.9699,\n", - " 0.9241, 0.9499, 0.9198, 0.9376, 0.9467, 0.9394, 0.9376, 0.9464, 0.9573,\n", - " 0.9343, 0.9149, 0.9295, 0.9281, 0.9246, 0.9344, 0.9389, 0.9547, 0.9402,\n", - " 0.9658, 0.9411, 0.9382, 0.9156, 0.9437, 0.9275, 0.9152, 0.9184, 0.9217,\n", - " 0.9339, 0.9276, 0.9419, 0.9479, 0.9600, 0.9451, 0.9488, 0.9270, 0.9495,\n", - " 0.9479, 0.9402, 0.9447, 0.9458, 0.9371, 0.9498, 0.8992, 0.9151, 0.9553,\n", - " 0.9536, 0.9450, 0.9511, 0.9409, 0.9393, 0.9152, 0.9384, 0.9335, 0.9241,\n", - " 0.8940, 0.9487, 0.9354, 0.9424, 0.9438, 0.9535, 0.9028, 0.9661, 0.9582,\n", - " 0.9399, 0.9600, 0.9415, 0.9428, 0.9680, 0.9553, 0.9248, 0.9557, 0.9113,\n", - " 0.9374, 0.9374, 0.9461, 0.9443, 0.9416, 0.9279, 0.9388, 0.9399, 0.9466,\n", - " 0.9435, 0.9422, 0.9609, 0.9509, 0.9310, 0.9209, 0.9452, 0.9577, 0.9346,\n", - " 0.9400, 0.9428, 0.9525, 0.9346, 0.9641, 0.9349, 0.9470, 0.9137, 0.9350,\n", - " 0.9322, 0.9196, 0.9232, 0.9252, 0.9330, 0.9467, 0.9543, 0.9248, 0.8629,\n", - " 0.9180, 0.9414, 0.9429, 0.9231, 0.9056, 0.9382, 0.9187, 0.9495, 0.9341,\n", - " 0.9322, 0.9511, 0.9465, 0.9153, 0.9365, 0.9653, 0.9359, 0.9252, 0.9117,\n", - " 0.9598, 0.9450, 0.9524, 0.9153, 0.9032, 0.9297, 0.8921, 0.9509, 0.9215,\n", - " 0.9579, 0.9169, 0.9369, 0.9061, 0.9091, 0.9617, 0.9450, 0.9466, 0.9391,\n", - " 0.9113, 0.9307, 0.9253, 0.9217, 0.9238, 0.9481, 0.9549, 0.9279, 0.9138,\n", - " 0.9555, 0.9042, 0.9309, 0.9231, 0.9194, 0.9409, 0.9385, 0.9175, 0.9201,\n", - " 0.9591, 0.9504, 0.9133, 0.9520, 0.9449, 0.9599, 0.9433, 0.9160, 0.9297,\n", - " 0.9570, 0.9110, 0.9338, 0.9370, 0.9589, 0.9375, 0.9261, 0.9169, 0.9310,\n", - " 0.9374, 0.9363, 0.9514, 0.9272, 0.9535, 0.9502, 0.9374, 0.9295, 0.9511,\n", - " 0.9599, 0.9449, 0.9497, 0.8779, 0.9328, 0.9454, 0.9495, 0.9331, 0.9647,\n", - " 0.9698, 0.9484, 0.9300, 0.9280, 0.9345, 0.9237, 0.9337, 0.9030, 0.9565,\n", - " 0.9279, 0.9271, 0.9723, 0.9468, 0.9512, 0.9526, 0.9374, 0.9558, 0.9394,\n", - " 0.9587, 0.9495, 0.9402, 0.9312, 0.9395, 0.9419, 0.9758, 0.9342, 0.9564,\n", - " 0.9520, 0.9487, 0.9542, 0.9476, 0.9504, 0.9469, 0.9274, 0.9562, 0.9309,\n", - " 0.9451, 0.9384, 0.9386, 0.9480, 0.9413, 0.9142, 0.9645, 0.9315, 0.9449,\n", - " 0.9300, 0.9518, 0.9041, 0.9299, 0.9298, 0.9482, 0.9564, 0.9643, 0.9353,\n", - " 0.9628, 0.9437, 0.9412, 0.9463, 0.9266, 0.9471, 0.9221, 0.9278, 0.9626,\n", - " 0.9441, 0.9537, 0.9464, 0.9413, 0.9305, 0.9370, 0.9566, 0.9322, 0.9464,\n", - " 0.9445, 0.9437, 0.9459, 0.9764, 0.9289, 0.9477, 0.9661, 0.9630, 0.9282,\n", - " 0.9523, 0.9287, 0.9353, 0.9453, 0.9447, 0.9469, 0.9324, 0.9427])\n", - "layer4.0.shortcut_bn.bias Parameter containing:\n", - "tensor([-0.1077, -0.1421, -0.1821, -0.1668, -0.1509, -0.1312, -0.1380, -0.1309,\n", - " -0.1376, -0.0952, -0.1164, -0.1412, -0.1481, -0.1340, -0.0982, -0.1111,\n", - " -0.1355, -0.1359, -0.0763, -0.1281, -0.1032, -0.1161, -0.1639, -0.1614,\n", - " -0.1359, -0.1743, -0.1449, -0.1202, -0.1613, -0.1218, -0.0912, -0.1359,\n", - " -0.1473, -0.1254, -0.1260, -0.1315, -0.1130, -0.1370, -0.1554, -0.1598,\n", - " -0.1412, -0.1671, -0.1722, -0.1153, -0.1522, -0.1593, -0.1292, -0.1383,\n", - " -0.1708, -0.1354, -0.1689, -0.1445, -0.1057, -0.1420, -0.1475, -0.1008,\n", - " -0.1661, -0.1452, -0.1685, -0.0975, -0.1722, -0.1182, -0.0939, -0.1553,\n", - " -0.1652, -0.1251, -0.1249, -0.1043, -0.1387, -0.1344, -0.1790, -0.1616,\n", - " -0.1750, -0.1710, -0.1289, -0.1422, -0.1303, -0.1413, -0.1527, -0.1282,\n", - " -0.1279, -0.1226, -0.1151, -0.1201, -0.1322, -0.1368, -0.1222, -0.1290,\n", - " -0.1492, -0.1467, -0.1736, -0.1057, -0.1849, -0.0941, -0.1291, -0.1398,\n", - " -0.1382, -0.1133, -0.1222, -0.1267, -0.1743, -0.1167, -0.1169, -0.1327,\n", - " -0.1233, -0.1442, -0.1083, -0.1543, -0.1278, -0.1602, -0.1623, -0.1351,\n", - " -0.1327, -0.1430, -0.1688, -0.1055, -0.1419, -0.1694, -0.1434, -0.1472,\n", - " -0.1157, -0.1287, -0.1160, -0.0923, -0.1165, -0.1386, -0.1188, -0.1284,\n", - " -0.1347, -0.1452, -0.1594, -0.1360, -0.1470, -0.1863, -0.1705, -0.1492,\n", - " -0.1368, -0.1487, -0.1645, -0.1125, -0.1228, -0.1313, -0.1517, -0.1273,\n", - " -0.1473, -0.1441, -0.1204, -0.1410, -0.1254, -0.1298, -0.1175, -0.1477,\n", - " -0.1687, -0.1658, -0.1260, -0.1066, -0.1634, -0.1069, -0.1038, -0.1724,\n", - " -0.1575, -0.1456, -0.1222, -0.1235, -0.1357, -0.1030, -0.1516, -0.1180,\n", - " -0.2283, -0.1081, -0.0908, -0.1186, -0.1629, -0.1338, -0.1590, -0.1619,\n", - " -0.1434, -0.1391, -0.1531, -0.1400, -0.1208, -0.1121, -0.1173, -0.1413,\n", - " -0.1178, -0.1539, -0.1020, -0.1621, -0.1607, -0.1491, -0.1088, -0.0963,\n", - " -0.1369, -0.1579, -0.1558, -0.1364, -0.1629, -0.1167, -0.1064, -0.1564,\n", - " -0.1016, -0.1513, -0.1344, -0.1452, -0.1330, -0.1143, -0.1711, -0.1544,\n", - " -0.1010, -0.1073, -0.1114, -0.1060, -0.2010, -0.1259, -0.1225, -0.1472,\n", - " -0.1328, -0.2026, -0.1359, -0.1276, -0.1133, -0.1622, -0.1322, -0.1666,\n", - " -0.1334, -0.1379, -0.1284, -0.1683, -0.1491, -0.1184, -0.1172, -0.1208,\n", - " -0.1506, -0.1204, -0.1467, -0.1479, -0.1494, -0.0968, -0.1067, -0.1380,\n", - " -0.1635, -0.0869, -0.1374, -0.1196, -0.1598, -0.1280, -0.1609, -0.1608,\n", - " -0.1404, -0.1249, -0.2054, -0.1497, -0.1637, -0.1352, -0.1547, -0.1162,\n", - " -0.1721, -0.1412, -0.1409, -0.1169, -0.1034, -0.1353, -0.1230, -0.1313,\n", - " -0.1110, -0.1485, -0.1141, -0.1649, -0.1388, -0.0882, -0.1378, -0.1768,\n", - " -0.1178, -0.1274, -0.1642, -0.1661, -0.1437, -0.1595, -0.1388, -0.2104,\n", - " -0.1090, -0.1467, -0.1431, -0.1173, -0.1046, -0.1935, -0.1478, -0.1360,\n", - " -0.1500, -0.1175, -0.1217, -0.1539, -0.1516, -0.1312, -0.1663, -0.0784,\n", - " -0.1408, -0.1435, -0.1334, -0.1537, -0.1486, -0.1203, -0.1474, -0.1198,\n", - " -0.1518, -0.1653, -0.1304, -0.1242, -0.1329, -0.1106, -0.1312, -0.1558,\n", - " -0.1402, -0.0926, -0.1435, -0.0997, -0.1235, -0.1709, -0.1523, -0.1424,\n", - " -0.1248, -0.1423, -0.1372, -0.1385, -0.1349, -0.1320, -0.1480, -0.1883,\n", - " -0.1528, -0.1106, -0.1115, -0.1545, -0.2491, -0.1437, -0.1366, -0.1148,\n", - " -0.1526, -0.0981, -0.1287, -0.0970, -0.1171, -0.1253, -0.1360, -0.1191,\n", - " -0.1091, -0.1361, -0.1243, -0.1733, -0.1326, -0.1444, -0.1465, -0.1514,\n", - " -0.1156, -0.1539, -0.1099, -0.1416, -0.1336, -0.1725, -0.1333, -0.1748,\n", - " -0.1506, -0.1016, -0.0820, -0.1540, -0.1404, -0.1233, -0.1787, -0.1224,\n", - " -0.1389, -0.1248, -0.1288, -0.1181, -0.1983, -0.1307, -0.1433, -0.1397,\n", - " -0.1570, -0.1298, -0.1285, -0.1475, -0.1472, -0.0788, -0.0871, -0.1186,\n", - " -0.1054, -0.1655, -0.1623, -0.1446, -0.0858, -0.1977, -0.1148, -0.1272,\n", - " -0.1156, -0.1302, -0.0837, -0.1272, -0.1368, -0.1431, -0.1257, -0.1206,\n", - " -0.1128, -0.1089, -0.1071, -0.1777, -0.1622, -0.1102, -0.2165, -0.1409,\n", - " -0.2035, -0.0944, -0.1122, -0.1559, -0.1246, -0.1746, -0.1242, -0.1357,\n", - " -0.1038, -0.1866, -0.1053, -0.1582, -0.1073, -0.1082, -0.0941, -0.1234,\n", - " -0.1571, -0.1522, -0.1285, -0.1717, -0.1760, -0.1226, -0.1873, -0.1178,\n", - " -0.1140, -0.1283, -0.1302, -0.1645, -0.1375, -0.1337, -0.1517, -0.1147,\n", - " -0.1548, -0.1192, -0.1427, -0.1613, -0.1634, -0.1307, -0.1150, -0.1227,\n", - " -0.1003, -0.1405, -0.1071, -0.1345, -0.1354, -0.1312, -0.0875, -0.1288,\n", - " -0.1407, -0.1009, -0.1498, -0.1397, -0.1114, -0.1694, -0.1349, -0.1294,\n", - " -0.1948, -0.1227, -0.1455, -0.1091, -0.1289, -0.0721, -0.1536, -0.1525,\n", - " -0.1659, -0.1372, -0.1063, -0.1048, -0.1303, -0.1184, -0.2030, -0.1300,\n", - " -0.1413, -0.1604, -0.1453, -0.1992, -0.1359, -0.1073, -0.1755, -0.1209,\n", - " -0.1073, -0.1238, -0.1732, -0.1200, -0.0910, -0.1428, -0.1074, -0.1454,\n", - " -0.1542, -0.1509, -0.1703, -0.1571, -0.1164, -0.1268, -0.1080, -0.1576,\n", - " -0.1055, -0.1228, -0.1551, -0.1254, -0.1390, -0.1149, -0.1284, -0.1413])\n", - "layer4.1.conv1.weight Parameter containing:\n", - "tensor([[[[-1.1948e-02, -3.3142e-03, 1.0257e-02],\n", - " [-1.0413e-02, 4.3039e-03, 3.6476e-03],\n", - " [ 8.4051e-03, 1.0929e-02, -1.2016e-02]],\n", - "\n", - " [[-2.0995e-04, 1.1175e-02, -9.2195e-03],\n", - " [ 1.3481e-03, -1.4488e-02, -2.9663e-02],\n", - " [-4.7732e-03, -1.0124e-02, -1.2327e-02]],\n", - "\n", - " [[-2.1531e-02, -5.3558e-03, -1.5935e-02],\n", - " [-1.2236e-02, -1.0232e-02, -3.2332e-02],\n", - " [-2.7034e-02, -2.6549e-02, -2.6104e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 5.3538e-04, 1.6756e-02, -1.2413e-02],\n", - " [ 1.9543e-03, 3.2035e-03, -1.4134e-02],\n", - " [-2.7545e-04, -2.1481e-03, 2.1559e-02]],\n", - "\n", - " [[-9.9980e-03, 7.8874e-03, -8.9226e-03],\n", - " [ 6.6112e-03, -2.9001e-03, -1.2756e-02],\n", - " [ 2.3838e-04, 5.3704e-03, -7.9757e-03]],\n", - "\n", - " [[ 3.8782e-03, -8.8496e-03, 2.8949e-02],\n", - " [-2.1215e-02, 1.0959e-02, -1.3079e-03],\n", - " [ 2.4892e-04, -2.5248e-03, -1.0246e-02]]],\n", - "\n", - "\n", - " [[[ 5.5631e-03, -9.3964e-04, 4.2188e-03],\n", - " [-1.3459e-02, -6.8220e-03, -1.7020e-02],\n", - " [-4.0890e-02, -2.8097e-02, -1.3561e-02]],\n", - "\n", - " [[-6.0051e-03, -7.8186e-03, 4.0130e-03],\n", - " [-4.4814e-04, -2.0547e-02, -1.1539e-02],\n", - " [ 1.2592e-02, -1.8088e-02, -3.7631e-03]],\n", - "\n", - " [[-8.2642e-03, -5.4152e-03, 2.7797e-03],\n", - " [-1.3859e-03, 8.2239e-03, 4.0218e-04],\n", - " [ 8.4035e-03, 1.4251e-02, 4.7447e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[ 2.0392e-02, 7.7357e-03, 4.1268e-03],\n", - " [-6.8086e-03, 9.4403e-03, 5.3178e-03],\n", - " [ 1.9140e-02, -3.3532e-02, 1.0914e-03]],\n", - "\n", - " [[ 2.3713e-03, 2.2750e-02, 7.3886e-03],\n", - " [-1.9896e-02, -8.8388e-03, 1.8477e-02],\n", - " [-1.7475e-02, 1.4251e-03, 1.3589e-02]],\n", - "\n", - " [[ 1.3121e-02, 4.1141e-02, 7.3961e-03],\n", - " [ 1.3370e-02, 6.1834e-03, 2.0109e-02],\n", - " [ 2.5821e-02, 1.9590e-02, 2.4656e-02]]],\n", - "\n", - "\n", - " [[[ 9.9412e-03, 1.5238e-02, 2.0204e-02],\n", - " [ 4.8507e-03, 1.7608e-02, 5.5749e-03],\n", - " [-1.2087e-02, 1.0358e-02, 1.1422e-02]],\n", - "\n", - " [[-2.1776e-02, -5.6003e-03, 5.6663e-03],\n", - " [ 2.1339e-02, 5.1282e-03, 1.7236e-02],\n", - " [-2.4920e-03, 1.2391e-02, 1.8380e-02]],\n", - "\n", - " [[-1.7954e-02, -1.2279e-02, -1.3557e-02],\n", - " [ 2.7206e-03, 9.9728e-04, 1.4261e-02],\n", - " [-1.4279e-03, 2.9682e-04, -7.3328e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[-2.3195e-02, 1.6100e-02, -7.7113e-06],\n", - " [-1.0721e-02, -2.1473e-02, 1.6510e-02],\n", - " [ 4.8150e-03, 1.8678e-03, -2.0047e-02]],\n", - "\n", - " [[-1.3962e-02, -8.0512e-03, -1.5752e-02],\n", - " [ 6.9845e-03, -4.3029e-03, 1.1148e-02],\n", - " [ 3.1242e-03, -1.9239e-02, 4.6783e-03]],\n", - "\n", - " [[ 6.3272e-03, -6.4049e-03, -1.9009e-02],\n", - " [-7.0409e-03, -1.3461e-02, -1.2215e-02],\n", - " [ 1.7156e-03, 2.9771e-03, 5.8117e-03]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[-1.0419e-02, -6.0178e-03, -1.4953e-02],\n", - " [ 2.8744e-03, -1.3205e-02, 1.2708e-03],\n", - " [ 6.5012e-04, 1.0050e-03, -9.5456e-03]],\n", - "\n", - " [[-1.1000e-02, -2.2153e-04, -1.5086e-03],\n", - " [ 1.0215e-02, -9.4436e-03, 5.2485e-03],\n", - " [-2.4163e-02, -2.5928e-02, -1.3547e-02]],\n", - "\n", - " [[ 5.4092e-03, -6.3877e-03, -1.0579e-02],\n", - " [-8.7557e-04, 1.9436e-02, -4.6480e-03],\n", - " [-1.2522e-02, -3.8739e-03, -1.0423e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[-7.6105e-03, -3.4862e-03, 1.1735e-02],\n", - " [-2.6591e-02, -9.0118e-03, -1.8351e-02],\n", - " [-1.9584e-02, -4.1147e-03, -1.3313e-02]],\n", - "\n", - " [[ 2.4424e-02, 2.6349e-03, 2.4604e-03],\n", - " [ 1.8698e-03, -6.0177e-03, -1.6583e-02],\n", - " [ 8.0879e-03, -2.2503e-03, -2.3768e-02]],\n", - "\n", - " [[-1.1821e-02, 1.1441e-03, 1.3298e-02],\n", - " [-1.4603e-03, 2.7675e-03, -4.1455e-03],\n", - " [ 1.2778e-02, 1.9607e-02, 5.1929e-03]]],\n", - "\n", - "\n", - " [[[ 1.1794e-02, -1.9382e-03, -1.0644e-02],\n", - " [-1.8139e-02, 1.3198e-03, 2.2425e-03],\n", - " [ 1.0553e-03, 2.1395e-03, -3.7677e-03]],\n", - "\n", - " [[-2.0456e-02, -1.6302e-02, 1.1988e-02],\n", - " [-1.3381e-02, -2.8377e-02, -1.1368e-02],\n", - " [ 1.2046e-02, -9.7604e-03, -1.4571e-03]],\n", - "\n", - " [[-5.2547e-03, 2.3563e-03, -2.8907e-02],\n", - " [-4.2318e-03, 2.1925e-02, 3.8629e-03],\n", - " [ 5.9383e-03, -1.0632e-02, 2.3464e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[ 7.0141e-04, 5.0503e-03, -9.5930e-03],\n", - " [ 2.6526e-02, 1.5373e-02, 1.9513e-02],\n", - " [-4.0049e-03, -2.8385e-02, -2.0793e-02]],\n", - "\n", - " [[ 1.4260e-02, 1.0030e-02, 2.5354e-02],\n", - " [-7.2404e-03, -2.1980e-02, 6.5759e-03],\n", - " [ 1.8276e-02, 6.0751e-03, 1.1317e-02]],\n", - "\n", - " [[ 2.6095e-02, 2.3253e-02, 7.7629e-03],\n", - " [ 1.2630e-02, 1.5899e-02, -1.1508e-02],\n", - " [-4.3151e-03, -8.6703e-04, 5.9518e-03]]],\n", - "\n", - "\n", - " [[[ 1.7142e-02, 6.6391e-03, -3.0989e-03],\n", - " [-2.6798e-02, -2.7884e-02, -5.9585e-03],\n", - " [-5.1605e-03, -5.6922e-03, -5.1046e-04]],\n", - "\n", - " [[ 1.0339e-02, 1.4060e-02, 8.2361e-03],\n", - " [ 2.2945e-02, 1.6656e-02, 4.4532e-03],\n", - " [-9.3033e-03, -1.7460e-02, -1.4590e-02]],\n", - "\n", - " [[ 6.1705e-03, -1.2017e-02, -1.1784e-02],\n", - " [ 1.8722e-02, 1.7385e-03, 6.6652e-03],\n", - " [-8.8163e-03, 1.6646e-02, 1.8388e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 3.6180e-02, 2.2804e-02, 8.0334e-03],\n", - " [ 7.5858e-03, -3.0590e-03, -1.0771e-02],\n", - " [ 2.6141e-02, 6.5954e-04, 3.5041e-03]],\n", - "\n", - " [[ 1.1716e-02, 2.4431e-03, 1.5212e-02],\n", - " [-8.8621e-03, 1.3009e-02, 8.3022e-03],\n", - " [-1.0354e-02, -3.3272e-02, 4.8393e-03]],\n", - "\n", - " [[-6.0150e-06, 8.0425e-03, 2.5490e-02],\n", - " [ 1.2990e-02, -1.5889e-02, 2.3843e-02],\n", - " [ 2.6896e-03, 1.3003e-02, 1.4933e-02]]]])\n", - "layer4.1.bn1.weight Parameter containing:\n", - "tensor([0.9805, 0.9846, 0.9759, 0.9868, 0.9540, 0.9773, 0.9970, 1.0088, 0.9459,\n", - " 1.0135, 0.9817, 1.0045, 0.9881, 0.9783, 1.0104, 0.9747, 0.9724, 1.0108,\n", - " 1.0235, 1.0103, 0.9986, 0.9730, 0.9935, 0.9986, 1.0021, 1.0400, 0.9961,\n", - " 0.9527, 1.0376, 1.0096, 0.9515, 0.9864, 0.9696, 0.9796, 0.9870, 0.9595,\n", - " 0.9694, 0.9824, 0.9672, 1.0294, 1.0250, 0.9979, 0.9760, 0.9887, 0.9431,\n", - " 1.0026, 0.9854, 0.9683, 1.0087, 0.9879, 0.9638, 0.9985, 1.0065, 1.0173,\n", - " 1.0125, 0.9784, 0.9980, 1.0158, 1.0134, 0.9912, 0.9901, 1.0090, 0.9841,\n", - " 0.9555, 0.9996, 0.9847, 0.9812, 1.0173, 1.0262, 1.0204, 1.0023, 1.0182,\n", - " 1.0031, 1.0051, 0.9728, 0.9804, 1.0384, 0.9685, 0.9646, 0.9908, 0.9813,\n", - " 0.9950, 0.9732, 1.0086, 0.9890, 1.0129, 1.0204, 0.9924, 0.9849, 0.9708,\n", - " 1.0072, 1.0065, 0.9797, 1.0123, 0.9609, 0.9864, 1.0090, 0.9948, 0.9661,\n", - " 1.0218, 1.0108, 0.9750, 0.9779, 0.9953, 0.9782, 0.9789, 0.9951, 0.9922,\n", - " 1.0282, 1.0167, 1.0085, 0.9837, 0.9874, 1.0074, 1.0155, 1.1245, 0.9838,\n", - " 1.0047, 0.9883, 1.0097, 1.0006, 0.9961, 1.0088, 0.9767, 1.0109, 0.9652,\n", - " 1.0224, 0.9591, 0.9531, 1.0059, 1.0103, 0.9987, 0.9877, 1.0158, 0.9940,\n", - " 0.9723, 1.0097, 1.0078, 0.9928, 0.9844, 1.0161, 0.9910, 0.9563, 0.9602,\n", - " 1.0367, 0.9833, 1.0116, 0.9996, 0.9971, 1.0012, 0.9933, 1.0024, 0.9923,\n", - " 0.9714, 0.9900, 1.0053, 1.0053, 1.0154, 0.9536, 1.0031, 0.9785, 0.9632,\n", - " 0.9639, 1.0079, 1.0122, 0.9868, 1.0039, 0.9561, 1.0111, 1.0003, 0.9817,\n", - " 0.9626, 1.0205, 0.9594, 0.9970, 0.9981, 1.0043, 1.0301, 0.9845, 0.9772,\n", - " 0.9690, 0.9833, 0.9810, 0.9739, 0.9751, 0.9919, 1.0352, 1.0192, 0.9927,\n", - " 0.9860, 1.0235, 0.9994, 0.9919, 1.0123, 0.9797, 0.9880, 1.0050, 0.9768,\n", - " 0.9599, 0.9904, 0.9890, 1.0258, 0.9771, 1.0327, 0.9756, 1.0120, 0.9563,\n", - " 1.0169, 1.0041, 0.9974, 0.9641, 0.9446, 0.9659, 0.9797, 0.9486, 0.9336,\n", - " 1.0025, 0.9776, 0.9898, 0.9870, 1.0090, 0.9526, 0.9759, 1.0003, 0.9790,\n", - " 0.9954, 1.0057, 1.0005, 1.0094, 0.9701, 0.9780, 0.9816, 0.9897, 0.9684,\n", - " 0.9891, 1.0065, 1.0226, 0.9636, 0.9637, 0.9261, 1.0008, 1.0651, 0.9930,\n", - " 0.9958, 0.9700, 0.9869, 0.9898, 1.0047, 1.0040, 1.0012, 1.0196, 0.9930,\n", - " 1.0209, 0.9953, 0.9851, 1.0046, 0.9682, 0.9725, 1.0114, 0.9976, 1.0024,\n", - " 1.0266, 1.0152, 0.9489, 0.9990, 0.9975, 0.9912, 1.0132, 0.9978, 0.9912,\n", - " 0.9794, 0.9792, 0.9685, 1.0182, 0.9790, 0.9918, 0.9876, 1.0032, 1.0073,\n", - " 1.0125, 1.0061, 0.9674, 1.0080, 0.9517, 0.9852, 1.0098, 0.9646, 0.9963,\n", - " 0.9974, 0.9805, 0.9975, 1.0254, 0.9624, 0.9958, 0.9880, 1.0272, 1.0107,\n", - " 1.0114, 1.0068, 0.9564, 0.9969, 0.9830, 1.0223, 1.0025, 0.9901, 1.0058,\n", - " 0.9905, 0.9872, 1.0171, 0.9750, 0.9838, 1.0035, 0.9567, 1.0092, 1.0087,\n", - " 0.9281, 0.9962, 0.9900, 0.9797, 0.9860, 1.0171, 0.9944, 0.9739, 1.0253,\n", - " 0.9776, 1.0007, 0.9641, 0.9815, 1.0056, 0.9797, 1.0091, 0.9975, 0.9990,\n", - " 1.0092, 1.0226, 1.0119, 0.9579, 1.0084, 1.0219, 0.9920, 0.9944, 1.0221,\n", - " 0.9891, 1.0069, 0.9719, 1.0181, 0.9983, 1.0006, 1.0074, 0.9991, 0.9806,\n", - " 0.9645, 0.9950, 0.9887, 1.0078, 0.9995, 1.0178, 0.9971, 0.9745, 0.9846,\n", - " 1.0115, 0.9732, 0.9848, 0.9896, 0.9884, 1.0157, 1.0317, 1.0030, 0.9966,\n", - " 0.9446, 0.9885, 1.0239, 0.9931, 0.9852, 1.0077, 0.9649, 0.9880, 1.0098,\n", - " 0.9700, 1.0080, 1.0383, 0.9895, 0.9959, 0.9903, 1.0276, 0.9585, 0.9994,\n", - " 0.9939, 0.9725, 1.0185, 1.0011, 0.9972, 0.9493, 0.9815, 0.9937, 1.0199,\n", - " 0.9811, 1.0165, 1.0112, 1.0096, 0.9968, 0.9923, 0.9888, 0.9956, 1.0406,\n", - " 0.9957, 1.0087, 0.9953, 0.9961, 0.9790, 0.9766, 0.9792, 0.9992, 0.9895,\n", - " 0.9994, 0.9961, 0.9934, 0.9956, 0.9952, 0.9332, 0.9898, 0.9428, 0.9966,\n", - " 0.9634, 0.9670, 1.0047, 1.0183, 0.9669, 0.9316, 0.9917, 1.0281, 0.9686,\n", - " 0.9995, 1.0016, 1.0001, 0.9854, 0.9899, 0.9941, 0.9777, 0.9883, 0.9760,\n", - " 0.9763, 1.0214, 0.9761, 0.9893, 1.0096, 1.0141, 0.9906, 1.0030, 0.9963,\n", - " 0.9774, 1.0134, 0.9830, 1.0211, 1.0154, 0.9978, 0.9926, 1.0075, 1.0314,\n", - " 0.9650, 1.0000, 1.0233, 1.0141, 1.0137, 0.9987, 0.9872, 0.9927, 1.0081,\n", - " 1.0121, 0.9797, 1.0049, 0.9735, 1.0206, 0.9841, 0.9722, 1.0008, 0.9902,\n", - " 0.9560, 1.0136, 0.9915, 0.9973, 1.0004, 0.9762, 0.9707, 0.9804, 0.9942,\n", - " 0.9903, 1.0008, 0.9939, 1.0084, 0.9891, 0.9804, 0.9916, 0.9853, 0.9925,\n", - " 0.9732, 0.9747, 0.9848, 0.9919, 0.9396, 1.0009, 0.9794, 0.9971, 1.0151,\n", - " 1.0144, 0.9861, 0.9910, 0.9841, 0.9594, 0.9910, 0.9671, 0.9948])\n", - "layer4.1.bn1.bias Parameter containing:\n", - "tensor([-0.0608, -0.1106, -0.1077, -0.0862, -0.0966, -0.1186, -0.0844, -0.1174,\n", - " -0.1269, -0.1427, -0.1088, -0.1491, -0.1391, -0.1073, -0.0685, -0.1510,\n", - " -0.0840, -0.1244, -0.1091, -0.0766, -0.0785, -0.0935, -0.1171, -0.0805,\n", - " -0.1100, -0.1681, -0.1164, -0.1179, -0.0757, -0.1134, -0.1257, -0.0916,\n", - " -0.1223, -0.0713, -0.1444, -0.0968, -0.0749, -0.1484, -0.1897, -0.1927,\n", - " -0.1047, -0.1176, -0.0849, -0.0631, -0.0967, -0.0724, -0.0853, -0.1475,\n", - " -0.2147, -0.1051, -0.0995, -0.1090, -0.1502, -0.0868, -0.1374, -0.0629,\n", - " -0.1739, -0.0584, -0.1067, -0.1148, -0.1047, -0.1238, -0.1547, -0.1756,\n", - " -0.1839, -0.1441, -0.1100, -0.1534, -0.1557, -0.1117, -0.0860, -0.1076,\n", - " -0.1687, -0.1198, -0.1197, -0.0775, -0.1608, -0.1607, -0.0424, -0.1094,\n", - " -0.0943, -0.0690, -0.1275, -0.1888, -0.1153, -0.0499, -0.1390, -0.1021,\n", - " -0.1419, -0.0608, -0.0607, -0.1145, -0.1181, -0.1917, -0.1177, -0.1105,\n", - " -0.0889, -0.0545, -0.0642, -0.1006, -0.0981, -0.1013, -0.0728, -0.0794,\n", - " -0.0851, -0.1038, -0.1028, -0.0633, -0.1673, -0.1243, -0.1759, -0.0961,\n", - " -0.1261, -0.1214, -0.1933, -0.3384, -0.0859, -0.1718, -0.1494, -0.1796,\n", - " -0.0852, -0.0871, -0.0846, -0.1100, -0.2103, -0.1524, -0.0914, -0.0854,\n", - " -0.1211, -0.0944, -0.2191, -0.2019, -0.0837, -0.0704, -0.0908, -0.1619,\n", - " -0.0854, -0.1587, -0.0729, -0.1280, -0.1806, -0.1186, -0.1148, -0.0597,\n", - " -0.1322, -0.1507, -0.1482, -0.0712, -0.1087, -0.1658, -0.1355, -0.1157,\n", - " -0.0734, -0.1032, -0.0741, -0.1099, -0.0713, -0.0955, -0.1165, -0.1179,\n", - " -0.0909, -0.1548, -0.0737, -0.1618, -0.1329, -0.1200, -0.0681, -0.0856,\n", - " -0.1546, -0.0685, -0.1340, -0.0716, -0.0849, -0.0728, -0.0803, -0.0587,\n", - " -0.1320, -0.1410, -0.1361, -0.1207, -0.0786, -0.1054, -0.0963, -0.0907,\n", - " -0.0963, -0.1160, -0.0913, -0.1016, -0.1288, -0.1441, -0.1249, -0.1070,\n", - " -0.1004, -0.0959, -0.1122, -0.1121, -0.0655, -0.1088, -0.1389, -0.0901,\n", - " -0.1826, -0.2301, -0.0609, -0.0993, -0.1086, -0.1292, -0.1134, -0.1160,\n", - " -0.1750, -0.0728, -0.1286, -0.0742, -0.1474, -0.1549, -0.0860, -0.0976,\n", - " -0.1396, -0.1356, -0.0913, -0.0766, -0.1393, -0.1175, -0.0719, -0.0979,\n", - " -0.1313, -0.1787, -0.0839, -0.0961, -0.1154, -0.1008, -0.1131, -0.0837,\n", - " -0.0736, -0.1305, -0.0767, -0.0568, -0.1916, -0.0585, -0.1332, -0.1342,\n", - " -0.1559, -0.3032, -0.0988, -0.1053, -0.0608, -0.1198, -0.1132, -0.1624,\n", - " -0.1054, -0.0622, -0.1352, -0.1380, -0.1693, -0.1173, -0.0603, -0.1139,\n", - " -0.1178, -0.1277, -0.2131, -0.0784, -0.0519, -0.2277, -0.1055, -0.0931,\n", - " -0.1010, -0.1456, -0.0494, -0.1482, -0.1402, -0.0407, -0.1182, -0.1223,\n", - " -0.1086, -0.1254, -0.1419, -0.0713, -0.1254, -0.0799, -0.0662, -0.1664,\n", - " -0.0995, -0.1176, -0.1106, -0.1328, -0.1268, -0.1084, -0.0495, -0.0686,\n", - " -0.1277, -0.0775, -0.0593, -0.1401, -0.0968, -0.0929, -0.0689, -0.0705,\n", - " -0.1046, -0.1467, -0.1568, -0.0931, -0.1599, -0.0977, -0.1093, -0.0855,\n", - " -0.0871, -0.0872, -0.0939, -0.1416, -0.1200, -0.1911, -0.1143, -0.0808,\n", - " -0.1169, -0.1221, -0.1957, -0.1343, -0.1106, -0.0971, -0.0942, -0.1319,\n", - " -0.1436, -0.0967, -0.0854, -0.1190, -0.1104, -0.1925, -0.0976, -0.1878,\n", - " -0.1187, -0.0819, -0.0849, -0.1193, -0.0944, -0.1071, -0.0700, -0.0783,\n", - " -0.1317, -0.1517, -0.0699, -0.1153, -0.1113, -0.2056, -0.1347, -0.1889,\n", - " -0.1641, -0.0675, -0.1322, -0.1328, -0.1147, -0.1123, -0.1087, -0.1466,\n", - " -0.0859, -0.1492, -0.1310, -0.0981, -0.0566, -0.1827, -0.0625, -0.0621,\n", - " -0.0790, -0.0667, -0.0876, -0.0773, -0.0938, -0.1823, -0.1910, -0.1502,\n", - " -0.1416, -0.0941, -0.1280, -0.1437, -0.1407, -0.0356, -0.1092, -0.1139,\n", - " -0.0943, -0.1434, -0.0862, -0.1915, -0.1062, -0.0828, -0.0734, -0.0818,\n", - " -0.1325, -0.0910, -0.0997, -0.1065, -0.1666, -0.0842, -0.1114, -0.1090,\n", - " -0.1106, -0.1000, -0.1096, -0.2158, -0.1140, -0.1128, -0.0802, -0.0938,\n", - " -0.1751, -0.0998, -0.1030, -0.1405, -0.1131, -0.1083, -0.0755, -0.0839,\n", - " -0.1115, -0.1142, -0.0762, -0.0731, -0.1117, -0.1029, -0.0728, -0.0489,\n", - " -0.0535, -0.1026, -0.1306, -0.1330, -0.1135, -0.0804, -0.2111, -0.1288,\n", - " -0.1213, -0.1006, -0.1309, -0.1047, -0.0553, -0.0531, -0.1103, -0.0808,\n", - " -0.0620, -0.0879, -0.1219, -0.1101, -0.1087, -0.0573, -0.1147, -0.0571,\n", - " -0.0626, -0.1239, -0.2353, -0.1909, -0.2048, -0.0762, -0.1088, -0.1020,\n", - " -0.0687, -0.1871, -0.1064, -0.1255, -0.0896, -0.1078, -0.1170, -0.0867,\n", - " -0.0720, -0.1306, -0.1985, -0.0807, -0.1060, -0.1028, -0.2111, -0.1270,\n", - " -0.0658, -0.0956, -0.1027, -0.1447, -0.2264, -0.1177, -0.1414, -0.1139,\n", - " -0.1850, -0.0533, -0.0952, -0.0898, -0.1356, -0.1355, -0.2044, -0.0864,\n", - " -0.0759, -0.0708, -0.0850, -0.1273, -0.1374, -0.1031, -0.0824, -0.1243,\n", - " -0.1191, -0.0874, -0.1535, -0.1514, -0.1384, -0.0627, -0.0857, -0.1217,\n", - " -0.1527, -0.1059, -0.1265, -0.0937, -0.0859, -0.0676, -0.0940, -0.1024,\n", - " -0.1152, -0.0735, -0.1281, -0.1194, -0.0971, -0.0706, -0.1200, -0.0926])\n", - "layer4.1.conv2.weight Parameter containing:\n", - "tensor([[[[ 8.3330e-03, -8.8291e-03, 1.2275e-02],\n", - " [-1.2374e-02, -8.4982e-03, 9.2257e-03],\n", - " [ 1.2488e-02, -3.5636e-03, 4.6247e-04]],\n", - "\n", - " [[-2.5105e-03, -8.5636e-03, -1.5846e-02],\n", - " [-1.0081e-02, -6.2206e-03, 8.4948e-04],\n", - " [ 7.9979e-03, 4.0291e-03, -1.0241e-02]],\n", - "\n", - " [[ 1.0911e-02, -2.9711e-04, -5.6554e-03],\n", - " [ 2.8144e-03, 8.9447e-03, -1.0828e-02],\n", - " [ 4.2974e-03, 1.0024e-02, 1.3600e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-5.2114e-03, -5.0998e-03, -8.2229e-03],\n", - " [-2.2549e-03, -3.9222e-04, -1.6133e-03],\n", - " [-1.4629e-02, 4.1414e-03, 5.7456e-03]],\n", - "\n", - " [[ 1.2655e-02, 2.8707e-03, 3.0632e-03],\n", - " [ 6.5042e-03, 8.0699e-03, 5.7052e-03],\n", - " [ 1.0453e-02, 1.2078e-02, 3.9071e-04]],\n", - "\n", - " [[ 1.7813e-02, -6.4311e-03, 1.0226e-02],\n", - " [ 1.2140e-02, 4.2405e-03, 9.8047e-03],\n", - " [-1.4421e-02, 5.1734e-03, -1.9546e-03]]],\n", - "\n", - "\n", - " [[[ 1.8379e-02, 1.1177e-02, 1.7544e-02],\n", - " [-8.6856e-04, 3.5306e-03, 1.3847e-02],\n", - " [-1.4978e-03, 1.8132e-02, -3.5698e-03]],\n", - "\n", - " [[ 2.4527e-04, 1.4042e-02, -8.8348e-03],\n", - " [ 4.5480e-03, 4.4421e-03, -7.4060e-03],\n", - " [ 7.7895e-03, 5.7431e-05, 4.0264e-03]],\n", - "\n", - " [[ 5.3564e-03, -1.4270e-02, -1.9118e-02],\n", - " [ 1.3765e-02, 8.8396e-04, 6.0510e-03],\n", - " [ 3.6769e-03, -3.0135e-03, -1.5168e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.7418e-02, -9.8899e-03, -2.1458e-03],\n", - " [ 7.3479e-03, 7.0623e-03, 1.1940e-03],\n", - " [ 3.7772e-03, 2.9783e-04, -2.8873e-03]],\n", - "\n", - " [[-9.7363e-03, 1.1990e-02, 5.8910e-04],\n", - " [-1.3291e-02, -1.3032e-02, 2.7085e-03],\n", - " [-1.2218e-02, -2.0081e-02, -2.5794e-04]],\n", - "\n", - " [[-1.4517e-03, -1.3347e-03, -1.9954e-02],\n", - " [-9.5822e-03, 8.9759e-03, -5.2620e-03],\n", - " [ 1.1365e-02, 6.5043e-03, -1.7677e-03]]],\n", - "\n", - "\n", - " [[[ 2.8814e-03, -1.1459e-02, 7.2866e-03],\n", - " [-1.1217e-02, -6.3151e-03, -4.2912e-04],\n", - " [ 4.4563e-03, -3.7896e-03, -1.9523e-03]],\n", - "\n", - " [[-1.3908e-02, 1.0951e-02, 7.5778e-03],\n", - " [ 9.1734e-05, -2.2147e-03, -4.2200e-03],\n", - " [ 9.6189e-03, -1.1441e-02, -1.4601e-02]],\n", - "\n", - " [[ 1.0115e-02, 1.7201e-02, 6.6180e-03],\n", - " [-4.3699e-03, 1.8709e-03, -1.8911e-02],\n", - " [ 6.8236e-03, 1.1824e-02, -5.8826e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[ 1.9188e-02, 8.9656e-03, 3.0848e-02],\n", - " [-5.0819e-03, -3.8729e-03, 1.8700e-02],\n", - " [ 4.7682e-03, 5.8786e-03, 9.7623e-03]],\n", - "\n", - " [[-3.5205e-04, -2.5114e-03, 1.1785e-04],\n", - " [ 8.3362e-03, 7.0353e-03, -1.7875e-03],\n", - " [-1.1702e-03, -8.3737e-03, 1.4237e-02]],\n", - "\n", - " [[-2.8220e-03, -2.0585e-02, -1.9547e-02],\n", - " [-3.0085e-04, -2.0964e-02, -1.9344e-02],\n", - " [-1.0963e-02, -2.8152e-03, -1.9004e-02]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[-1.1269e-02, 1.0234e-02, -1.5817e-03],\n", - " [ 3.7629e-03, -3.9181e-03, -1.3499e-02],\n", - " [-7.1245e-04, -2.1378e-03, 1.7662e-02]],\n", - "\n", - " [[ 1.0952e-02, 6.3739e-03, 1.2541e-02],\n", - " [ 1.6861e-02, -1.0851e-02, 4.3220e-03],\n", - " [ 1.3744e-02, 3.8529e-03, 5.6580e-04]],\n", - "\n", - " [[-1.3857e-02, -2.3699e-04, -2.2844e-03],\n", - " [-2.0008e-02, 6.1069e-03, -1.1951e-02],\n", - " [ 1.0057e-02, -6.6128e-03, 2.5271e-03]],\n", - "\n", - " ...,\n", - "\n", - " [[ 1.3057e-02, 1.1391e-02, 2.4781e-03],\n", - " [ 5.9918e-03, 2.2100e-02, 3.2220e-03],\n", - " [ 1.6557e-02, 1.9436e-02, 4.1398e-03]],\n", - "\n", - " [[ 1.3594e-02, 7.4428e-03, 2.5966e-02],\n", - " [ 1.4304e-02, -4.4501e-03, -4.0079e-03],\n", - " [-1.0729e-02, 9.3449e-03, -7.2407e-03]],\n", - "\n", - " [[ 8.2458e-03, 1.2307e-02, -1.8730e-03],\n", - " [ 3.8755e-03, -1.0899e-02, -1.0243e-02],\n", - " [ 1.4385e-02, 1.4387e-02, -6.9691e-06]]],\n", - "\n", - "\n", - " [[[-1.2377e-03, -2.4913e-02, -6.6713e-03],\n", - " [-1.1979e-02, -3.9882e-02, -3.5081e-02],\n", - " [-3.0059e-02, -1.3726e-02, -1.6513e-02]],\n", - "\n", - " [[-7.6263e-03, 4.6258e-03, 4.1316e-03],\n", - " [-2.0899e-02, 9.7261e-03, 1.7532e-02],\n", - " [ 1.3484e-02, -1.5733e-02, 5.7539e-03]],\n", - "\n", - " [[ 6.4326e-03, 1.1697e-02, -9.0290e-03],\n", - " [ 3.7352e-03, -7.8998e-03, -2.0597e-03],\n", - " [ 1.2484e-02, 6.9576e-03, 1.9936e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-9.8303e-03, -1.0628e-02, -3.4814e-03],\n", - " [ 4.3456e-03, -1.1256e-02, 4.4709e-03],\n", - " [ 5.8977e-03, 1.3371e-03, 1.3130e-03]],\n", - "\n", - " [[ 1.2398e-02, 8.9216e-03, 1.2770e-03],\n", - " [-1.1820e-02, -7.3181e-03, -8.0942e-04],\n", - " [-5.5888e-03, -8.0208e-03, 3.3651e-03]],\n", - "\n", - " [[-4.3322e-03, -9.2950e-03, -1.2784e-02],\n", - " [-6.3165e-03, -1.3445e-03, 8.2466e-03],\n", - " [ 2.8841e-03, -8.8737e-03, 3.1212e-03]]],\n", - "\n", - "\n", - " [[[-6.0997e-03, -1.6026e-02, -1.7981e-02],\n", - " [-1.8672e-02, 3.5800e-03, -1.1836e-02],\n", - " [-2.5707e-03, 6.3123e-03, -3.1627e-03]],\n", - "\n", - " [[ 2.3208e-03, -9.5134e-03, 5.2807e-03],\n", - " [-4.3437e-03, -2.9705e-03, -1.8156e-02],\n", - " [ 1.3944e-02, -1.7958e-02, 4.8227e-03]],\n", - "\n", - " [[-1.5250e-03, 1.4135e-02, -7.6070e-03],\n", - " [-6.7339e-03, 2.5677e-03, -8.7000e-03],\n", - " [-5.2387e-03, 1.7277e-02, 1.6578e-02]],\n", - "\n", - " ...,\n", - "\n", - " [[-3.9090e-03, -4.0446e-03, 1.0471e-02],\n", - " [ 9.9475e-03, -4.6785e-03, 1.9391e-02],\n", - " [ 1.5370e-02, 4.5712e-03, 8.4700e-03]],\n", - "\n", - " [[ 1.7564e-02, -3.1784e-03, -2.9029e-03],\n", - " [-1.3421e-02, -2.4112e-02, -2.0311e-02],\n", - " [-6.3522e-03, -1.1957e-02, -3.5473e-03]],\n", - "\n", - " [[ 3.6304e-03, 2.8775e-04, -5.7601e-03],\n", - " [-2.3570e-03, 1.9330e-02, -3.9670e-03],\n", - " [-7.8952e-03, -1.0448e-02, -2.4362e-02]]]])\n", - "layer4.1.bn2.weight Parameter containing:\n", - "tensor([1.0242, 1.0146, 1.0980, 1.1073, 1.0217, 1.0784, 1.1271, 1.0397, 1.0641,\n", - " 1.0312, 1.0438, 1.0239, 1.0947, 1.0322, 1.0603, 1.0867, 1.0833, 1.0292,\n", - " 1.0249, 1.0900, 1.0181, 1.0410, 1.1284, 1.0774, 1.0293, 1.0697, 1.0609,\n", - " 1.0724, 1.0830, 1.1673, 1.0062, 1.1154, 1.1345, 1.0555, 1.0160, 1.0482,\n", - " 1.0454, 1.1229, 1.2171, 1.0859, 1.0210, 1.1659, 1.1155, 1.0243, 1.1353,\n", - " 1.0889, 1.0895, 1.1472, 1.0563, 1.0662, 1.1582, 1.0786, 1.0374, 0.9917,\n", - " 1.0358, 1.0606, 1.1683, 1.0239, 1.1692, 1.0581, 1.0631, 1.0899, 1.0036,\n", - " 1.1219, 1.0847, 1.0466, 1.0357, 1.0475, 1.0673, 1.0314, 1.0647, 1.0164,\n", - " 1.0947, 1.0652, 1.0888, 1.1126, 1.0377, 1.0581, 1.0609, 1.0574, 1.0381,\n", - " 1.0433, 1.0988, 1.0728, 1.0245, 1.1218, 1.0846, 1.0253, 1.0980, 1.0815,\n", - " 1.0791, 1.0930, 1.0781, 1.0400, 1.0950, 1.0370, 1.0712, 1.0209, 1.0012,\n", - " 1.0378, 1.0469, 1.0637, 1.0556, 1.0836, 1.0990, 1.1014, 1.0262, 1.0404,\n", - " 1.0208, 1.0725, 1.1698, 1.0678, 1.0908, 1.0541, 1.0590, 1.0491, 1.0458,\n", - " 1.0766, 1.1052, 1.0307, 1.0563, 1.0302, 1.1253, 1.0703, 1.0515, 1.0016,\n", - " 1.0125, 1.0734, 1.1176, 1.0469, 1.1395, 1.0709, 1.1377, 1.2334, 1.1187,\n", - " 1.0490, 1.0694, 1.0609, 1.1138, 1.0400, 1.1218, 1.0512, 1.2059, 1.0560,\n", - " 1.0507, 1.0822, 1.0961, 1.0719, 1.0658, 1.0883, 1.0782, 1.0793, 1.0848,\n", - " 1.0312, 1.0373, 1.0784, 1.0518, 1.0757, 1.0129, 1.0832, 1.0855, 1.0356,\n", - " 1.0360, 1.0601, 1.0969, 1.0702, 1.0355, 1.0328, 1.1190, 1.0092, 1.0068,\n", - " 1.0617, 1.1217, 1.0142, 1.1396, 1.1373, 1.2151, 1.1153, 1.0831, 1.0624,\n", - " 0.9801, 1.0520, 1.0310, 1.0255, 1.0246, 1.0752, 1.0150, 1.1226, 1.1025,\n", - " 1.0545, 1.0117, 1.0097, 1.0898, 1.0338, 1.1749, 0.9866, 1.0784, 1.0726,\n", - " 1.0084, 1.1760, 1.0616, 1.1402, 1.0070, 1.0595, 1.0467, 1.1104, 0.9872,\n", - " 1.0514, 1.0343, 1.0182, 1.0532, 1.0238, 1.1251, 1.0696, 1.0595, 1.1180,\n", - " 1.0848, 1.1539, 1.0253, 1.0455, 1.0185, 1.0646, 1.0264, 1.0715, 1.1070,\n", - " 0.9987, 1.0956, 1.0493, 1.0908, 1.0277, 1.0646, 1.0280, 1.0533, 1.0014,\n", - " 1.0239, 1.0354, 1.0473, 1.0112, 1.0122, 1.1165, 1.0348, 0.9891, 1.0225,\n", - " 1.0456, 0.9976, 1.0481, 1.1477, 1.0721, 1.0166, 1.0151, 1.0839, 1.0700,\n", - " 1.1691, 1.0164, 1.0309, 1.0240, 1.1567, 1.0431, 1.0736, 1.1114, 1.0065,\n", - " 1.0703, 1.0611, 1.0475, 1.0474, 1.0789, 1.0267, 1.1011, 1.0307, 1.0315,\n", - " 1.0288, 1.1784, 1.0545, 1.0427, 1.1500, 1.0942, 1.1086, 1.0019, 1.1033,\n", - " 1.1051, 1.0601, 1.0148, 1.0515, 1.0572, 1.0110, 1.1092, 1.0073, 1.0400,\n", - " 1.0272, 1.0786, 1.0280, 1.1989, 1.1029, 1.0141, 1.0876, 1.0109, 1.0384,\n", - " 1.0236, 1.0347, 1.1355, 1.1481, 1.0085, 1.0619, 1.1081, 1.0697, 1.1674,\n", - " 1.0433, 1.0232, 1.0453, 1.0082, 1.0715, 1.0691, 1.0982, 0.9905, 1.0557,\n", - " 1.0059, 1.0252, 1.0962, 1.1416, 1.0434, 1.0214, 1.0485, 1.1178, 1.0474,\n", - " 1.1043, 1.1483, 1.1027, 1.1413, 1.1306, 1.1123, 1.0489, 1.0220, 1.1077,\n", - " 1.0568, 1.0750, 1.0148, 1.1145, 1.0168, 1.0564, 1.0298, 1.0500, 1.0745,\n", - " 1.1666, 1.0316, 1.0129, 1.0392, 0.9879, 1.0129, 1.0548, 1.0767, 1.1244,\n", - " 1.1325, 1.0386, 1.0411, 1.0769, 1.0749, 1.0645, 1.0663, 1.0729, 1.0543,\n", - " 1.0424, 1.0464, 1.0326, 1.1171, 1.1288, 1.0248, 1.0545, 1.0248, 1.0937,\n", - " 1.0726, 1.0840, 1.0670, 1.0170, 1.0825, 1.1329, 1.0469, 1.0576, 1.0934,\n", - " 1.0217, 1.0305, 1.0960, 1.0253, 1.0343, 1.0251, 1.0356, 1.0462, 1.1378,\n", - " 1.1591, 1.0126, 1.0977, 1.0571, 1.0547, 1.0241, 1.0492, 1.0013, 1.0310,\n", - " 1.0516, 1.1145, 1.1385, 1.0651, 1.0141, 1.0097, 1.0420, 1.0733, 1.0806,\n", - " 1.0100, 1.1806, 1.0809, 1.1831, 1.0071, 1.0546, 1.0627, 1.0322, 1.0487,\n", - " 1.0382, 1.1233, 1.0546, 1.1694, 1.0623, 1.0617, 1.0711, 1.0351, 1.0413,\n", - " 1.0751, 1.0799, 1.0228, 1.0337, 1.1464, 1.1565, 1.0602, 1.0685, 1.0337,\n", - " 1.0126, 1.0579, 1.0529, 1.1038, 1.1000, 1.1436, 1.0905, 1.0137, 1.1300,\n", - " 1.0555, 1.0546, 1.1316, 1.0681, 1.0819, 1.0160, 1.0665, 1.0106, 1.0965,\n", - " 1.0122, 1.0951, 1.1311, 1.0323, 0.9902, 1.0654, 1.1184, 1.0895, 1.1081,\n", - " 1.1093, 1.0375, 1.0910, 1.0604, 1.0668, 1.1883, 1.0838, 1.0655, 1.0060,\n", - " 1.0086, 1.0424, 1.1064, 1.0452, 1.0224, 1.0966, 1.0517, 1.0848, 1.0836,\n", - " 1.0501, 1.1060, 1.0344, 1.0298, 1.0721, 1.1158, 1.1045, 1.1082, 1.0016,\n", - " 1.1235, 1.1025, 1.0471, 1.1419, 1.1065, 1.0055, 1.0378, 1.0280, 1.0306,\n", - " 1.0631, 1.0714, 1.1036, 1.0619, 1.0554, 1.0397, 1.0618, 1.0505, 1.0305,\n", - " 1.0174, 1.0348, 1.0543, 1.1011, 1.0401, 1.0488, 1.1258, 1.0311])\n", - "layer4.1.bn2.bias Parameter containing:\n", - "tensor([-0.0325, -0.0605, -0.0857, -0.0918, -0.0602, -0.0845, -0.0784, -0.0422,\n", - " -0.0753, -0.0499, -0.0510, -0.0269, -0.0744, -0.0588, -0.0597, -0.0444,\n", - " -0.0797, -0.0465, -0.0286, -0.1052, -0.0191, -0.0452, -0.1614, -0.1270,\n", - " -0.0403, -0.0847, -0.0681, -0.0665, -0.0804, -0.0736, -0.0157, -0.0878,\n", - " -0.1096, -0.0556, -0.0338, -0.0658, -0.0492, -0.0942, -0.1351, -0.0913,\n", - " -0.0605, -0.1077, -0.0760, -0.0268, -0.0929, -0.1248, -0.0825, -0.1334,\n", - " -0.0502, -0.0905, -0.0893, -0.0895, -0.0377, -0.0247, -0.0406, -0.0473,\n", - " -0.1259, -0.0324, -0.1490, -0.0499, -0.0981, -0.0630, -0.0290, -0.0723,\n", - " -0.0771, -0.0589, -0.0483, -0.0727, -0.0719, -0.0575, -0.0789, -0.0573,\n", - " -0.0650, -0.0764, -0.0824, -0.0842, -0.0428, -0.0618, -0.0717, -0.0674,\n", - " -0.0505, -0.0385, -0.1024, -0.0690, -0.0712, -0.0958, -0.0633, -0.0369,\n", - " -0.1048, -0.0852, -0.0832, -0.0954, -0.1188, -0.0503, -0.0537, -0.0669,\n", - " -0.1042, -0.0372, -0.0239, -0.0809, -0.0686, -0.0626, -0.0527, -0.0764,\n", - " -0.0807, -0.0843, -0.0461, -0.0379, -0.0419, -0.0656, -0.1475, -0.0798,\n", - " -0.0922, -0.0623, -0.0506, -0.0202, -0.0486, -0.0802, -0.0912, -0.0541,\n", - " -0.0456, -0.0481, -0.0719, -0.0575, -0.0546, -0.0292, -0.0339, -0.0723,\n", - " -0.0812, -0.0887, -0.1355, -0.0588, -0.0987, -0.1827, -0.1166, -0.0825,\n", - " -0.0900, -0.0490, -0.1302, -0.0523, -0.1335, -0.0503, -0.1490, -0.0766,\n", - " -0.0275, -0.0724, -0.0757, -0.0518, -0.0562, -0.1220, -0.0533, -0.0878,\n", - " -0.0788, -0.0503, -0.0515, -0.0457, -0.0664, -0.0668, -0.0324, -0.1042,\n", - " -0.1219, -0.0598, -0.0487, -0.0811, -0.1269, -0.0676, -0.0335, -0.0383,\n", - " -0.1570, -0.0325, -0.0246, -0.0483, -0.0958, -0.0224, -0.0991, -0.1084,\n", - " -0.1514, -0.0992, -0.0957, -0.0658, -0.0233, -0.0615, -0.0628, -0.0498,\n", - " -0.0483, -0.0696, -0.0132, -0.0625, -0.1002, -0.0779, -0.0290, -0.0283,\n", - " -0.0602, -0.0512, -0.1086, -0.0198, -0.1045, -0.0523, -0.0197, -0.1472,\n", - " -0.0456, -0.1203, -0.0306, -0.0667, -0.0519, -0.1100, -0.0186, -0.0970,\n", - " -0.0415, -0.0227, -0.0472, -0.0281, -0.1176, -0.0439, -0.0673, -0.1259,\n", - " -0.0806, -0.1527, -0.0542, -0.0571, -0.0378, -0.0684, -0.0353, -0.0721,\n", - " -0.0857, -0.0396, -0.1194, -0.0586, -0.0896, -0.0594, -0.0764, -0.0524,\n", - " -0.0556, -0.0250, -0.0375, -0.0627, -0.0567, -0.0237, -0.0375, -0.1197,\n", - " -0.0859, -0.0185, -0.0505, -0.0494, -0.0357, -0.0603, -0.1173, -0.1240,\n", - " -0.0454, -0.0445, -0.0784, -0.0812, -0.1168, -0.0637, -0.0371, -0.0569,\n", - " -0.1315, -0.0535, -0.0760, -0.1098, -0.0259, -0.0787, -0.0838, -0.0419,\n", - " -0.0635, -0.0755, -0.0300, -0.0900, -0.0478, -0.0320, -0.0406, -0.1227,\n", - " -0.0664, -0.0390, -0.1053, -0.0896, -0.1074, -0.0246, -0.0633, -0.0996,\n", - " -0.0780, -0.0535, -0.0581, -0.0648, -0.0292, -0.0925, -0.0487, -0.0541,\n", - " -0.0306, -0.0418, -0.0345, -0.1240, -0.0759, -0.0462, -0.0935, -0.0202,\n", - " -0.0498, -0.0459, -0.0519, -0.1087, -0.0906, -0.0095, -0.0508, -0.1118,\n", - " -0.0899, -0.1910, -0.0354, -0.0657, -0.0774, -0.0577, -0.0633, -0.0740,\n", - " -0.0751, -0.0207, -0.1020, -0.0199, -0.0419, -0.1173, -0.1266, -0.0585,\n", - " -0.0434, -0.0672, -0.0881, -0.0537, -0.1247, -0.1065, -0.0609, -0.1341,\n", - " -0.1198, -0.0728, -0.0685, -0.0626, -0.0630, -0.0776, -0.0959, -0.0196,\n", - " -0.1088, -0.0125, -0.0495, -0.0377, -0.0659, -0.0954, -0.1193, -0.0432,\n", - " -0.0374, -0.0297, -0.0210, -0.0286, -0.0535, -0.0750, -0.0967, -0.0967,\n", - " -0.0632, -0.0489, -0.0639, -0.0741, -0.0620, -0.0948, -0.0966, -0.0725,\n", - " -0.0739, -0.0649, -0.0182, -0.0995, -0.1026, -0.0422, -0.0829, -0.0523,\n", - " -0.0904, -0.1020, -0.1121, -0.0748, -0.0402, -0.0973, -0.1215, -0.0515,\n", - " -0.0663, -0.0595, -0.0502, -0.0370, -0.1234, -0.0346, -0.0301, -0.0440,\n", - " -0.0398, -0.0457, -0.1115, -0.1188, -0.0378, -0.1049, -0.0657, -0.0521,\n", - " -0.0337, -0.0106, -0.0047, -0.0402, -0.0479, -0.1028, -0.1165, -0.0627,\n", - " -0.0531, -0.0219, -0.0575, -0.0982, -0.1051, -0.0290, -0.1009, -0.0741,\n", - " -0.1253, -0.0165, -0.0490, -0.0601, -0.0461, -0.0342, -0.0526, -0.1105,\n", - " -0.0567, -0.1048, -0.0604, -0.0469, -0.0570, -0.0415, -0.0517, -0.0641,\n", - " -0.0758, -0.0582, -0.0588, -0.1217, -0.1596, -0.0583, -0.0997, -0.0409,\n", - " -0.0497, -0.0700, -0.0495, -0.1100, -0.0712, -0.1107, -0.1000, -0.0489,\n", - " -0.0686, -0.0748, -0.0701, -0.1143, -0.0737, -0.0786, -0.0387, -0.0504,\n", - " -0.0310, -0.0747, -0.0424, -0.0903, -0.1444, -0.0355, -0.0141, -0.0639,\n", - " -0.0912, -0.0753, -0.1046, -0.0843, -0.0553, -0.0972, -0.0655, -0.0793,\n", - " -0.1704, -0.0576, -0.0707, -0.0293, -0.0306, -0.0406, -0.0982, -0.0649,\n", - " -0.0475, -0.0868, -0.0556, -0.0701, -0.0867, -0.0442, -0.1121, -0.0544,\n", - " -0.0447, -0.0932, -0.0825, -0.0917, -0.0973, -0.0286, -0.1121, -0.0692,\n", - " -0.0585, -0.1000, -0.1387, -0.0233, -0.0365, -0.0427, -0.0382, -0.0542,\n", - " -0.0804, -0.0513, -0.0809, -0.0928, -0.0485, -0.0691, -0.0529, -0.0574,\n", - " -0.0412, -0.0369, -0.0579, -0.1057, -0.0373, -0.0553, -0.1050, -0.0390])\n", - "linear.weight Parameter containing:\n", - "tensor([[ 0.0370, -0.0410, -0.0149, ..., -0.0283, -0.0105, -0.0024],\n", - " [ 0.0430, 0.0205, -0.0063, ..., -0.0271, 0.0237, -0.0130],\n", - " [ 0.0301, 0.0312, -0.0110, ..., 0.0084, 0.0321, 0.0298],\n", - " ...,\n", - " [ 0.0296, -0.0215, -0.0301, ..., -0.0128, -0.0214, 0.0351],\n", - " [ 0.0431, 0.0256, 0.0101, ..., -0.0294, -0.0066, -0.0428],\n", - " [-0.0257, 0.0246, 0.0169, ..., -0.0290, 0.0167, 0.0032]],\n", - " requires_grad=True)\n", - "linear.bias Parameter containing:\n", - "tensor([ 0.0029, -0.0253, -0.0079, -0.0038, 0.0257, 0.0105, 0.0017, 0.0405,\n", - " 0.0337, 0.0201], requires_grad=True)\n" - ] - } - ], - "source": [ - "model = create_backbone(name='res18', num_classes=10)\n", - "classifier = nn.Linear(in_features=model.output_dim, out_features=10, bias=True)\n", - "for name, value in model.named_parameters():\n", - " if not name.startswith('linear') :\n", - " value.requires_grad = False\n", - "pretrained_model = torch.load('../checkpoint/SimCLR_on_Cifar4CL_lr0.5_lstep1_rn100.ckpt', map_location='cpu')\n", - "model.load_state_dict({k[9:]:v for k, v in pretrained_model['model'].items() if k.startswith('backbone.')}, strict=False)\n", - "\n", - "del pretrained_model\n", - "# model.add_module(\"Linear\", classifier)\n", - "for name, value in model.named_parameters():\n", - " print(name, value)" - ] - }, - { - "cell_type": "code", - "execution_count": 93, - "id": "166237f2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameter containing:\n", - "tensor([[-0.0187, -0.0211, -0.0236, ..., -0.0326, 0.0137, 0.0262],\n", - " [-0.0381, -0.0086, 0.0255, ..., -0.0020, 0.0384, -0.0152],\n", - " [ 0.0052, 0.0015, -0.0284, ..., 0.0357, 0.0110, -0.0346],\n", - " ...,\n", - " [-0.0243, 0.0382, -0.0128, ..., 0.0228, 0.0274, -0.0390],\n", - " [-0.0257, 0.0127, 0.0342, ..., 0.0194, 0.0440, -0.0125],\n", - " [-0.0125, 0.0293, -0.0360, ..., -0.0290, 0.0309, 0.0352]],\n", - " requires_grad=True)\n", - "Parameter containing:\n", - "tensor([-0.0302, 0.0121, 0.0225, -0.0034, 0.0179, -0.0157, -0.0225, -0.0358,\n", - " 0.0366, -0.0426], requires_grad=True)\n" - ] - } - ], - "source": [ - "def get_freezed_parameters(module):\n", - " \"\"\"\n", - " Returns names of freezed parameters of the given module.\n", - " \"\"\"\n", - " \n", - " freezed_parameters = []\n", - " for name, parameter in module.named_parameters():\n", - " if not parameter.requires_grad:\n", - " freezed_parameters.append(name)\n", - " \n", - " return freezed_parameters\n", - "\n", - "get_freezed_parameters(model)\n", - "for i in filter(lambda p: p.requires_grad, model.parameters()):\n", - " print(i)" - ] - }, - { - "cell_type": "code", - "execution_count": 94, - "id": "d6dc3b50", - "metadata": {}, - "outputs": [], - "source": [ - "optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)\n", - "optimizer.zero_grad()\n", - "output = model(torch.rand((1,3,32,32)))\n", - "loss = F.cross_entropy(output, torch.tensor([1]))\n", - "loss.backward()\n", - "optimizer.step()" - ] - }, - { - "cell_type": "code", - "execution_count": 99, - "id": "d6ff199a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 99, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "'abcdfgh'.endswith('dfgh')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e5293562", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py index b34a8ba3f..4359fe1b7 100644 --- a/federatedscope/cl/trainer/trainer.py +++ b/federatedscope/cl/trainer/trainer.py @@ -117,17 +117,6 @@ def _hook_on_fit_end(self, ctx): np.concatenate(ctx.get("{}_y_prob".format(ctx.cur_data_split)))) results = self.metric_calculator.eval(ctx) setattr(ctx, 'eval_metrics', results) - -def linear_prob_forward(ctx): - x, label = [_.to(ctx.device) for _ in ctx.data_batch] - pred = ctx.model(x) - if len(label.size()) == 0: - label = label.unsqueeze(0) - ctx.loss_batch = ctx.criterion(pred, label) - ctx.y_true = label - ctx.y_prob = pred - - ctx.batch_size = len(label) class LPTrainer(GeneralTorchTrainer): pass diff --git a/federatedscope/core/auxiliaries/eunms.py b/federatedscope/core/auxiliaries/eunms.py deleted file mode 100644 index 2ef6b478b..000000000 --- a/federatedscope/core/auxiliaries/eunms.py +++ /dev/null @@ -1,30 +0,0 @@ -class MODE: - """ - - Note: - Currently StrEnum cannot be imported with the environment - `sys.version_info < (3, 11)`, so we simply create a MODE class here. - """ - TRAIN = 'train' - TEST = 'test' - VAL = 'val' - FINETUNE = 'finetune' - - -class TRIGGER: - ON_FIT_START = 'on_fit_start' - ON_EPOCH_START = 'on_epoch_start' - ON_BATCH_START = 'on_batch_start' - ON_BATCH_FORWARD = 'on_batch_forward' - ON_BATCH_BACKWARD = 'on_batch_backward' - ON_BATCH_END = 'on_batch_end' - ON_EPOCH_END = 'on_epoch_end' - ON_FIT_END = 'on_fit_end' - - @classmethod - def contains(cls, item): - return item in [ - "on_fit_start", "on_epoch_start", "on_batch_start", - "on_batch_forward", "on_batch_backward", "on_batch_end", - "on_epoch_end", "on_fit_end" - ] diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index e04b5b276..9e479bfaa 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -70,10 +70,14 @@ def get_shape_from_data(data, model_config, backend='torch'): import torch if issubclass(type(data_representative), torch.utils.data.DataLoader): x, _ = next(iter(data_representative)) + if x.type == list: + return x[0].shape return x.shape else: try: x, _ = data_representative + if x.type == list: + return x[0].shape return x.shape except: raise TypeError('Unsupported data type.') From 5f8ef6d291583947432fcf6e66774aa44fcb974e Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Wed, 3 Aug 2022 04:41:17 +0800 Subject: [PATCH 06/46] debug --- federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml | 2 ++ federatedscope/core/auxiliaries/model_builder.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml b/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml index 3372c874d..aa898bef4 100644 --- a/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml @@ -12,6 +12,8 @@ data: root: 'data' type: 'Cifar4CL' batch_size: 256 + splitter: 'lda' + splitter_args: [{'alpha': 0.5}] num_workers: 2 model: type: 'SimCLR' diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index 9e479bfaa..8674ce844 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -141,7 +141,7 @@ def get_model(model_config, local_data=None, backend='torch'): model = get_cnn(model_config, input_shape) elif model_config.type.lower() in ['simclr', 'simclr_linear']: from federatedscope.cl.model import get_simclr - model = get_simclr(model_config, local_data) + model = get_simclr(model_config, input_shape) if model_config.type.lower().endswith('linear'): for name, value in model.named_parameters(): if not name.startswith('linear') : From e3e2f642b28ab31c4f7a48e3385470f00cb97abc Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Wed, 3 Aug 2022 21:00:25 +0800 Subject: [PATCH 07/46] debug --- .../fedsimclr_linearprob_on_cifar10.yaml | 6 +- .../cl/baseline/fedsimclr_on_cifar10.yaml | 2 +- .../supervised_fedavg_on_cifar10.yaml | 36 ++++++ .../baseline/supervised_local_on_cifar10.yaml | 34 ++++++ .../unpretrained_linearprob_on_cifar10.yaml | 34 ++++++ federatedscope/cl/dataloader/Cifar10.py | 103 +++++++++--------- federatedscope/cl/model/SimCLR.py | 28 ++++- federatedscope/cl/trainer/trainer.py | 101 +---------------- federatedscope/core/aggregator.py | 27 ++++- .../core/auxiliaries/aggregator_builder.py | 4 +- .../core/auxiliaries/model_builder.py | 6 +- 11 files changed, 222 insertions(+), 159 deletions(-) create mode 100644 federatedscope/cl/baseline/supervised_fedavg_on_cifar10.yaml create mode 100644 federatedscope/cl/baseline/supervised_local_on_cifar10.yaml create mode 100644 federatedscope/cl/baseline/unpretrained_linearprob_on_cifar10.yaml diff --git a/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml b/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml index b4e85ed31..a866cb4d9 100644 --- a/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml @@ -6,12 +6,14 @@ federate: client_num: 5 sample_client_rate: 1.0 method: local - #restore_from: 'checkpoint/SimCLR_on_Cifar4CL_lr0.1_lstep5_rn100.ckpt' +# restore_from: 'checkpoint/SimCLR_on_Cifar4CL_lr0.1_lstep5_rn100.ckpt' data: root: 'data' type: 'Cifar4LP' batch_size: 256 - num_workers: 2 + splitter: 'lda' + splitter_args: [{'alpha': 0.5}] + num_workers: 4 model: type: 'SimCLR_linear' train: diff --git a/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml b/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml index aa898bef4..1eadbf684 100644 --- a/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml @@ -14,7 +14,7 @@ data: batch_size: 256 splitter: 'lda' splitter_args: [{'alpha': 0.5}] - num_workers: 2 + num_workers: 4 model: type: 'SimCLR' train: diff --git a/federatedscope/cl/baseline/supervised_fedavg_on_cifar10.yaml b/federatedscope/cl/baseline/supervised_fedavg_on_cifar10.yaml new file mode 100644 index 000000000..7fa08d7a6 --- /dev/null +++ b/federatedscope/cl/baseline/supervised_fedavg_on_cifar10.yaml @@ -0,0 +1,36 @@ +use_gpu: True +device: 2 +federate: + mode: standalone + total_round_num: 50 + client_num: 5 + sample_client_rate: 1.0 + share_local_model: True + online_aggr: True + method: fedavg +data: + root: 'data' + type: 'Cifar4LP' + batch_size: 256 + splitter: 'lda' + splitter_args: [{'alpha': 0.5}] + num_workers: 4 +model: + type: 'supervised_fedavg' +train: + local_update_steps: 1 + batch_or_epoch: 'epoch' + optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0 +early_stop: + patience: 0 +criterion: + type: CrossEntropyLoss +trainer: + type: general +eval: + freq: 2 + metrics: ['acc'] + split: ['val', 'test'] \ No newline at end of file diff --git a/federatedscope/cl/baseline/supervised_local_on_cifar10.yaml b/federatedscope/cl/baseline/supervised_local_on_cifar10.yaml new file mode 100644 index 000000000..e790a7bec --- /dev/null +++ b/federatedscope/cl/baseline/supervised_local_on_cifar10.yaml @@ -0,0 +1,34 @@ +use_gpu: True +device: 2 +federate: + mode: standalone + total_round_num: 50 + client_num: 5 + sample_client_rate: 1.0 + method: local +data: + root: 'data' + type: 'Cifar4LP' + batch_size: 256 + splitter: 'lda' + splitter_args: [{'alpha': 0.5}] + num_workers: 4 +model: + type: 'supervised_local' +train: + local_update_steps: 1 + batch_or_epoch: 'epoch' + optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0 +early_stop: + patience: 0 +criterion: + type: CrossEntropyLoss +trainer: + type: general +eval: + freq: 2 + metrics: ['acc'] + split: ['val', 'test'] \ No newline at end of file diff --git a/federatedscope/cl/baseline/unpretrained_linearprob_on_cifar10.yaml b/federatedscope/cl/baseline/unpretrained_linearprob_on_cifar10.yaml new file mode 100644 index 000000000..aac85d048 --- /dev/null +++ b/federatedscope/cl/baseline/unpretrained_linearprob_on_cifar10.yaml @@ -0,0 +1,34 @@ +use_gpu: True +device: 2 +federate: + mode: standalone + total_round_num: 50 + client_num: 5 + sample_client_rate: 1.0 + method: local +data: + root: 'data' + type: 'Cifar4LP' + batch_size: 256 + splitter: 'lda' + splitter_args: [{'alpha': 0.5}] + num_workers: 4 +model: + type: 'SimCLR_linear' +train: + local_update_steps: 1 + batch_or_epoch: 'epoch' + optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0 +early_stop: + patience: 0 +criterion: + type: CrossEntropyLoss +trainer: + type: general +eval: + freq: 2 + metrics: ['acc'] + split: ['val', 'test'] \ No newline at end of file diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index b5cb16103..96ae35c93 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -11,6 +11,8 @@ import pickle as pkl import numpy as np from federatedscope.register import register_data +from federatedscope.core.auxiliaries.splitter_builder import get_splitter + class SimCLRTransform(): @@ -43,43 +45,48 @@ def Cifar4CL(config): T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) + splits = config.data.splits + path = config.data.root + name = config.data.type.upper() + client_num = config.federate.client_num + batch_size = config.data.batch_size - data_train = CIFAR10(config.data.root, train=True, download=True, transform=transform_train) - data_val = CIFAR10(config.data.root, train=True, download=True, transform=transform_train) - data_test = CIFAR10(config.data.root, train=False, download=True, transform=transform_train) + data_train = CIFAR10(path, train=True, download=True, transform=transform_train) +# data_val = CIFAR10(path, train=True, download=True, transform=transform_train) + data_test = CIFAR10(path, train=False, download=True, transform=transform_train) # Split data into dict data_dict = dict() - train_per_client = len(data_train) // config.federate.client_num - val_per_client = len(data_val) // config.federate.client_num - test_per_client = len(data_test) // config.federate.client_num + + # Splitter + splitter = get_splitter(config) + data_train = splitter(data_train) + data_val = data_train + data_test = splitter(data_test) + + + client_num = min(len(data_train), config.federate.client_num + ) if config.federate.client_num > 0 else len(data_train) + config.merge_from_list(['federate.client_num', client_num]) + - for client_idx in range(1, config.federate.client_num + 1): + for client_idx in range(1, client_num + 1): dataloader_dict = { 'train': - DataLoader([ - data_train[i] - for i in range((client_idx - 1) * - train_per_client, client_idx * train_per_client) - ], + DataLoader(data_train[client_idx - 1], config.data.batch_size, - shuffle=config.data.shuffle), + shuffle=config.data.shuffle, + num_workers=config.data.num_workers), 'val': - DataLoader([ - data_val[i] - for i in range((client_idx - 1) * - val_per_client, client_idx * val_per_client) - ], + DataLoader(data_val[client_idx - 1], config.data.batch_size, - shuffle=config.data.shuffle), + shuffle=False, + num_workers=config.data.num_workers), 'test': - DataLoader([ - data_test[i] - for i in range((client_idx - 1) * test_per_client, client_idx * - test_per_client) - ], + DataLoader(data_test[client_idx - 1], config.data.batch_size, - shuffle=False) + shuffle=False, + num_workers=config.data.num_workers), } data_dict[client_idx] = dataloader_dict r""" @@ -119,40 +126,38 @@ def Cifar4LP(config): # Split data into dict data_dict = dict() - train_per_client = len(data_train) // config.federate.client_num - val_per_client = len(data_val) // config.federate.client_num - test_per_client = len(data_test) // config.federate.client_num + + # Splitter + splitter = get_splitter(config) + data_train = splitter(data_train) + data_val = splitter(data_val) + data_test = splitter(data_test) + + + client_num = min(len(data_train), config.federate.client_num + ) if config.federate.client_num > 0 else len(data_train) + config.merge_from_list(['federate.client_num', client_num]) + - print("time1") - for client_idx in range(1, config.federate.client_num + 1): + for client_idx in range(1, client_num + 1): dataloader_dict = { 'train': - DataLoader([ - data_train[i] - for i in range((client_idx - 1) * - train_per_client, client_idx * train_per_client) - ], + DataLoader(data_train[client_idx - 1], config.data.batch_size, - shuffle=config.data.shuffle), + shuffle=config.data.shuffle, + num_workers=config.data.num_workers), 'val': - DataLoader([ - data_val[i] - for i in range((client_idx - 1) * - val_per_client, client_idx * val_per_client) - ], + DataLoader(data_val[client_idx - 1], config.data.batch_size, - shuffle=config.data.shuffle), + shuffle=False, + num_workers=config.data.num_workers), 'test': - DataLoader([ - data_test[i] - for i in range((client_idx - 1) * test_per_client, client_idx * - test_per_client) - ], + DataLoader(data_test[client_idx - 1], config.data.batch_size, - shuffle=False) + shuffle=False, + num_workers=config.data.num_workers), } data_dict[client_idx] = dataloader_dict - print("time2") r""" Returns: diff --git a/federatedscope/cl/model/SimCLR.py b/federatedscope/cl/model/SimCLR.py index 444e942cb..a10325153 100644 --- a/federatedscope/cl/model/SimCLR.py +++ b/federatedscope/cl/model/SimCLR.py @@ -183,9 +183,11 @@ def forward(self, x): # SimCLR class simclr(nn.Module): + ''' + source: https://github.com/akhilmathurs/orchestra/blob/main/models.py + ''' def __init__(self, bbone_arch): super(simclr, self).__init__() - self.T = 0.5 self.register_buffer("rounds_done", torch.zeros(1)) self.backbone = create_backbone(bbone_arch, num_classes=0) @@ -197,17 +199,31 @@ def forward(self, x1, x2, x3=None, deg_labels=None): # L = NT_xentloss(z1, z2, temperature=self.T) return z1, z2 + +class simclr_linearprob(nn.Module): + def __init__(self, bbone_arch, num_classes=10): + super(simclr_linearprob, self).__init__() + self.register_buffer("rounds_done", torch.zeros(1)) + + self.backbone = create_backbone(bbone_arch, num_classes=0) + self.linear = nn.Linear(512, num_classes, bias=True) + + def forward(self, x): + N = x.shape[0] + out = self.backbone(x) + out = self.linear(out) + + return out def ModelBuilder(model_config, local_data): # You can also build models without local_data - data = next(iter(local_data['train'])) if model_config.type == "SimCLR": model = simclr(bbone_arch='res18') return model - if model_config.type == "SimCLR_linear": - model = create_backbone(name='res18', num_classes=10) - pretrained_model = torch.load('checkpoint/SimCLR_on_Cifar4CL_lr0.5_lstep5_rn100.ckpt', map_location='cpu') - model.load_state_dict({k[9:]:v for k, v in pretrained_model['model'].items() if k.startswith('backbone.')}, strict=False) + if model_config.type in ["SimCLR_linear","supervised_local","supervised_fedavg"]: + model = simclr_linearprob(bbone_arch='res18', num_classes=10) +# pretrained_model = torch.load('checkpoint/SimCLR_on_Cifar4CL_lr0.1_lstep5_rn100.ckpt', map_location='cpu') +# model.load_state_dict(pretrained_model['model'], strict=False) # for name, value in model.named_parameters(): # if not name.startswith('linear') : # value.requires_grad = False diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py index 4359fe1b7..c5cc741b8 100644 --- a/federatedscope/cl/trainer/trainer.py +++ b/federatedscope/cl/trainer/trainer.py @@ -3,56 +3,6 @@ from federatedscope.core.auxiliaries import utils import numpy as np -def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): - # compute cos similarity between each feature vector and feature bank ---> [B, N] - sim_matrix = torch.mm(feature, feature_bank) - # [B, K] - sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) - # [B, K] - sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices) - sim_weight = (sim_weight / knn_t).exp() - - # counts for each class - one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device) - # [B*K, C] - one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0) - # weighted score ---> [B, C] - pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1) - - pred_labels = pred_scores.argsort(dim=-1, descending=True) - return pred_labels - -def knn_monitor(net, memory_data_loader, test_data_loader, k=200, t=0.1, device="cpu", verbose=True): - net.eval() - classes = len(memory_data_loader.dataset.classes) - total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] - feature_labels = [] - with torch.no_grad(): - # generate feature bank - for data, target in tqdm(memory_data_loader, desc='Feature extracting', leave=False, disable=not verbose): - feature = net(data.to(device)) - feature = F.normalize(feature, dim=1) - feature_bank.append(feature) - feature_labels.append(target.to(device)) - # [D, N] - feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() - # [N] - # feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=device) - feature_labels = torch.cat(feature_labels, dim=0).contiguous() - - # loop test data to predict the label by weighted knn search - test_bar = tqdm(test_data_loader, desc='kNN', disable=not verbose) - for data, target in test_bar: - data, target = data.to(device), target.to(device) - feature = net(data) - feature = F.normalize(feature, dim=1) - - pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, k, t) - - total_num += data.size(0) - total_top1 += (pred_labels[:, 0] == target).float().sum().item() - return total_top1 / total_num * 100 - class CLTrainer(GeneralTorchTrainer): def _hook_on_batch_forward(self, ctx): @@ -70,53 +20,12 @@ def _hook_on_batch_forward(self, ctx): def _hook_on_batch_end(self, ctx): # update statistics - setattr( - ctx, "loss_batch_total_{}".format(ctx.cur_data_split), - ctx.get("loss_batch_total_{}".format(ctx.cur_data_split)) + - ctx.loss_batch.item() * ctx.batch_size) - - if ctx.get("loss_regular", None) is None or ctx.loss_regular == 0: - loss_regular = 0. - else: - loss_regular = ctx.loss_regular.item() - setattr( - ctx, "loss_regular_total_{}".format(ctx.cur_data_split), - ctx.get("loss_regular_total_{}".format(ctx.cur_data_split)) + - loss_regular) - setattr( - ctx, "num_samples_{}".format(ctx.cur_data_split), - ctx.get("num_samples_{}".format(ctx.cur_data_split)) + - ctx.batch_size) - + ctx.num_samples += ctx.batch_size + ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size + ctx.loss_regular_total += float(ctx.get("loss_regular", 0.)) # cache label for evaluate - ctx.get("{}_y_true".format(ctx.cur_data_split)).append( - ctx.y_true.detach().cpu().numpy()) - -# print(len(ctx.y_prob), ctx.y_prob[0].size(), ctx.y_prob[1].size()) - ctx.get("{}_y_prob".format(ctx.cur_data_split)).append( - ctx.y_prob[0].detach().cpu().numpy()) - - # clean temp ctx - ctx.data_batch = None - ctx.batch_size = None - ctx.loss_task = None - ctx.loss_batch = None - ctx.loss_regular = None - ctx.y_true = None - ctx.y_prob = None - - def _hook_on_fit_end(self, ctx): - """Evaluate metrics. - - """ - setattr( - ctx, "{}_y_true".format(ctx.cur_data_split), - np.concatenate(ctx.get("{}_y_true".format(ctx.cur_data_split)))) - setattr( - ctx, "{}_y_prob".format(ctx.cur_data_split), - np.concatenate(ctx.get("{}_y_prob".format(ctx.cur_data_split)))) - results = self.metric_calculator.eval(ctx) - setattr(ctx, 'eval_metrics', results) + ctx.ys_true.append(ctx.y_true.detach().cpu().numpy()) + ctx.ys_prob.append(ctx.y_prob[0].detach().cpu().numpy()) class LPTrainer(GeneralTorchTrainer): pass diff --git a/federatedscope/core/aggregator.py b/federatedscope/core/aggregator.py index 41b269a07..584d8c7c2 100644 --- a/federatedscope/core/aggregator.py +++ b/federatedscope/core/aggregator.py @@ -104,12 +104,37 @@ def _para_weighted_avg(self, models, recover_fun=None): return avg_model -class NoCommunicationAggregator(Aggregator): +class NoCommunicationAggregator(ClientsAvgAggregator): """"Clients do not communicate. Each client work locally """ + def __init__(self, model=None, device='cpu', config=None): + super(NoCommunicationAggregator, self).__init__(model, device, config) + def aggregate(self, agg_info): # do nothing return {} + + def update(self, model_parameters): + pass + + def save_model(self, path, cur_round=-1): + assert self.model is not None + + ckpt = {'cur_round': cur_round, 'model': self.model.state_dict()} + torch.save(ckpt, path) + + def load_model(self, path): + assert self.model is not None + + if os.path.exists(path): + ckpt = torch.load(path, map_location=self.device) + self.model.load_state_dict(ckpt['model'], strict=False) + return ckpt['cur_round'] + else: + raise ValueError("The file {} does NOT exist".format(path)) + + def _para_weighted_avg(self, models, recover_fun=None): + pass class AsynClientsAvgAggregator(ClientsAvgAggregator): diff --git a/federatedscope/core/auxiliaries/aggregator_builder.py b/federatedscope/core/auxiliaries/aggregator_builder.py index e9cd0b256..97752d824 100644 --- a/federatedscope/core/auxiliaries/aggregator_builder.py +++ b/federatedscope/core/auxiliaries/aggregator_builder.py @@ -48,7 +48,9 @@ def get_aggregator(method, model=None, device=None, online=False, config=None): config=config, beta=config.personalization.beta) elif aggregator_type == 'no_communication': - return NoCommunicationAggregator() + return NoCommunicationAggregator(model=model, + device=device, + config=config) else: raise NotImplementedError( "Aggregator {} is not implemented.".format(aggregator_type)) diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index 8674ce844..0ad808437 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -70,13 +70,13 @@ def get_shape_from_data(data, model_config, backend='torch'): import torch if issubclass(type(data_representative), torch.utils.data.DataLoader): x, _ = next(iter(data_representative)) - if x.type == list: + if isinstance(x, list): return x[0].shape return x.shape else: try: x, _ = data_representative - if x.type == list: + if isinstance(x, list): return x[0].shape return x.shape except: @@ -139,7 +139,7 @@ def get_model(model_config, local_data=None, backend='torch'): elif model_config.type.lower() in ['convnet2', 'convnet5', 'vgg11', 'lr']: from federatedscope.cv.model import get_cnn model = get_cnn(model_config, input_shape) - elif model_config.type.lower() in ['simclr', 'simclr_linear']: + elif model_config.type.lower() in ['simclr', 'simclr_linear',"supervised_local","supervised_fedavg"]: from federatedscope.cl.model import get_simclr model = get_simclr(model_config, input_shape) if model_config.type.lower().endswith('linear'): From 6ac15aa7e1e6162ec95eb706e0b948e8678d036d Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Wed, 3 Aug 2022 23:36:37 +0800 Subject: [PATCH 08/46] debug --- .../cl/baseline/supervised_fedavg_on_cifar10.yaml | 4 ++-- federatedscope/cl/model/SimCLR.py | 10 ++-------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/federatedscope/cl/baseline/supervised_fedavg_on_cifar10.yaml b/federatedscope/cl/baseline/supervised_fedavg_on_cifar10.yaml index 7fa08d7a6..7a91c110e 100644 --- a/federatedscope/cl/baseline/supervised_fedavg_on_cifar10.yaml +++ b/federatedscope/cl/baseline/supervised_fedavg_on_cifar10.yaml @@ -1,5 +1,5 @@ use_gpu: True -device: 2 +device: 0 federate: mode: standalone total_round_num: 50 @@ -7,7 +7,7 @@ federate: sample_client_rate: 1.0 share_local_model: True online_aggr: True - method: fedavg + method: FedAvg data: root: 'data' type: 'Cifar4LP' diff --git a/federatedscope/cl/model/SimCLR.py b/federatedscope/cl/model/SimCLR.py index a10325153..b5b914555 100644 --- a/federatedscope/cl/model/SimCLR.py +++ b/federatedscope/cl/model/SimCLR.py @@ -222,19 +222,13 @@ def ModelBuilder(model_config, local_data): return model if model_config.type in ["SimCLR_linear","supervised_local","supervised_fedavg"]: model = simclr_linearprob(bbone_arch='res18', num_classes=10) -# pretrained_model = torch.load('checkpoint/SimCLR_on_Cifar4CL_lr0.1_lstep5_rn100.ckpt', map_location='cpu') -# model.load_state_dict(pretrained_model['model'], strict=False) -# for name, value in model.named_parameters(): -# if not name.startswith('linear') : -# value.requires_grad = False return model from federatedscope.register import register_model def get_simclr(model_config, local_data): - if model_config.type in ["SimCLR", "SimCLR_linear"]: - model = ModelBuilder(model_config, local_data) - return model + model = ModelBuilder(model_config, local_data) + return model register_model("SimCLR", get_simclr) From 774419ade90a2f51f48dc497b710fafb70380a64 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Fri, 5 Aug 2022 02:28:46 +0800 Subject: [PATCH 09/46] debug --- .../cl/baseline/fedsimclr_linearprob_on_cifar10.yaml | 2 +- federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml | 2 +- federatedscope/cl/model/SimCLR.py | 9 +-------- federatedscope/cl/trainer/trainer.py | 11 ++++++----- 4 files changed, 9 insertions(+), 15 deletions(-) diff --git a/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml b/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml index a866cb4d9..02ed75804 100644 --- a/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml @@ -6,7 +6,7 @@ federate: client_num: 5 sample_client_rate: 1.0 method: local -# restore_from: 'checkpoint/SimCLR_on_Cifar4CL_lr0.1_lstep5_rn100.ckpt' + restore_from: 'SimCLR_on_Cifar4CL_lr0.1_lstep5_rn100.ckpt' data: root: 'data' type: 'Cifar4LP' diff --git a/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml b/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml index 1eadbf684..4667f41b5 100644 --- a/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml @@ -7,7 +7,7 @@ federate: share_local_model: True online_aggr: True sample_client_rate: 1.0 - save_to: 'checkpoint/SimCLR_on_Cifar4CL_lr0.05_lus5_rn100.ckpt' + save_to: 'SimCLR_on_Cifar4CL_lda0.5_lr0.05_lus5_rn100.ckpt' data: root: 'data' type: 'Cifar4CL' diff --git a/federatedscope/cl/model/SimCLR.py b/federatedscope/cl/model/SimCLR.py index b5b914555..cfa25f1c4 100644 --- a/federatedscope/cl/model/SimCLR.py +++ b/federatedscope/cl/model/SimCLR.py @@ -6,7 +6,6 @@ import numpy as np from collections import OrderedDict -#### ResNets class BasicBlock(nn.Module): expansion = 1 @@ -152,9 +151,7 @@ def ResNet56(num_classes=10, block="BasicBlock"): ### Retrieval function for backbones ### def create_backbone(name, num_classes=10, block='BasicBlock'): - if(name == 'VGG'): - net = VGGmodel(num_classes=num_classes) - elif(name == 'res18'): + if(name == 'res18'): net = ResNet18(num_classes=num_classes, block=block) elif(name == 'res34'): net = ResNet34(num_classes=num_classes, block=block) @@ -183,9 +180,6 @@ def forward(self, x): # SimCLR class simclr(nn.Module): - ''' - source: https://github.com/akhilmathurs/orchestra/blob/main/models.py - ''' def __init__(self, bbone_arch): super(simclr, self).__init__() self.register_buffer("rounds_done", torch.zeros(1)) @@ -196,7 +190,6 @@ def __init__(self, bbone_arch): def forward(self, x1, x2, x3=None, deg_labels=None): N = x1.shape[0] z1, z2 = self.projector(self.backbone(x1)), self.projector(self.backbone(x2)) -# L = NT_xentloss(z1, z2, temperature=self.T) return z1, z2 diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py index c5cc741b8..97f163a9f 100644 --- a/federatedscope/cl/trainer/trainer.py +++ b/federatedscope/cl/trainer/trainer.py @@ -12,11 +12,12 @@ def _hook_on_batch_forward(self, ctx): z1, z2 = ctx.model(x1, x2) if len(label.size()) == 0: label = label.unsqueeze(0) - ctx.loss_batch = ctx.criterion(z1, z2) - ctx.y_true = label - ctx.y_prob = z1, z2 - - ctx.batch_size = len(label) + + ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) + ctx.y_prob = CtxVar(z1, z2, LIFECYCLE.BATCH) + ctx.loss_batch = CtxVar(ctx.criterion(z1, z2), LIFECYCLE.BATCH) + ctx.batch_size = CtxVar(len(label), LIFECYCLE.BATCH) + def _hook_on_batch_end(self, ctx): # update statistics From a55b84434b9ca13a464b78f85eca3ed4e14433a2 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Fri, 5 Aug 2022 02:33:33 +0800 Subject: [PATCH 10/46] debug --- federatedscope/cl/dataloader/Cifar10.py | 30 ++----------------------- 1 file changed, 2 insertions(+), 28 deletions(-) diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index 96ae35c93..f5ac0e662 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -89,20 +89,7 @@ def Cifar4CL(config): num_workers=config.data.num_workers), } data_dict[client_idx] = dataloader_dict - r""" - - Returns: - data: - { - '{client_id}': { - 'train': Dataset or DataLoader, - 'test': Dataset or DataLoader, - 'val': Dataset or DataLoader - } - } - config: - cfg_node - """ + config = config return data_dict, config @@ -158,20 +145,7 @@ def Cifar4LP(config): num_workers=config.data.num_workers), } data_dict[client_idx] = dataloader_dict - r""" - - Returns: - data: - { - '{client_id}': { - 'train': Dataset or DataLoader, - 'test': Dataset or DataLoader, - 'val': Dataset or DataLoader - } - } - config: - cfg_node - """ + config = config return data_dict, config From 88aa90aac5d978268f2cfe21d7f803ce323391f3 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Fri, 5 Aug 2022 17:23:13 +0800 Subject: [PATCH 11/46] debug --- federatedscope/cl/dataloader/Cifar10.py | 17 +++++------------ federatedscope/cl/model/SimCLR.py | 8 +------- federatedscope/cl/trainer/trainer.py | 4 +++- 3 files changed, 9 insertions(+), 20 deletions(-) diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index f5ac0e662..9a5793a3f 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -40,22 +40,14 @@ def __call__(self, x): def Cifar4CL(config): transform_train = SimCLRTransform(is_sup=False, image_size=32) - transform_test = T.Compose([ - T.ToTensor(), - T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - ) - splits = config.data.splits path = config.data.root - name = config.data.type.upper() - client_num = config.federate.client_num - batch_size = config.data.batch_size data_train = CIFAR10(path, train=True, download=True, transform=transform_train) # data_val = CIFAR10(path, train=True, download=True, transform=transform_train) data_test = CIFAR10(path, train=False, download=True, transform=transform_train) - # Split data into dict + # Split data into dict data_dict = dict() # Splitter @@ -106,10 +98,11 @@ def Cifar4LP(config): T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) + path = config.data.root - data_train = CIFAR10(config.data.root, train=True, download=True, transform=transform_train) - data_val = CIFAR10(config.data.root, train=True, download=True, transform=transform_test) - data_test = CIFAR10(config.data.root, train=False, download=True, transform=transform_test) + data_train = CIFAR10(path, train=True, download=True, transform=transform_train) + data_val = CIFAR10(path, train=True, download=True, transform=transform_test) + data_test = CIFAR10(path, train=False, download=True, transform=transform_test) # Split data into dict data_dict = dict() diff --git a/federatedscope/cl/model/SimCLR.py b/federatedscope/cl/model/SimCLR.py index cfa25f1c4..2a541eabb 100644 --- a/federatedscope/cl/model/SimCLR.py +++ b/federatedscope/cl/model/SimCLR.py @@ -149,7 +149,7 @@ def ResNet56(num_classes=10, block="BasicBlock"): return ResNet_basic(get_block(block), [9,9,9], num_classes=num_classes) -### Retrieval function for backbones ### + def create_backbone(name, num_classes=10, block='BasicBlock'): if(name == 'res18'): net = ResNet18(num_classes=num_classes, block=block) @@ -160,10 +160,6 @@ def create_backbone(name, num_classes=10, block='BasicBlock'): return net - -# SimCLR model - - # Projector class projection_MLP_simclr(nn.Module): def __init__(self, in_dim, hidden_dim=512, out_dim=512): @@ -188,7 +184,6 @@ def __init__(self, bbone_arch): self.projector = projection_MLP_simclr(self.backbone.output_dim, hidden_dim=512, out_dim=512) def forward(self, x1, x2, x3=None, deg_labels=None): - N = x1.shape[0] z1, z2 = self.projector(self.backbone(x1)), self.projector(self.backbone(x2)) return z1, z2 @@ -202,7 +197,6 @@ def __init__(self, bbone_arch, num_classes=10): self.linear = nn.Linear(512, num_classes, bias=True) def forward(self, x): - N = x.shape[0] out = self.backbone(x) out = self.linear(out) diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py index 97f163a9f..afa1f7073 100644 --- a/federatedscope/cl/trainer/trainer.py +++ b/federatedscope/cl/trainer/trainer.py @@ -1,5 +1,7 @@ from federatedscope.register import register_trainer from federatedscope.core.trainers import GeneralTorchTrainer +from federatedscope.core.trainers.context import CtxVar +from federatedscope.core.auxiliaries.enums import LIFECYCLE from federatedscope.core.auxiliaries import utils import numpy as np @@ -14,7 +16,7 @@ def _hook_on_batch_forward(self, ctx): label = label.unsqueeze(0) ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) - ctx.y_prob = CtxVar(z1, z2, LIFECYCLE.BATCH) + ctx.y_prob = CtxVar((z1, z2), LIFECYCLE.BATCH) ctx.loss_batch = CtxVar(ctx.criterion(z1, z2), LIFECYCLE.BATCH) ctx.batch_size = CtxVar(len(label), LIFECYCLE.BATCH) From c716db6e5d246da8e367c3de7470f6d6033921b3 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Sat, 6 Aug 2022 15:39:36 +0800 Subject: [PATCH 12/46] debug --- federatedscope/cl/dataloader/Cifar10.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index 9a5793a3f..a55794314 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -44,7 +44,7 @@ def Cifar4CL(config): path = config.data.root data_train = CIFAR10(path, train=True, download=True, transform=transform_train) -# data_val = CIFAR10(path, train=True, download=True, transform=transform_train) + # data_val = CIFAR10(path, train=True, download=True, transform=transform_train) data_test = CIFAR10(path, train=False, download=True, transform=transform_train) # Split data into dict From 063e88ffc835007be10feeead25053be3db8fc86 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Sat, 6 Aug 2022 16:35:27 +0800 Subject: [PATCH 13/46] debug --- federatedscope/cl/dataloader/Cifar10.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index a55794314..e5b1261d5 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -49,8 +49,7 @@ def Cifar4CL(config): # Split data into dict data_dict = dict() - - # Splitter + splitter = get_splitter(config) data_train = splitter(data_train) data_val = data_train From 8058ae6be3c681cdad439c4cccc836525fd618f7 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Sat, 6 Aug 2022 16:43:18 +0800 Subject: [PATCH 14/46] debug --- federatedscope/cl/dataloader/Cifar10.py | 1 - 1 file changed, 1 deletion(-) diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index e5b1261d5..8f99061c0 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -49,7 +49,6 @@ def Cifar4CL(config): # Split data into dict data_dict = dict() - splitter = get_splitter(config) data_train = splitter(data_train) data_val = data_train From e543b1b70ddcafa999ab07576ae11573b8a7e8f8 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Mon, 8 Aug 2022 15:02:07 +0800 Subject: [PATCH 15/46] debug --- federatedscope/cl/dataloader/Cifar10.py | 1 - 1 file changed, 1 deletion(-) diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index 8f99061c0..eadb681ce 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -44,7 +44,6 @@ def Cifar4CL(config): path = config.data.root data_train = CIFAR10(path, train=True, download=True, transform=transform_train) - # data_val = CIFAR10(path, train=True, download=True, transform=transform_train) data_test = CIFAR10(path, train=False, download=True, transform=transform_train) # Split data into dict From aad20382bb7fcebddd3b24738a048362c68541b0 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Wed, 31 Aug 2022 05:14:27 +0800 Subject: [PATCH 16/46] FedGlobalContrast --- .../cl/baseline/fedgc_on_cifar10.yaml | 36 +++ federatedscope/cl/fedgc/utils.py | 43 ++++ federatedscope/cl/fedgc/worker.py | 226 ++++++++++++++++++ federatedscope/cl/trainer/trainer.py | 54 +++++ federatedscope/core/aggregators/aggregator.py | 29 +++ .../core/auxiliaries/worker_builder.py | 6 + federatedscope/core/configs/constants.py | 3 + 7 files changed, 397 insertions(+) create mode 100644 federatedscope/cl/baseline/fedgc_on_cifar10.yaml create mode 100644 federatedscope/cl/fedgc/utils.py create mode 100644 federatedscope/cl/fedgc/worker.py diff --git a/federatedscope/cl/baseline/fedgc_on_cifar10.yaml b/federatedscope/cl/baseline/fedgc_on_cifar10.yaml new file mode 100644 index 000000000..bb71bdb6e --- /dev/null +++ b/federatedscope/cl/baseline/fedgc_on_cifar10.yaml @@ -0,0 +1,36 @@ +use_gpu: True +device: 2 +federate: + mode: standalone + total_round_num: 100 + client_num: 5 + share_local_model: True + online_aggr: True + sample_client_rate: 1.0 + method: fedgc + save_to: 'FedGC_on_Cifar4CL_lda0.5_lr0.05_lus5_rn100.ckpt' +data: + root: 'data' + type: 'Cifar4CL' + batch_size: 256 + splitter: 'lda' + splitter_args: [{'alpha': 0.5}] + num_workers: 4 +model: + type: 'SimCLR' +train: + local_update_steps: 1 + batch_or_epoch: 'batch' + optimizer: + lr: 0.05 + momentum: 0.1 +early_stop: + patience: 0 +criterion: + type: 'NT_xentloss' +trainer: + type: 'cltrainer' +eval: + freq: 2 + metrics: ['loss'] + split: ['val', 'test'] \ No newline at end of file diff --git a/federatedscope/cl/fedgc/utils.py b/federatedscope/cl/fedgc/utils.py new file mode 100644 index 000000000..ff7e3f1f4 --- /dev/null +++ b/federatedscope/cl/fedgc/utils.py @@ -0,0 +1,43 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import networkx as nx + + +def norm(w): + return torch.norm(torch.cat([v.flatten() for v in w.values()])).item() + + +def compute_global_NT_xentloss(z1, z2, others_z2=[], temperature=0.5): + """ computes global NT_xentloss""" + N, Z = z1.shape + representations = torch.cat([z1, z2], dim=0) + similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) + + l_pos = torch.diag(similarity_matrix, N) + r_pos = torch.diag(similarity_matrix, -N) + positives = torch.cat([l_pos, r_pos]).view(2 * N, 1) + + diag = torch.eye(2*N, dtype=torch.bool) + diag[N:,:N] = diag[:N,N:] = diag[:N,:N] + negatives = similarity_matrix[~diag].view(2*N, -1) + + + if len(others_z2) != 0: + for z2_ in others_z2: + N2, Z2 = z2_.shape + representations = torch.cat([z1, z2_], dim=0) + similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) + mask = torch.zeros_like(similarity_matrix, dtype=torch.bool) + mask[N:,:N] = True + mask[:N,N:] = True + negatives_other = similarity_matrix[mask].view(2*N, -1) + negatives = torch.cat([negatives, negatives_other], dim=1) + + + logits = torch.cat([positives, negatives], dim=1) / temperature + labels = torch.zeros(2*N, dtype=torch.int64) # scalar label per sample + loss = F.cross_entropy(logits, labels, reduction='sum') + + return loss / (2 * N) diff --git a/federatedscope/cl/fedgc/worker.py b/federatedscope/cl/fedgc/worker.py new file mode 100644 index 000000000..4adbbed36 --- /dev/null +++ b/federatedscope/cl/fedgc/worker.py @@ -0,0 +1,226 @@ +import torch +import logging +import copy +import numpy as np + +from federatedscope.core.message import Message +from federatedscope.core.workers.server import Server +from federatedscope.core.workers.client import Client +from federatedscope.core.auxiliaries.utils import merge_dict +from federatedscope.cl.fedgc.utils import compute_global_NT_xentloss + +logger = logging.getLogger(__name__) + + +class GlobalContrastFLServer(Server): + def __init__(self, + ID=-1, + state=0, + config=None, + data=None, + model=None, + client_num=5, + total_round_num=10, + device='cpu', + strategy=None, + **kwargs): + super(GlobalContrastFLServer, + self).__init__(ID, state, config, data, model, client_num, + total_round_num, device, strategy, **kwargs) + # Initial seqs_embedding + self.seqs_embedding = { + idx: () + for idx in range(1, self._cfg.federate.client_num + 1) + } + self.loss_list= { + idx: 0 + for idx in range(1, self._cfg.federate.client_num + 1) + } + + + def check_and_move_on(self, check_eval_result=False): + + if check_eval_result: + # all clients are participating in evaluation + minimal_number = self.client_num + else: + # sampled clients are participating in training + minimal_number = self.sample_client_num + + if self.check_buffer(self.state, minimal_number, check_eval_result): + + if not check_eval_result: # in the training process + # Receiving enough feedback in the training process + aggregated_num = self._perform_federated_aggregation() + + # Get all the message + train_msg_buffer = self.msg_buffer['train'][self.state] + for model_idx in range(self.model_num): + model = self.models[model_idx] + aggregator = self.aggregators[model_idx] + msg_list = list() + for client_id in train_msg_buffer: + if self.model_num == 1: + train_data_size, model_para, pred_embedding = \ + train_msg_buffer[client_id] + self.seqs_embedding[client_id] = pred_embedding + msg_list.append((train_data_size, model_para, pred_embedding)) + else: + raise ValueError( + 'GlobalContrastFL server not support multi-model.') + + for client_id in train_msg_buffer: + z1, z2 = self.seqs_embedding[client_id][0], self.seqs_embedding[client_id][1] + others_z2 = [self.seqs_embedding[other_client_id][1] + for other_client_id in train_msg_buffer + if other_client_id != client_id] + print("start cal loss\n") + self.loss_list[client_id] = compute_global_NT_xentloss(z1, z2, others_z2) + print(self.loss_list[client_id]) + print("end cal loss\n") + + + self.state += 1 + if self.state % self._cfg.eval.freq == 0 and self.state != \ + self.total_round_num: + # Evaluate + logger.info( + 'Server: Starting evaluation at round {:d}.'.format( + self.state)) + self.eval() + + if self.state < self.total_round_num: + self._start_new_training_round(aggregated_num) + for client_id in train_msg_buffer: + + msg_list = { + 'global_loss': self.loss_list[client_id], + } + + # Send loss to Clients + self.comm_manager.send( + Message(msg_type='global_loss', + sender=self.ID, + receiver=[client_id], + state=self.state, + content=msg_list)) + + + # Move to next round of training + logger.info( + f'----------- Starting a new traininground(Round ' + f'#{self.state}) -------------') + # Clean the msg_buffer + self.msg_buffer['train'][self.state - 1].clear() + self.msg_buffer['train'][self.state] = dict() + self.staled_msg_buffer.clear() + # Start a new training round + + + else: + # Final Evaluate + logger.info('Server: Training is finished! Starting ' + 'evaluation.') + self.eval() + + else: # in the evaluation process + # Get all the message & aggregate + formatted_eval_res = self.merge_eval_results_from_all_clients() + self.history_results = merge_dict(self.history_results, + formatted_eval_res) + self.check_and_save() + + def callback_funcs_model_para(self, message: Message): + """ + The handling function for receiving model parameters, which triggers + check_and_move_on (perform aggregation when enough feedback has + been received). + This handling function is widely used in various FL courses. + + Arguments: + message: The received message, which includes sender, receiver, + state, and content. More detail can be found in + federatedscope.core.message + """ + if self.is_finish: + return 'finish' + + round = message.state + sender = message.sender + timestamp = message.timestamp + content = message.content + self.sampler.change_state(sender, 'idle') + + # update the currency timestamp according to the received message + assert timestamp >= self.cur_timestamp # for test + self.cur_timestamp = timestamp + + if round == self.state: + if round not in self.msg_buffer['train']: + self.msg_buffer['train'][round] = dict() + # Save the messages in this round + self.msg_buffer['train'][round][sender] = content + elif round >= self.state - self.staleness_toleration: + # Save the staled messages + self.staled_msg_buffer.append((round, sender, content)) + else: + # Drop the out-of-date messages + logger.info(f'Drop a out-of-date message from round #{round}') + self.dropout_num += 1 + + if self._cfg.federate.online_aggr: + self.aggregator.inc(content[:2]) + + move_on_flag = self.check_and_move_on() + if self._cfg.asyn.use and self._cfg.asyn.broadcast_manner == \ + 'after_receiving': + self.broadcast_model_para(msg_type='model_para', + sample_client_num=1) + + return move_on_flag + + +class GlobalContrastFLClient(Client): + def _register_default_handlers(self): + self.register_handlers('assign_client_id', + self.callback_funcs_for_assign_id) + self.register_handlers('ask_for_join_in_info', + self.callback_funcs_for_join_in_info) + self.register_handlers('address', self.callback_funcs_for_address) + self.register_handlers('model_para', + self.callback_funcs_for_model_para) + self.register_handlers('global_loss', self.callback_funcs_for_global_loss) + self.register_handlers('ss_model_para', + self.callback_funcs_for_model_para) + + self.register_handlers('evaluate', self.callback_funcs_for_evaluate) + self.register_handlers('finish', self.callback_funcs_for_finish) + self.register_handlers('converged', self.callback_funcs_for_converged) + + def callback_funcs_for_global_loss(self, message: Message): + round, sender, content = message.state, message.sender, message.content + global_loss = content['global_loss'] + model_para_old = self.trainer.get_model_para() + model_para = self.trainer.train_with_global_loss(model_para_old, global_loss) + self.trainer.update(model_para) + + def callback_funcs_for_model_para(self, message: Message): + round, sender, content = message.state, message.sender, message.content + self.trainer.update(content) + self.state = round + sample_size, model_para, results = self.trainer.train() + pred_embedding = self.trainer.get_train_pred_embedding() + if self._cfg.federate.share_local_model and not \ + self._cfg.federate.online_aggr: + model_para = copy.deepcopy(model_para) + logger.info( + self._monitor.format_eval_res(results, + rnd=self.state, + role='Client #{}'.format(self.ID))) + + self.comm_manager.send( + Message(msg_type='model_para', + sender=self.ID, + receiver=[sender], + state=self.state, + content=(sample_size, model_para, pred_embedding))) diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py index afa1f7073..dcb8b7792 100644 --- a/federatedscope/cl/trainer/trainer.py +++ b/federatedscope/cl/trainer/trainer.py @@ -1,16 +1,46 @@ +from federatedscope.core.auxiliaries.enums import MODE from federatedscope.register import register_trainer +from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer +from federatedscope.core.auxiliaries.scheduler_builder import get_scheduler from federatedscope.core.trainers import GeneralTorchTrainer from federatedscope.core.trainers.context import CtxVar from federatedscope.core.auxiliaries.enums import LIFECYCLE from federatedscope.core.auxiliaries import utils +import torch import numpy as np class CLTrainer(GeneralTorchTrainer): + def __init__(self, + model, + data, + device, + config, + only_for_eval=False, + monitor=None): + super(CLTrainer, self).__init__(model, data, device, config, + only_for_eval, monitor) + self.batches_aug_data_1, self.batches_aug_data_2 = [], [] + + @torch.no_grad() + def get_train_pred_embedding(self): + model = self.ctx.model.to(self.ctx.device) + ys_prob_1, ys_prob_2 = [], [] + x1, x2 = torch.cat(self.batches_aug_data_1, dim=0), torch.cat(self.batches_aug_data_2, dim=0) + z1, z2 = model(x1.to(self.ctx.device), x2.to(self.ctx.device)) + ys_prob_1 = z1.detach().cpu() + ys_prob_2 = z2.detach().cpu() + print(ys_prob_1.size()) + self.batches_aug_data_1, self.batches_aug_data_2 = [], [] + + return [ys_prob_1, ys_prob_2] + def _hook_on_batch_forward(self, ctx): x, label = [utils.move_to(_, ctx.device) for _ in ctx.data_batch] # print(len(x), x[0].size(), x[1].size(), label.size()) x1, x2 = x[0], x[1] + self.batches_aug_data_1.append(x1) + self.batches_aug_data_2.append(x2) z1, z2 = ctx.model(x1, x2) if len(label.size()) == 0: label = label.unsqueeze(0) @@ -30,6 +60,30 @@ def _hook_on_batch_end(self, ctx): ctx.ys_true.append(ctx.y_true.detach().cpu().numpy()) ctx.ys_prob.append(ctx.y_prob[0].detach().cpu().numpy()) + def train_with_global_loss(self, model_para, loss): + """ + Arguments: + model_para: model parameters + loss: loss after global calculate + :returns: + grads: grads to optimize the model of other clients + """ + + for key in model_para.keys(): + if isinstance(model_para[key], list): + model_para[key] = torch.FloatTensor(model_para[key]) + self.ctx.model.load_state_dict(model_para) + self.ctx.model = self.ctx.model.to(self.ctx.device) + + self.ctx.optimizer.zero_grad() + + loss = loss.requires_grad_() + loss.backward() + self.ctx.optimizer.step() + + return self.ctx.model.state_dict() + + class LPTrainer(GeneralTorchTrainer): pass diff --git a/federatedscope/core/aggregators/aggregator.py b/federatedscope/core/aggregators/aggregator.py index c8e2052ac..af9c842c0 100644 --- a/federatedscope/core/aggregators/aggregator.py +++ b/federatedscope/core/aggregators/aggregator.py @@ -13,6 +13,35 @@ def aggregate(self, agg_info): class NoCommunicationAggregator(Aggregator): """"Clients do not communicate. Each client work locally """ + def __init__(self, model=None, device='cpu', config=None): + super(Aggregator, self).__init__() + self.model = model + self.device = device + self.cfg = config + + def update(self, model_parameters): + ''' + Arguments: + model_parameters (dict): PyTorch Module object's state_dict. + ''' + self.model.load_state_dict(model_parameters, strict=False) + + def save_model(self, path, cur_round=-1): + assert self.model is not None + + ckpt = {'cur_round': cur_round, 'model': self.model.state_dict()} + torch.save(ckpt, path) + + def load_model(self, path): + assert self.model is not None + + if os.path.exists(path): + ckpt = torch.load(path, map_location=self.device) + self.model.load_state_dict(ckpt['model']) + return ckpt['cur_round'] + else: + raise ValueError("The file {} does NOT exist".format(path)) + def aggregate(self, agg_info): # do nothing return {} diff --git a/federatedscope/core/auxiliaries/worker_builder.py b/federatedscope/core/auxiliaries/worker_builder.py index bff5c6183..e0e6e878d 100644 --- a/federatedscope/core/auxiliaries/worker_builder.py +++ b/federatedscope/core/auxiliaries/worker_builder.py @@ -29,6 +29,9 @@ def get_client_cls(cfg): elif client_type == 'gcflplus': from federatedscope.gfl.gcflplus.worker import GCFLPlusClient client_class = GCFLPlusClient + elif client_type == 'fedgc': + from federatedscope.cl.fedgc.worker import GlobalContrastFLClient + client_class = GlobalContrastFLClient else: client_class = Client @@ -85,5 +88,8 @@ def get_server_cls(cfg): elif client_type == 'gcflplus': from federatedscope.gfl.gcflplus.worker import GCFLPlusServer return GCFLPlusServer + elif client_type == 'fedgc': + from federatedscope.cl.fedgc.worker import GlobalContrastFLServer + return GlobalContrastFLServer else: return Server diff --git a/federatedscope/core/configs/constants.py b/federatedscope/core/configs/constants.py index ded6105a6..b06f3c632 100644 --- a/federatedscope/core/configs/constants.py +++ b/federatedscope/core/configs/constants.py @@ -19,6 +19,7 @@ "ditto": "clients_avg", # Ditto "fedsageplus": "clients_avg", "gcflplus": "clients_avg", + "fedgc": "clients_avg", "fedopt": "fedopt" } @@ -31,6 +32,7 @@ # models "fedsageplus": "fedsageplus", # FedSage+ for graph data "gcflplus": "gcflplus", # GCFL+ for graph data + "fedgc": "fedgc", "gradascent": "gradascent" } @@ -42,4 +44,5 @@ # models "fedsageplus": "fedsageplus", # FedSage+ for graph data "gcflplus": "gcflplus", # GCFL+ for graph data + "fedgc": "fedgc" } From d03eef5233e6d9e2b5fbebae598816c0547e82a7 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Mon, 5 Sep 2022 05:30:14 +0800 Subject: [PATCH 17/46] modify and add docstring --- .../cl/baseline/fedgc_on_cifar10.yaml | 8 +-- federatedscope/cl/dataloader/Cifar10.py | 32 +++++++++- federatedscope/cl/fedgc/client.py | 59 +++++++++++++++++++ .../cl/fedgc/{worker.py => server.py} | 56 +++--------------- federatedscope/cl/fedgc/utils.py | 14 ++++- federatedscope/cl/loss/NT_xentloss.py | 14 ++++- federatedscope/cl/model/SimCLR.py | 59 +------------------ federatedscope/cl/trainer/trainer.py | 6 +- .../core/auxiliaries/worker_builder.py | 4 +- 9 files changed, 131 insertions(+), 121 deletions(-) create mode 100644 federatedscope/cl/fedgc/client.py rename federatedscope/cl/fedgc/{worker.py => server.py} (74%) diff --git a/federatedscope/cl/baseline/fedgc_on_cifar10.yaml b/federatedscope/cl/baseline/fedgc_on_cifar10.yaml index bb71bdb6e..16a424d9b 100644 --- a/federatedscope/cl/baseline/fedgc_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedgc_on_cifar10.yaml @@ -2,13 +2,13 @@ use_gpu: True device: 2 federate: mode: standalone - total_round_num: 100 + total_round_num: 300 client_num: 5 share_local_model: True online_aggr: True sample_client_rate: 1.0 method: fedgc - save_to: 'FedGC_on_Cifar4CL_lda0.5_lr0.05_lus5_rn100.ckpt' + save_to: 'FedGC_on_Cifar4CL_lda0.5_lr0.05_lus5b_rn300.ckpt' data: root: 'data' type: 'Cifar4CL' @@ -19,10 +19,10 @@ data: model: type: 'SimCLR' train: - local_update_steps: 1 + local_update_steps: 5 batch_or_epoch: 'batch' optimizer: - lr: 0.05 + lr: 0.01 momentum: 0.1 early_stop: patience: 0 diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index eadb681ce..22bd75654 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -16,6 +16,16 @@ class SimCLRTransform(): + r""" + Data Augmentations of SimCLR refer from https://github.com/akhilmathurs/orchestra/blob/main/utils.py + Arguments: + is_sup (bool): the transform for supervised learning or contrastive learning. + :returns: + torch.tensor: one output for supervised learning. + :returns: + torch.tensor: two output for contrastive learning + torch.tensor: two output for contrastive learning + """ def __init__(self, is_sup, image_size=32): self.transform = T.Compose([ T.RandomResizedCrop(image_size, scale=(0.5, 1.0), interpolation=T.InterpolationMode.BICUBIC), @@ -38,7 +48,16 @@ def __call__(self, x): return x1, x2 def Cifar4CL(config): - + r""" + generate Cifar10 Dataset transform and split dict for contrastive learning + return { + 'client_id': { + 'train': DataLoader(), + 'test': DataLoader(), + 'val': DataLoader() + } + } + """ transform_train = SimCLRTransform(is_sup=False, image_size=32) path = config.data.root @@ -83,7 +102,16 @@ def Cifar4CL(config): return data_dict, config def Cifar4LP(config): - + r""" + generate Cifar10 Dataset transform and split dict for linear prob evaluation of contrastive learning + return { + 'client_id': { + 'train': DataLoader(), + 'test': DataLoader(), + 'val': DataLoader() + } + } + """ transform_train = T.Compose([ T.RandomResizedCrop(32, scale=(0.5, 1.0), interpolation=T.InterpolationMode.BICUBIC), T.RandomHorizontalFlip(p=0.5), diff --git a/federatedscope/cl/fedgc/client.py b/federatedscope/cl/fedgc/client.py new file mode 100644 index 000000000..01f64d4ae --- /dev/null +++ b/federatedscope/cl/fedgc/client.py @@ -0,0 +1,59 @@ +import torch +import logging +import copy +import numpy as np + +from federatedscope.core.message import Message +from federatedscope.core.workers.client import Client +from federatedscope.core.auxiliaries.utils import merge_dict + +logger = logging.getLogger(__name__) + +class GlobalContrastFLClient(Client): + r""" + GlobalContrastFL(Fedgc) Client receive aggregated model weight from server then update local + weight; it also receive global loss from server to train model and update weight locally. + """ + def _register_default_handlers(self): + self.register_handlers('assign_client_id', + self.callback_funcs_for_assign_id) + self.register_handlers('ask_for_join_in_info', + self.callback_funcs_for_join_in_info) + self.register_handlers('address', self.callback_funcs_for_address) + self.register_handlers('model_para', + self.callback_funcs_for_model_para) + self.register_handlers('global_loss', self.callback_funcs_for_global_loss) + self.register_handlers('ss_model_para', + self.callback_funcs_for_model_para) + + self.register_handlers('evaluate', self.callback_funcs_for_evaluate) + self.register_handlers('finish', self.callback_funcs_for_finish) + self.register_handlers('converged', self.callback_funcs_for_converged) + + def callback_funcs_for_global_loss(self, message: Message): + round, sender, content = message.state, message.sender, message.content + global_loss = content['global_loss'] + model_para_old = self.trainer.get_model_para() + model_para = self.trainer.train_with_global_loss(model_para_old, global_loss) + self.trainer.update(model_para) + + def callback_funcs_for_model_para(self, message: Message): + round, sender, content = message.state, message.sender, message.content + self.trainer.update(content) + self.state = round + sample_size, model_para, results = self.trainer.train() + pred_embedding = self.trainer.get_train_pred_embedding() + if self._cfg.federate.share_local_model and not \ + self._cfg.federate.online_aggr: + model_para = copy.deepcopy(model_para) + logger.info( + self._monitor.format_eval_res(results, + rnd=self.state, + role='Client #{}'.format(self.ID))) + + self.comm_manager.send( + Message(msg_type='model_para', + sender=self.ID, + receiver=[sender], + state=self.state, + content=(sample_size, model_para, pred_embedding))) diff --git a/federatedscope/cl/fedgc/worker.py b/federatedscope/cl/fedgc/server.py similarity index 74% rename from federatedscope/cl/fedgc/worker.py rename to federatedscope/cl/fedgc/server.py index 4adbbed36..157da6e93 100644 --- a/federatedscope/cl/fedgc/worker.py +++ b/federatedscope/cl/fedgc/server.py @@ -5,7 +5,6 @@ from federatedscope.core.message import Message from federatedscope.core.workers.server import Server -from federatedscope.core.workers.client import Client from federatedscope.core.auxiliaries.utils import merge_dict from federatedscope.cl.fedgc.utils import compute_global_NT_xentloss @@ -13,6 +12,11 @@ class GlobalContrastFLServer(Server): + r""" + GlobalContrastFL(Fedgc) Server contain two part in training: Fedavg aggragator + for client model weight and calculate global loss from all sampled client embedding + then broadcast all client to train model. + """ def __init__(self, ID=-1, state=0, @@ -74,10 +78,10 @@ def check_and_move_on(self, check_eval_result=False): others_z2 = [self.seqs_embedding[other_client_id][1] for other_client_id in train_msg_buffer if other_client_id != client_id] - print("start cal loss\n") +# print("start cal loss") self.loss_list[client_id] = compute_global_NT_xentloss(z1, z2, others_z2) print(self.loss_list[client_id]) - print("end cal loss\n") +# print("end cal loss") self.state += 1 @@ -178,49 +182,3 @@ def callback_funcs_model_para(self, message: Message): sample_client_num=1) return move_on_flag - - -class GlobalContrastFLClient(Client): - def _register_default_handlers(self): - self.register_handlers('assign_client_id', - self.callback_funcs_for_assign_id) - self.register_handlers('ask_for_join_in_info', - self.callback_funcs_for_join_in_info) - self.register_handlers('address', self.callback_funcs_for_address) - self.register_handlers('model_para', - self.callback_funcs_for_model_para) - self.register_handlers('global_loss', self.callback_funcs_for_global_loss) - self.register_handlers('ss_model_para', - self.callback_funcs_for_model_para) - - self.register_handlers('evaluate', self.callback_funcs_for_evaluate) - self.register_handlers('finish', self.callback_funcs_for_finish) - self.register_handlers('converged', self.callback_funcs_for_converged) - - def callback_funcs_for_global_loss(self, message: Message): - round, sender, content = message.state, message.sender, message.content - global_loss = content['global_loss'] - model_para_old = self.trainer.get_model_para() - model_para = self.trainer.train_with_global_loss(model_para_old, global_loss) - self.trainer.update(model_para) - - def callback_funcs_for_model_para(self, message: Message): - round, sender, content = message.state, message.sender, message.content - self.trainer.update(content) - self.state = round - sample_size, model_para, results = self.trainer.train() - pred_embedding = self.trainer.get_train_pred_embedding() - if self._cfg.federate.share_local_model and not \ - self._cfg.federate.online_aggr: - model_para = copy.deepcopy(model_para) - logger.info( - self._monitor.format_eval_res(results, - rnd=self.state, - role='Client #{}'.format(self.ID))) - - self.comm_manager.send( - Message(msg_type='model_para', - sender=self.ID, - receiver=[sender], - state=self.state, - content=(sample_size, model_para, pred_embedding))) diff --git a/federatedscope/cl/fedgc/utils.py b/federatedscope/cl/fedgc/utils.py index ff7e3f1f4..36e6313f9 100644 --- a/federatedscope/cl/fedgc/utils.py +++ b/federatedscope/cl/fedgc/utils.py @@ -10,7 +10,19 @@ def norm(w): def compute_global_NT_xentloss(z1, z2, others_z2=[], temperature=0.5): - """ computes global NT_xentloss""" + r""" + global_NT_xentloss is federated NT_xentloss in server. It collect sample client + embedding and calculate NT_xentloss from local client positive examples and + negative examples of local and other clients. + Arguments: + z1 (torch.tensor): the embedding of local model. + z2 (torch.tensor): the embedding of local model using another augmentation. + others_z2 (list[torch.tensor]): the embedding list of other clients, each client has two embedding. + returns: + loss: the NT_xentloss loss for this aggregation of global clients + :rtype: + torch.FloatTensor + """ N, Z = z1.shape representations = torch.cat([z1, z2], dim=0) similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) diff --git a/federatedscope/cl/loss/NT_xentloss.py b/federatedscope/cl/loss/NT_xentloss.py index 58b257cef..6791cfe97 100644 --- a/federatedscope/cl/loss/NT_xentloss.py +++ b/federatedscope/cl/loss/NT_xentloss.py @@ -5,6 +5,16 @@ from federatedscope.register import register_criterion class NT_xentloss(nn.Module): + r""" + NT_xentloss definition adapted from https://github.com/PatrickHua/SimSiam + Arguments: + z1 (torch.tensor): the embedding of model . + z2 (torch.tensor): the embedding of model using another augmentation. + returns: + loss: the NT_xentloss loss for this batch data + :rtype: + torch.FloatTensor + """ def __init__(self, temperature=0.5): super(NT_xentloss, self).__init__() self.temperature = temperature @@ -25,9 +35,9 @@ def forward(self, z1, z2): logits = torch.cat([positives, negatives], dim=1) / self.temperature labels = torch.zeros(2*N, device=device, dtype=torch.int64) # scalar label per sample - loss = F.cross_entropy(logits, labels, reduction='sum') + loss = F.cross_entropy(logits, labels, reduction='sum') / (2 * N) - return loss / (2 * N) + return loss def create_NT_xentloss(type, device): diff --git a/federatedscope/cl/model/SimCLR.py b/federatedscope/cl/model/SimCLR.py index 2a541eabb..efadfcf71 100644 --- a/federatedscope/cl/model/SimCLR.py +++ b/federatedscope/cl/model/SimCLR.py @@ -5,61 +5,7 @@ from math import pi, cos, e import numpy as np from collections import OrderedDict - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, in_planes, planes, stride=1): - super(BasicBlock, self).__init__() - self.use_shortcut = stride != 1 or in_planes != self.expansion*planes - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes, affine=True) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes, affine=True) - - self.shortcut_conv = nn.Sequential() - if self.use_shortcut: - self.shortcut_conv = nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) - self.shortcut_bn = nn.BatchNorm2d(self.expansion*planes, affine=True) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - shortcut = self.shortcut_conv(x) - if self.use_shortcut: - shortcut = self.shortcut_bn(shortcut) - out += shortcut - return F.relu(out) - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, in_planes, planes, stride=1): - super(Bottleneck, self).__init__() - self.use_shortcut = stride != 1 or in_planes != self.expansion*planes - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes, affine=True) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes, affine=True) - self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(self.expansion*planes, affine=True) - - self.shortcut_conv = nn.Sequential() - if self.use_shortcut: - self.shortcut_conv = nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) - self.shortcut_bn = nn.BatchNorm2d(self.expansion*planes, affine=True) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = F.relu(self.bn2(self.conv2(out))) - out = self.bn3(self.conv3(out)) - shortcut = self.shortcut_conv(x) - if self.use_shortcut: - shortcut = self.shortcut_bn(shortcut) - out += shortcut - return F.relu(out) - +from federatedscope.contrib.model.resnet import BasicBlock, Bottleneck # Model class class ResNet(nn.Module): @@ -145,9 +91,6 @@ def ResNet18(num_classes=10, block="BasicBlock"): def ResNet34(num_classes=10, block="BasicBlock"): return ResNet(get_block(block), [3,4,6,3], num_classes=num_classes) -def ResNet56(num_classes=10, block="BasicBlock"): - return ResNet_basic(get_block(block), [9,9,9], num_classes=num_classes) - def create_backbone(name, num_classes=10, block='BasicBlock'): diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py index dcb8b7792..5d4a5b671 100644 --- a/federatedscope/cl/trainer/trainer.py +++ b/federatedscope/cl/trainer/trainer.py @@ -30,7 +30,6 @@ def get_train_pred_embedding(self): z1, z2 = model(x1.to(self.ctx.device), x2.to(self.ctx.device)) ys_prob_1 = z1.detach().cpu() ys_prob_2 = z2.detach().cpu() - print(ys_prob_1.size()) self.batches_aug_data_1, self.batches_aug_data_2 = [], [] return [ys_prob_1, ys_prob_2] @@ -39,8 +38,9 @@ def _hook_on_batch_forward(self, ctx): x, label = [utils.move_to(_, ctx.device) for _ in ctx.data_batch] # print(len(x), x[0].size(), x[1].size(), label.size()) x1, x2 = x[0], x[1] - self.batches_aug_data_1.append(x1) - self.batches_aug_data_2.append(x2) + if ctx.cur_mode in [MODE.TRAIN]: + self.batches_aug_data_1.append(x1) + self.batches_aug_data_2.append(x2) z1, z2 = ctx.model(x1, x2) if len(label.size()) == 0: label = label.unsqueeze(0) diff --git a/federatedscope/core/auxiliaries/worker_builder.py b/federatedscope/core/auxiliaries/worker_builder.py index e0e6e878d..b8ef250f5 100644 --- a/federatedscope/core/auxiliaries/worker_builder.py +++ b/federatedscope/core/auxiliaries/worker_builder.py @@ -30,7 +30,7 @@ def get_client_cls(cfg): from federatedscope.gfl.gcflplus.worker import GCFLPlusClient client_class = GCFLPlusClient elif client_type == 'fedgc': - from federatedscope.cl.fedgc.worker import GlobalContrastFLClient + from federatedscope.cl.fedgc.client import GlobalContrastFLClient client_class = GlobalContrastFLClient else: client_class = Client @@ -89,7 +89,7 @@ def get_server_cls(cfg): from federatedscope.gfl.gcflplus.worker import GCFLPlusServer return GCFLPlusServer elif client_type == 'fedgc': - from federatedscope.cl.fedgc.worker import GlobalContrastFLServer + from federatedscope.cl.fedgc.server import GlobalContrastFLServer return GlobalContrastFLServer else: return Server From a43ceef9243c200365a76e8932d3e9266702b31b Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Mon, 5 Sep 2022 16:06:34 +0800 Subject: [PATCH 18/46] create repro_exp shell and report exp result the linear prob experiment result of fedgc in cifar10 is val_acc=0.49 and test_acc=0.4985, the exp paramaters are total_round_num: 300 lr:0.05 local_update_steps: 5 batch_or_epoch: 'batch' --- .../baseline/repro_exp/args_cifar10_fedgc.sh | 11 +++++ .../repro_exp/args_cifar10_fedsimclr.sh | 7 +++ .../repro_exp/run_contrastive_learning.sh | 45 +++++++++++++++++++ 3 files changed, 63 insertions(+) create mode 100644 federatedscope/cl/baseline/repro_exp/args_cifar10_fedgc.sh create mode 100644 federatedscope/cl/baseline/repro_exp/args_cifar10_fedsimclr.sh create mode 100644 federatedscope/cl/baseline/repro_exp/run_contrastive_learning.sh diff --git a/federatedscope/cl/baseline/repro_exp/args_cifar10_fedgc.sh b/federatedscope/cl/baseline/repro_exp/args_cifar10_fedgc.sh new file mode 100644 index 000000000..6561e5289 --- /dev/null +++ b/federatedscope/cl/baseline/repro_exp/args_cifar10_fedgc.sh @@ -0,0 +1,11 @@ +# ---------------------------------------------------------------------- # +# Fedgc +# ---------------------------------------------------------------------- # + +bash run_contrastive_learning.sh 1 fedgc cifar10 + + + + + + diff --git a/federatedscope/cl/baseline/repro_exp/args_cifar10_fedsimclr.sh b/federatedscope/cl/baseline/repro_exp/args_cifar10_fedsimclr.sh new file mode 100644 index 000000000..16f811cb3 --- /dev/null +++ b/federatedscope/cl/baseline/repro_exp/args_cifar10_fedsimclr.sh @@ -0,0 +1,7 @@ +# ---------------------------------------------------------------------- # +# Fedsimclr +# ---------------------------------------------------------------------- # + +bash run_contrastive_learning.sh 1 fedsimclr cifar10 + + diff --git a/federatedscope/cl/baseline/repro_exp/run_contrastive_learning.sh b/federatedscope/cl/baseline/repro_exp/run_contrastive_learning.sh new file mode 100644 index 000000000..06fc40076 --- /dev/null +++ b/federatedscope/cl/baseline/repro_exp/run_contrastive_learning.sh @@ -0,0 +1,45 @@ +set -e + +cudaid=$1 +method_name=$2 +dataset=$3 + +cd ../../../.. + +if [ ! -d "out" ];then + mkdir out +fi + +if [[ $method_name = 'fedgc' ]]; then + method='fedgc' + total_round_num='300' + batch_or_epoch='batch' +elif [[ $method_name = 'fedsimclr' ]]; then + method='Fedavg' + total_round_num='100' + batch_or_epoch='epoch' +fi + +echo "Fed Contrastive Learning starts..." + +lrs=(0.01 0.05 0.25) +local_updates=(1 3 5) + + +for (( i=0; i<${#lrs[@]}; i++ )) +do + for (( j=0; j<${#local_updates[@]}; j++ )) + do + for k in {1..5} + do + train_yaml=${method_name}_on_${dataset}.yaml + save_path=${method_name}_on_Cifar4CL_lda0.5_lr${lrs[$i]}_lus${local_updates[$j]}_rn${total_round_num}${batch_or_epoch}.ckpt + python federatedscope/main.py --cfg federatedscope/cl/baseline/${train_yaml} device ${cudaid} federate.save_to ${save_path} train.optimizer.lr ${lrs[$i]} train.local_update_steps ${local_updates[$j]} seed $k >>out/${method_name}_on_Cifar4CL_lda0.5_lr${lrs[$i]}_lus${local_updates[$j]}_rn${total_round_num}${batch_or_epoch}.log 2>&1 + linear_prob_yaml=fedcontrastlearning_linearprob_on_cifar10.yaml + python federatedscope/main.py --cfg federatedscope/cl/baseline/${linear_prob_yaml} device ${cudaid} federate.restore_from ${save_path} >>out/${method_name}_on_Cifar4CL_lda0.5_lr${lrs[$i]}_lus${local_updates[$j]}_rn${total_round_num}${batch_or_epoch}.log 2>&1 + done + done +done + + +echo "Fed Contrastive Learning ends." From 6b42eeceb6a29872d53a76a8713334d600b494e7 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Wed, 14 Sep 2022 20:31:17 +0800 Subject: [PATCH 19/46] Accelerate computing global loss with GPU and repair load model problem --- ...ontrastlearning_linearprob_on_cifar10.yaml | 36 +++++++++++++++++++ .../cl/baseline/fedgc_on_cifar10.yaml | 9 ++--- .../baseline/repro_exp/args_cifar10_fedgc.sh | 2 +- .../repro_exp/args_cifar10_fedsimclr.sh | 2 +- .../repro_exp/run_contrastive_learning.sh | 15 ++++---- federatedscope/cl/dataloader/Cifar10.py | 1 + federatedscope/cl/fedgc/server.py | 2 +- federatedscope/cl/fedgc/utils.py | 10 +++--- federatedscope/cl/trainer/trainer.py | 14 +++++++- 9 files changed, 72 insertions(+), 19 deletions(-) create mode 100644 federatedscope/cl/baseline/fedcontrastlearning_linearprob_on_cifar10.yaml diff --git a/federatedscope/cl/baseline/fedcontrastlearning_linearprob_on_cifar10.yaml b/federatedscope/cl/baseline/fedcontrastlearning_linearprob_on_cifar10.yaml new file mode 100644 index 000000000..482c4400c --- /dev/null +++ b/federatedscope/cl/baseline/fedcontrastlearning_linearprob_on_cifar10.yaml @@ -0,0 +1,36 @@ +use_gpu: True +seed: 1 +device: 1 +federate: + mode: standalone + total_round_num: 50 + client_num: 5 + sample_client_rate: 1.0 + method: global + restore_from: '../fedsimclr_on_Cifar4CL_lda0.1_lr0.05_lus5_rn200batch_seed2.ckpt' +data: + root: 'data' + type: 'Cifar4LP' + batch_size: 256 + splitter: 'lda' + splitter_args: [{'alpha': 0.1}] + num_workers: 4 +model: + type: 'SimCLR_linear' +train: + local_update_steps: 1 + batch_or_epoch: 'epoch' + optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0 +early_stop: + patience: 0 +criterion: + type: CrossEntropyLoss +trainer: + type: 'lptrainer' +eval: + freq: 2 + metrics: ['acc'] + split: ['val', 'test'] \ No newline at end of file diff --git a/federatedscope/cl/baseline/fedgc_on_cifar10.yaml b/federatedscope/cl/baseline/fedgc_on_cifar10.yaml index 16a424d9b..af86dc369 100644 --- a/federatedscope/cl/baseline/fedgc_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedgc_on_cifar10.yaml @@ -1,20 +1,21 @@ use_gpu: True +seed: 1 device: 2 federate: mode: standalone - total_round_num: 300 + total_round_num: 200 client_num: 5 share_local_model: True online_aggr: True sample_client_rate: 1.0 method: fedgc - save_to: 'FedGC_on_Cifar4CL_lda0.5_lr0.05_lus5b_rn300.ckpt' + save_to: 'FedGC_on_Cifar4CL_lda0.1_lr0.25_lus5b_rn100.ckpt' data: root: 'data' type: 'Cifar4CL' batch_size: 256 splitter: 'lda' - splitter_args: [{'alpha': 0.5}] + splitter_args: [{'alpha': 0.1}] num_workers: 4 model: type: 'SimCLR' @@ -22,7 +23,7 @@ train: local_update_steps: 5 batch_or_epoch: 'batch' optimizer: - lr: 0.01 + lr: 0.25 momentum: 0.1 early_stop: patience: 0 diff --git a/federatedscope/cl/baseline/repro_exp/args_cifar10_fedgc.sh b/federatedscope/cl/baseline/repro_exp/args_cifar10_fedgc.sh index 6561e5289..45d0c369c 100644 --- a/federatedscope/cl/baseline/repro_exp/args_cifar10_fedgc.sh +++ b/federatedscope/cl/baseline/repro_exp/args_cifar10_fedgc.sh @@ -2,7 +2,7 @@ # Fedgc # ---------------------------------------------------------------------- # -bash run_contrastive_learning.sh 1 fedgc cifar10 +bash run_contrastive_learning.sh 2 fedgc cifar10 0.3 diff --git a/federatedscope/cl/baseline/repro_exp/args_cifar10_fedsimclr.sh b/federatedscope/cl/baseline/repro_exp/args_cifar10_fedsimclr.sh index 16f811cb3..ec7225e20 100644 --- a/federatedscope/cl/baseline/repro_exp/args_cifar10_fedsimclr.sh +++ b/federatedscope/cl/baseline/repro_exp/args_cifar10_fedsimclr.sh @@ -2,6 +2,6 @@ # Fedsimclr # ---------------------------------------------------------------------- # -bash run_contrastive_learning.sh 1 fedsimclr cifar10 +bash run_contrastive_learning.sh 0 fedsimclr cifar10 0.5 diff --git a/federatedscope/cl/baseline/repro_exp/run_contrastive_learning.sh b/federatedscope/cl/baseline/repro_exp/run_contrastive_learning.sh index 06fc40076..ba4cf3aa4 100644 --- a/federatedscope/cl/baseline/repro_exp/run_contrastive_learning.sh +++ b/federatedscope/cl/baseline/repro_exp/run_contrastive_learning.sh @@ -3,6 +3,7 @@ set -e cudaid=$1 method_name=$2 dataset=$3 +lda_alpha=$4 cd ../../../.. @@ -12,12 +13,12 @@ fi if [[ $method_name = 'fedgc' ]]; then method='fedgc' - total_round_num='300' + total_round_num='200' batch_or_epoch='batch' elif [[ $method_name = 'fedsimclr' ]]; then method='Fedavg' - total_round_num='100' - batch_or_epoch='epoch' + total_round_num='200' + batch_or_epoch='batch' fi echo "Fed Contrastive Learning starts..." @@ -30,13 +31,13 @@ for (( i=0; i<${#lrs[@]}; i++ )) do for (( j=0; j<${#local_updates[@]}; j++ )) do - for k in {1..5} + for k in {1..2} do train_yaml=${method_name}_on_${dataset}.yaml - save_path=${method_name}_on_Cifar4CL_lda0.5_lr${lrs[$i]}_lus${local_updates[$j]}_rn${total_round_num}${batch_or_epoch}.ckpt - python federatedscope/main.py --cfg federatedscope/cl/baseline/${train_yaml} device ${cudaid} federate.save_to ${save_path} train.optimizer.lr ${lrs[$i]} train.local_update_steps ${local_updates[$j]} seed $k >>out/${method_name}_on_Cifar4CL_lda0.5_lr${lrs[$i]}_lus${local_updates[$j]}_rn${total_round_num}${batch_or_epoch}.log 2>&1 + save_path=${method_name}_on_Cifar4CL_lda${lda_alpha}_lr${lrs[$i]}_lus${local_updates[$j]}_rn${total_round_num}${batch_or_epoch}_seed${k}.ckpt + python federatedscope/main.py --cfg federatedscope/cl/baseline/${train_yaml} device ${cudaid} federate.save_to ${save_path} federate.total_round_num ${total_round_num} data.splitter_args \[\{\'alpha\'\:${lda_alpha}\}\] train.optimizer.lr ${lrs[$i]} train.local_update_steps ${local_updates[$j]} train.batch_or_epoch ${batch_or_epoch} seed $k>>out/${method_name}_on_Cifar4CL_lda${lda_alpha}_lr${lrs[$i]}_lus${local_updates[$j]}_rn${total_round_num}${batch_or_epoch}.log 2>&1 linear_prob_yaml=fedcontrastlearning_linearprob_on_cifar10.yaml - python federatedscope/main.py --cfg federatedscope/cl/baseline/${linear_prob_yaml} device ${cudaid} federate.restore_from ${save_path} >>out/${method_name}_on_Cifar4CL_lda0.5_lr${lrs[$i]}_lus${local_updates[$j]}_rn${total_round_num}${batch_or_epoch}.log 2>&1 + python federatedscope/main.py --cfg federatedscope/cl/baseline/${linear_prob_yaml} device ${cudaid} federate.restore_from ${save_path} data.splitter_args \[\{\'alpha\'\:${lda_alpha}\}\] seed $k>>out/${method_name}_on_Cifar4CL_lda${lda_alpha}_lr${lrs[$i]}_lus${local_updates[$j]}_rn${total_round_num}${batch_or_epoch}.log 2>&1 done done done diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index 22bd75654..dede7e59f 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -69,6 +69,7 @@ def Cifar4CL(config): data_dict = dict() splitter = get_splitter(config) data_train = splitter(data_train) + print([len(i) for i in data_train]) data_val = data_train data_test = splitter(data_test) diff --git a/federatedscope/cl/fedgc/server.py b/federatedscope/cl/fedgc/server.py index 157da6e93..08843f64a 100644 --- a/federatedscope/cl/fedgc/server.py +++ b/federatedscope/cl/fedgc/server.py @@ -79,7 +79,7 @@ def check_and_move_on(self, check_eval_result=False): for other_client_id in train_msg_buffer if other_client_id != client_id] # print("start cal loss") - self.loss_list[client_id] = compute_global_NT_xentloss(z1, z2, others_z2) + self.loss_list[client_id] = compute_global_NT_xentloss(z1, z2, others_z2, device=self.device) print(self.loss_list[client_id]) # print("end cal loss") diff --git a/federatedscope/cl/fedgc/utils.py b/federatedscope/cl/fedgc/utils.py index 36e6313f9..19aa4e7df 100644 --- a/federatedscope/cl/fedgc/utils.py +++ b/federatedscope/cl/fedgc/utils.py @@ -9,7 +9,7 @@ def norm(w): return torch.norm(torch.cat([v.flatten() for v in w.values()])).item() -def compute_global_NT_xentloss(z1, z2, others_z2=[], temperature=0.5): +def compute_global_NT_xentloss(z1, z2, others_z2=[], temperature=0.5, device='cpu'): r""" global_NT_xentloss is federated NT_xentloss in server. It collect sample client embedding and calculate NT_xentloss from local client positive examples and @@ -23,6 +23,7 @@ def compute_global_NT_xentloss(z1, z2, others_z2=[], temperature=0.5): :rtype: torch.FloatTensor """ + z1, z2 = z1.cuda(device=device), z2.cuda(device=device) N, Z = z1.shape representations = torch.cat([z1, z2], dim=0) similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) @@ -31,17 +32,18 @@ def compute_global_NT_xentloss(z1, z2, others_z2=[], temperature=0.5): r_pos = torch.diag(similarity_matrix, -N) positives = torch.cat([l_pos, r_pos]).view(2 * N, 1) - diag = torch.eye(2*N, dtype=torch.bool) + diag = torch.eye(2*N, dtype=torch.bool, device=device) diag[N:,:N] = diag[:N,N:] = diag[:N,:N] negatives = similarity_matrix[~diag].view(2*N, -1) if len(others_z2) != 0: for z2_ in others_z2: + z2_ = z2_.cuda(device=device) N2, Z2 = z2_.shape representations = torch.cat([z1, z2_], dim=0) similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) - mask = torch.zeros_like(similarity_matrix, dtype=torch.bool) + mask = torch.zeros_like(similarity_matrix, dtype=torch.bool, device=device) mask[N:,:N] = True mask[:N,N:] = True negatives_other = similarity_matrix[mask].view(2*N, -1) @@ -49,7 +51,7 @@ def compute_global_NT_xentloss(z1, z2, others_z2=[], temperature=0.5): logits = torch.cat([positives, negatives], dim=1) / temperature - labels = torch.zeros(2*N, dtype=torch.int64) # scalar label per sample + labels = torch.zeros(2*N, dtype=torch.int64, device=device) # scalar label per sample loss = F.cross_entropy(logits, labels, reduction='sum') return loss / (2 * N) diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py index 5d4a5b671..c5d0ac8f0 100644 --- a/federatedscope/cl/trainer/trainer.py +++ b/federatedscope/cl/trainer/trainer.py @@ -85,7 +85,19 @@ def train_with_global_loss(self, model_para, loss): class LPTrainer(GeneralTorchTrainer): - pass + def __init__(self, + model, + data, + device, + config, + only_for_eval=False, + monitor=None): + super(LPTrainer, self).__init__(model, data, device, config, + only_for_eval, monitor) + + if config.federate.restore_from != '': + self.load_model(config.federate.restore_from) + def call_cl_trainer(trainer_type): if trainer_type == 'cltrainer': From 747f8e847635ed23d4c5903756092b99c6a64e62 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Sat, 17 Sep 2022 04:23:02 +0800 Subject: [PATCH 20/46] add paper list --- .../README.md | 17 ++++++++++++ .../README.md | 24 +++++++++++++++++ .../README.md | 26 +++++++++++++++++++ 3 files changed, 67 insertions(+) create mode 100644 materials/paper_list/Federated_Learning_for_Multi-Task/README.md create mode 100644 materials/paper_list/Federated_Learning_on_Medical_Data/README.md create mode 100644 materials/paper_list/Federated_Self-Supervised_Learning/README.md diff --git a/materials/paper_list/Federated_Learning_for_Multi-Task/README.md b/materials/paper_list/Federated_Learning_for_Multi-Task/README.md new file mode 100644 index 000000000..df8eaa1c5 --- /dev/null +++ b/materials/paper_list/Federated_Learning_for_Multi-Task/README.md @@ -0,0 +1,17 @@ +## Federated Learning for Multi-Task + +### 2022 + +| Title | Venue | Link | +| ------------------------------------------------------------ | ------ | ------------------------------------------------------------ | +| Dynamic Neural Graphs Based Federated Reptile for Semi-Supervised Multi-Tasking in Healthcare Applications | JBHI | [pdf](https://ieeexplore.ieee.org/abstract/document/9648036) | +| Privacy-Preserving Federated Multi-Task Linear Regression: A One-Shot Linear Mixing Approach Inspired By Graph Regularization | ICASSP | [pdf](https://ieeexplore.ieee.org/abstract/document/9746007) | +| Decentralized Graph Federated Multitask Learning for Streaming Data | CCIS | [pdf](https://ieeexplore.ieee.org/abstract/document/9751160) | + +### 2021 + +| Title | Venue | Link | +| ------------------------------------------------------------ | ------------ | ------------------------------------------------------------ | +| FL-DISCO: Federated Generative Adversarial Network for Graph-based Molecule Drug Discovery | ICCAD | [pdf](https://ieeexplore.ieee.org/abstract/document/9643440) | +| Splitting chemical structure data sets for federated privacy-preserving machine learning | J.Cheminform | [pdf](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-021-00576-2) | + diff --git a/materials/paper_list/Federated_Learning_on_Medical_Data/README.md b/materials/paper_list/Federated_Learning_on_Medical_Data/README.md new file mode 100644 index 000000000..a8453cdbf --- /dev/null +++ b/materials/paper_list/Federated_Learning_on_Medical_Data/README.md @@ -0,0 +1,24 @@ +## Federated Learning on Medical Data + +### 2022 + +| Title | Venue | Link | +| ------------------------------------------------------------ | -------- | ------------------------------------------------------------ | +| SpreadGNN: Decentralized Multi-Task Federated Learning for Graph Neural Networks on Molecular Data | AAAI | [pdf](https://www.aaai.org/AAAI22Papers/AAAI-4599.HeC.pdf) | +| Federated learning of molecular properties with graph neural networks in a heterogeneous setting | Patterns | [pdf](https://www.sciencedirect.com/science/article/pii/S2666389922001180) | +| Federated Learning of Oligonucleotide Drug Molecule Thermodynamics with Differentially Private ADMM-Based SVM | CCIS | [pdf](https://link.springer.com/chapter/10.1007/978-3-030-93733-1_34) | + +### 2021 + +| Title | Venue | Link | +| ------------------------------------------------------------ | --------------------------- | ------------------------------------------------------------ | +| FL-DISCO: Federated Generative Adversarial Network for Graph-based Molecule Drug Discovery | ICCAD | [pdf](https://ieeexplore.ieee.org/abstract/document/9643440) | +| Splitting chemical structure data sets for federated privacy-preserving machine learning | J.Cheminform | [pdf](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-021-00576-2) | +| Facing small and biased data dilemma in drug discovery with enhanced federated learning approaches | Science China Life Sciences | [pdf](https://link.springer.com/article/10.1007/s11427-021-1946-0) | +| FLOP: Federated Learning on Medical Datasets using Partial Networks | KDD | [pdf](https://dl.acm.org/doi/abs/10.1145/3447548.3467185) | + +### 2020 + +| Title | Venue | Link | +| ------------------------------------------------------------ | -------------- | ------------------------------------------------------------ | +| Secure multiparty computation for privacy-preserving drug discovery | Bioinformatics | [pdf](https://watermark.silverchair.com/btaa038.pdf?token=AQECAHi208BE49Ooan9kkhW_Ercy7Dm3ZL_9Cf3qfKAc485ysgAAAuQwggLgBgkqhkiG9w0BBwagggLRMIICzQIBADCCAsYGCSqGSIb3DQEHATAeBglghkgBZQMEAS4wEQQMdcoXxZP-7dBrdXEzAgEQgIIClyqcZJ5bk8WS94gxG1TLCJ4RIluD8isSk1mjG0UlZqpqbiE6Qo-woVPAfPw3cC8uozlbEq2Ubh6uN68GNY1CrpdZl6S25sYz99bR9o8AsA139JjCKvH6NtOrs7pDxCpOQgHORfIvSYXbsRvEc-vrx013j5n68Ewcs_xIK5E7GO6M5gMMl4GoTU54PMuj05dXMuh-kQjWDaTusH-v_DTldHHn6eRhYGgNtU4shGAinoCrQM5dpbnpwQiy2iusSDjPSEOcCtcy4C7v2xs0PzZurc_Woh7PE6pbM4KMSyUj1NICHn468bWjU0YBEVGUbZNqQbE1BWxE6j1ygV8r8UiXt8B6vxPsW8JzjSzYdMojA-oCmoZM8Ru0plrKka12h4st8P-bkzfPvK9y5F6oetdaZnGMGNA9MHXZ1SnC8Da2-WV8rA-g0OBlFcKdh9k5cgf0pnt7L569QGdd_frzpI6NgnqEZwY3INxxcd6ElMUXm7mOJbBECbkPmGqREG4J6fMF5wstT1nfafteWdmLhflHGxfwMTsGmlgBSzEKalFdUt3GNCEFyeJAy6D_5_mcb3m-X81fHLJARnzcVX6LV1CvBlMk7zkuitTfeW_ZXH2u3bzRRdeZqpHcfvLHCMOA4B7FSkFw-FNkCpn8odsSxJ3DXf_ZSKNyzw8PqlHadzp48YULho3jPDdmXOQEYEBdSuj3JR8eb-GhzoaWEA2W1pzz7nr_I98CBgrWztpsWqt2IOKgqe4bLzMVKoN7Oh0cMRxSeKTc_mujFb9t9W4gs07GcwBqNsqcevy5-iJmoPZf1c37qXHu6u0kZ4SprUSh4C-G6Hhe3AXbprcwggcwLzjlVcll2T5pU86JrrPf4Ia_jNQHHN3WKGJybw) | \ No newline at end of file diff --git a/materials/paper_list/Federated_Self-Supervised_Learning/README.md b/materials/paper_list/Federated_Self-Supervised_Learning/README.md new file mode 100644 index 000000000..feb5b8f28 --- /dev/null +++ b/materials/paper_list/Federated_Self-Supervised_Learning/README.md @@ -0,0 +1,26 @@ +## Federated Self-Supervised Learning + +### 2022 + +| Title | Venue | Link | +| ------------------------------------------------------------ | ------- | ------------------------------------------------------------ | +| SSFL: Tackling Label Deficiency in Federated Learning via Personalized Self-Supervision | FL-AAAI | [pdf](https://federated-learning.org/fl-aaai-2022/Papers/FL-AAAI-22_paper_36.pdf) | +| Orchestra: Unsupervised Federated Learning via Globally Consistent Clustering | ICML | [pdf](https://arxiv.org/pdf/2205.11506.pdf), [code](https://github.com/akhilmathurs/orchestra) | +| Divergence-aware Federated Self-Supervised Learning | ICLR | [pdf](https://arxiv.org/pdf/2204.04385.pdf) | +| Federated Learning from Only Unlabeled Data with Class-Conditional-Sharing Clients | ICLR | [pdf](https://niug1984.github.io/paper/lu_iclr22.pdf) | + +### 2021 + +| Title | Venue | Link | +| ------------------------------------------------------------ | ----- | ------------------------------------------------------------ | +| Collaborative unsupervised visual representation learning from decentralized data | ICCV | [pdf](https://openaccess.thecvf.com/content/ICCV2021/papers/Zhuang_Collaborative_Unsupervised_Visual_Representation_Learning_From_Decentralized_Data_ICCV_2021_paper.pdf) | + +### 2020 + +| Title | Venue | Link | +| ------------------------------------------------------------ | ---------- | --------------------------------------------------------- | +| Performance optimization of federated person re-identification via benchmark analysis | MM | [pdf](https://dl.acm.org/doi/abs/10.1145/3394171.3413814) | +| Federated Self-Supervised Learning of Multi-Sensor Representations for Embedded Intelligence | IEEE IoT-J | [pdf](https://arxiv.org/pdf/2007.13018.pdf) | +| Federated unsupervised representation learning | Arxiv | [pdf](https://arxiv.org/pdf/2010.08982.pdf) | +| Towards utilizing unlabeled data in federated learning: A survey and prospective | Arxiv | [pdf](https://arxiv.org/pdf/2002.11545.pdf) | +| Towards federated unsupervised representation learning | EdgeSys | [pdf](https://dl.acm.org/doi/abs/10.1145/3378679.3394530) | \ No newline at end of file From 241b42ee6380aee24a3671fafffb4e2599b42b75 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Mon, 19 Sep 2022 23:28:09 +0800 Subject: [PATCH 21/46] add global loss grad and computed graph, and keep the same Non-IID distribution of train and test set --- .../cl/baseline/fedgc_on_cifar10.yaml | 12 +- .../fedsimclr_linearprob_on_cifar10.yaml | 35 ------ federatedscope/cl/dataloader/Cifar10.py | 7 +- federatedscope/cl/fedgc/client.py | 3 +- federatedscope/cl/fedgc/server.py | 27 ++-- federatedscope/cl/fedgc/utils.py | 116 +++++++++++++----- federatedscope/cl/trainer/trainer.py | 23 ++-- 7 files changed, 120 insertions(+), 103 deletions(-) delete mode 100644 federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml diff --git a/federatedscope/cl/baseline/fedgc_on_cifar10.yaml b/federatedscope/cl/baseline/fedgc_on_cifar10.yaml index af86dc369..4850bfbf9 100644 --- a/federatedscope/cl/baseline/fedgc_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedgc_on_cifar10.yaml @@ -3,13 +3,13 @@ seed: 1 device: 2 federate: mode: standalone - total_round_num: 200 + total_round_num: 20 client_num: 5 - share_local_model: True + share_local_model: False online_aggr: True sample_client_rate: 1.0 method: fedgc - save_to: 'FedGC_on_Cifar4CL_lda0.1_lr0.25_lus5b_rn100.ckpt' + save_to: 'FedGC_on_Cifar4CL_lda0.1_lr0.05_lus2b_rn20.ckpt' data: root: 'data' type: 'Cifar4CL' @@ -20,10 +20,10 @@ data: model: type: 'SimCLR' train: - local_update_steps: 5 + local_update_steps: 2 batch_or_epoch: 'batch' optimizer: - lr: 0.25 + lr: 0.05 momentum: 0.1 early_stop: patience: 0 @@ -32,6 +32,6 @@ criterion: trainer: type: 'cltrainer' eval: - freq: 2 + freq: 5 metrics: ['loss'] split: ['val', 'test'] \ No newline at end of file diff --git a/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml b/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml deleted file mode 100644 index 02ed75804..000000000 --- a/federatedscope/cl/baseline/fedsimclr_linearprob_on_cifar10.yaml +++ /dev/null @@ -1,35 +0,0 @@ -use_gpu: True -device: 2 -federate: - mode: standalone - total_round_num: 50 - client_num: 5 - sample_client_rate: 1.0 - method: local - restore_from: 'SimCLR_on_Cifar4CL_lr0.1_lstep5_rn100.ckpt' -data: - root: 'data' - type: 'Cifar4LP' - batch_size: 256 - splitter: 'lda' - splitter_args: [{'alpha': 0.5}] - num_workers: 4 -model: - type: 'SimCLR_linear' -train: - local_update_steps: 1 - batch_or_epoch: 'epoch' - optimizer: - lr: 0.1 - momentum: 0.9 - weight_decay: 0.0 -early_stop: - patience: 0 -criterion: - type: CrossEntropyLoss -trainer: - type: general -eval: - freq: 2 - metrics: ['acc'] - split: ['val', 'test'] \ No newline at end of file diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index dede7e59f..71bbf41f7 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -70,8 +70,10 @@ def Cifar4CL(config): splitter = get_splitter(config) data_train = splitter(data_train) print([len(i) for i in data_train]) + label_data_train = [[i[1] for i in list_i] for list_i in data_train] data_val = data_train - data_test = splitter(data_test) + data_test = splitter(data_test, prior=label_data_train) + print([len(i) for i in data_test]) client_num = min(len(data_train), config.federate.client_num @@ -136,8 +138,9 @@ def Cifar4LP(config): # Splitter splitter = get_splitter(config) data_train = splitter(data_train) + label_data_train = [[i[1] for i in list_i] for list_i in data_train] data_val = splitter(data_val) - data_test = splitter(data_test) + data_test = splitter(data_test, prior=label_data_train) client_num = min(len(data_train), config.federate.client_num diff --git a/federatedscope/cl/fedgc/client.py b/federatedscope/cl/fedgc/client.py index 01f64d4ae..7eb489db6 100644 --- a/federatedscope/cl/fedgc/client.py +++ b/federatedscope/cl/fedgc/client.py @@ -33,8 +33,7 @@ def _register_default_handlers(self): def callback_funcs_for_global_loss(self, message: Message): round, sender, content = message.state, message.sender, message.content global_loss = content['global_loss'] - model_para_old = self.trainer.get_model_para() - model_para = self.trainer.train_with_global_loss(model_para_old, global_loss) + model_para = self.trainer.train_with_global_loss(global_loss) self.trainer.update(model_para) def callback_funcs_for_model_para(self, message: Message): diff --git a/federatedscope/cl/fedgc/server.py b/federatedscope/cl/fedgc/server.py index 08843f64a..2f4078f97 100644 --- a/federatedscope/cl/fedgc/server.py +++ b/federatedscope/cl/fedgc/server.py @@ -6,7 +6,8 @@ from federatedscope.core.message import Message from federatedscope.core.workers.server import Server from federatedscope.core.auxiliaries.utils import merge_dict -from federatedscope.cl.fedgc.utils import compute_global_NT_xentloss +from federatedscope.cl.fedgc.utils import global_NT_xentloss +from torchviz import make_dot, make_dot_from_trace logger = logging.getLogger(__name__) @@ -40,7 +41,6 @@ def __init__(self, idx: 0 for idx in range(1, self._cfg.federate.client_num + 1) } - def check_and_move_on(self, check_eval_result=False): @@ -73,28 +73,24 @@ def check_and_move_on(self, check_eval_result=False): raise ValueError( 'GlobalContrastFL server not support multi-model.') + global_loss_fn = global_NT_xentloss(device=self.device) for client_id in train_msg_buffer: z1, z2 = self.seqs_embedding[client_id][0], self.seqs_embedding[client_id][1] others_z2 = [self.seqs_embedding[other_client_id][1] for other_client_id in train_msg_buffer if other_client_id != client_id] # print("start cal loss") - self.loss_list[client_id] = compute_global_NT_xentloss(z1, z2, others_z2, device=self.device) + self.loss_list[client_id] = global_loss_fn(z1, z2, others_z2) print(self.loss_list[client_id]) + print('client {} global_loss:{}'.format(client_id, self.loss_list[client_id])) # print("end cal loss") self.state += 1 - if self.state % self._cfg.eval.freq == 0 and self.state != \ - self.total_round_num: - # Evaluate - logger.info( - 'Server: Starting evaluation at round {:d}.'.format( - self.state)) - self.eval() + if self.state < self.total_round_num: - self._start_new_training_round(aggregated_num) + for client_id in train_msg_buffer: msg_list = { @@ -108,8 +104,17 @@ def check_and_move_on(self, check_eval_result=False): receiver=[client_id], state=self.state, content=msg_list)) + + if self.state % self._cfg.eval.freq == 0 and self.state != \ + self.total_round_num: + # Evaluate + logger.info( + 'Server: Starting evaluation at round {:d}.'.format( + self.state)) + self.eval() + self._start_new_training_round(aggregated_num) # Move to next round of training logger.info( f'----------- Starting a new traininground(Round ' diff --git a/federatedscope/cl/fedgc/utils.py b/federatedscope/cl/fedgc/utils.py index 19aa4e7df..c5226ff29 100644 --- a/federatedscope/cl/fedgc/utils.py +++ b/federatedscope/cl/fedgc/utils.py @@ -8,50 +8,98 @@ def norm(w): return torch.norm(torch.cat([v.flatten() for v in w.values()])).item() - -def compute_global_NT_xentloss(z1, z2, others_z2=[], temperature=0.5, device='cpu'): +class global_NT_xentloss(nn.Module): r""" - global_NT_xentloss is federated NT_xentloss in server. It collect sample client - embedding and calculate NT_xentloss from local client positive examples and - negative examples of local and other clients. + NT_xentloss definition adapted from https://github.com/PatrickHua/SimSiam Arguments: - z1 (torch.tensor): the embedding of local model. - z2 (torch.tensor): the embedding of local model using another augmentation. - others_z2 (list[torch.tensor]): the embedding list of other clients, each client has two embedding. + z1 (torch.tensor): the embedding of model . + z2 (torch.tensor): the embedding of model using another augmentation. returns: - loss: the NT_xentloss loss for this aggregation of global clients + loss: the NT_xentloss loss for this batch data :rtype: torch.FloatTensor """ - z1, z2 = z1.cuda(device=device), z2.cuda(device=device) - N, Z = z1.shape - representations = torch.cat([z1, z2], dim=0) - similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) + def __init__(self, temperature=0.5, device=torch.device("cpu")): + super(global_NT_xentloss, self).__init__() + self.temperature = temperature + self.device = device + + def forward(self, z1, z2, others_z2=[]): + N, Z = z1.shape + z1, z2 = z1.to(self.device), z2.to(self.device) + representations = torch.cat([z1, z2], dim=0) + similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) + + l_pos = torch.diag(similarity_matrix, N) + r_pos = torch.diag(similarity_matrix, -N) + positives = torch.cat([l_pos, r_pos]).view(2 * N, 1) + + diag = torch.eye(2*N, dtype=torch.bool, device=self.device) + diag[N:,:N] = diag[:N,N:] = diag[:N,:N] + negatives = similarity_matrix[~diag].view(2*N, -1) + + if len(others_z2) != 0: + for z2_ in others_z2: + z2_ = z2_.detach().to(self.device) + N2, Z2 = z2_.shape + representations = torch.cat([z1, z2_], dim=0) + similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) + mask = torch.zeros_like(similarity_matrix, dtype=torch.bool, device=self.device) + mask[N:,:N] = True + mask[:N,N:] = True + negatives_other = similarity_matrix[mask].view(2*N, -1) + negatives = torch.cat([negatives, negatives_other], dim=1) + + + logits = torch.cat([positives, negatives], dim=1) / self.temperature + labels = torch.zeros(2*N, dtype=torch.int64, device=self.device) # scalar label per sample + loss = F.cross_entropy(logits, labels, reduction='sum') / (2 * N) + + return loss + +# def compute_global_NT_xentloss(z1, z2, others_z2=[], temperature=0.5, device='cpu'): +# r""" +# global_NT_xentloss is federated NT_xentloss in server. It collect sample client +# embedding and calculate NT_xentloss from local client positive examples and +# negative examples of local and other clients. +# Arguments: +# z1 (torch.tensor): the embedding of local model. +# z2 (torch.tensor): the embedding of local model using another augmentation. +# others_z2 (list[torch.tensor]): the embedding list of other clients, each client has two embedding. +# returns: +# loss: the NT_xentloss loss for this aggregation of global clients +# :rtype: +# torch.FloatTensor +# """ +# z1, z2 = z1.cuda(device=device), z2.cuda(device=device) +# N, Z = z1.shape +# representations = torch.cat([z1, z2], dim=0) +# similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) - l_pos = torch.diag(similarity_matrix, N) - r_pos = torch.diag(similarity_matrix, -N) - positives = torch.cat([l_pos, r_pos]).view(2 * N, 1) +# l_pos = torch.diag(similarity_matrix, N) +# r_pos = torch.diag(similarity_matrix, -N) +# positives = torch.cat([l_pos, r_pos]).view(2 * N, 1) - diag = torch.eye(2*N, dtype=torch.bool, device=device) - diag[N:,:N] = diag[:N,N:] = diag[:N,:N] - negatives = similarity_matrix[~diag].view(2*N, -1) +# diag = torch.eye(2*N, dtype=torch.bool, device=device) +# diag[N:,:N] = diag[:N,N:] = diag[:N,:N] +# negatives = similarity_matrix[~diag].view(2*N, -1) - if len(others_z2) != 0: - for z2_ in others_z2: - z2_ = z2_.cuda(device=device) - N2, Z2 = z2_.shape - representations = torch.cat([z1, z2_], dim=0) - similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) - mask = torch.zeros_like(similarity_matrix, dtype=torch.bool, device=device) - mask[N:,:N] = True - mask[:N,N:] = True - negatives_other = similarity_matrix[mask].view(2*N, -1) - negatives = torch.cat([negatives, negatives_other], dim=1) +# if len(others_z2) != 0: +# for z2_ in others_z2: +# z2_ = z2_.cuda(device=device) +# N2, Z2 = z2_.shape +# representations = torch.cat([z1, z2_], dim=0) +# similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) +# mask = torch.zeros_like(similarity_matrix, dtype=torch.bool, device=device) +# mask[N:,:N] = True +# mask[:N,N:] = True +# negatives_other = similarity_matrix[mask].view(2*N, -1) +# negatives = torch.cat([negatives, negatives_other], dim=1) - logits = torch.cat([positives, negatives], dim=1) / temperature - labels = torch.zeros(2*N, dtype=torch.int64, device=device) # scalar label per sample - loss = F.cross_entropy(logits, labels, reduction='sum') +# logits = torch.cat([positives, negatives], dim=1) / temperature +# labels = torch.zeros(2*N, dtype=torch.int64, device=device) # scalar label per sample +# loss = F.cross_entropy(logits, labels, reduction='sum') - return loss / (2 * N) +# return loss / (2 * N) diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py index c5d0ac8f0..98dd88bd9 100644 --- a/federatedscope/cl/trainer/trainer.py +++ b/federatedscope/cl/trainer/trainer.py @@ -6,8 +6,10 @@ from federatedscope.core.trainers.context import CtxVar from federatedscope.core.auxiliaries.enums import LIFECYCLE from federatedscope.core.auxiliaries import utils +from torchviz import make_dot, make_dot_from_trace import torch import numpy as np +import copy class CLTrainer(GeneralTorchTrainer): @@ -21,22 +23,22 @@ def __init__(self, super(CLTrainer, self).__init__(model, data, device, config, only_for_eval, monitor) self.batches_aug_data_1, self.batches_aug_data_2 = [], [] + self.z1, self.z2 = torch.empty(1), torch.empty(1) - @torch.no_grad() + def get_train_pred_embedding(self): model = self.ctx.model.to(self.ctx.device) ys_prob_1, ys_prob_2 = [], [] - x1, x2 = torch.cat(self.batches_aug_data_1, dim=0), torch.cat(self.batches_aug_data_2, dim=0) - z1, z2 = model(x1.to(self.ctx.device), x2.to(self.ctx.device)) - ys_prob_1 = z1.detach().cpu() - ys_prob_2 = z2.detach().cpu() + x1, x2 = torch.cat(self.batches_aug_data_1, dim=0).to(self.ctx.device), torch.cat(self.batches_aug_data_2, dim=0).to(self.ctx.device) + z1, z2 = model(x1, x2) self.batches_aug_data_1, self.batches_aug_data_2 = [], [] + self.z1, self.z2 = z1, z2 + self.ctx.model.to(torch.device('cpu')) - return [ys_prob_1, ys_prob_2] + return [self.z1, self.z2] def _hook_on_batch_forward(self, ctx): x, label = [utils.move_to(_, ctx.device) for _ in ctx.data_batch] -# print(len(x), x[0].size(), x[1].size(), label.size()) x1, x2 = x[0], x[1] if ctx.cur_mode in [MODE.TRAIN]: self.batches_aug_data_1.append(x1) @@ -60,19 +62,14 @@ def _hook_on_batch_end(self, ctx): ctx.ys_true.append(ctx.y_true.detach().cpu().numpy()) ctx.ys_prob.append(ctx.y_prob[0].detach().cpu().numpy()) - def train_with_global_loss(self, model_para, loss): + def train_with_global_loss(self, loss): """ Arguments: - model_para: model parameters loss: loss after global calculate :returns: grads: grads to optimize the model of other clients """ - for key in model_para.keys(): - if isinstance(model_para[key], list): - model_para[key] = torch.FloatTensor(model_para[key]) - self.ctx.model.load_state_dict(model_para) self.ctx.model = self.ctx.model.to(self.ctx.device) self.ctx.optimizer.zero_grad() From a9aebf5d4f1ba29e456a160c232ce94bfbd82abe Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 20 Sep 2022 04:17:42 +0800 Subject: [PATCH 22/46] debug worker_builder --- federatedscope/core/auxiliaries/worker_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federatedscope/core/auxiliaries/worker_builder.py b/federatedscope/core/auxiliaries/worker_builder.py index 7f2b84e95..827ccfb2e 100644 --- a/federatedscope/core/auxiliaries/worker_builder.py +++ b/federatedscope/core/auxiliaries/worker_builder.py @@ -106,7 +106,7 @@ def get_server_cls(cfg): elif server_type == 'gcflplus': from federatedscope.gfl.gcflplus.worker import GCFLPlusServer return GCFLPlusServer - elif client_type == 'fedgc': + elif server_type == 'fedgc': from federatedscope.cl.fedgc.server import GlobalContrastFLServer return GlobalContrastFLServer else: From cd171a0455a4f9888ea19d3d9fde0a8c53b5c5f5 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Fri, 23 Sep 2022 04:33:46 +0800 Subject: [PATCH 23/46] modify the worker for global loss backward once for both local train and global loss, use local ratio to control the gradient component --- ...ontrastlearning_linearprob_on_cifar10.yaml | 4 +- .../cl/baseline/fedgc_on_cifar10.yaml | 6 +- .../cl/baseline/fedsimclr_on_cifar10.yaml | 8 +- federatedscope/cl/dataloader/Cifar10.py | 2 +- federatedscope/cl/fedgc/client.py | 28 ++- federatedscope/cl/fedgc/server.py | 185 +++++++++--------- federatedscope/cl/trainer/trainer.py | 24 ++- 7 files changed, 144 insertions(+), 113 deletions(-) diff --git a/federatedscope/cl/baseline/fedcontrastlearning_linearprob_on_cifar10.yaml b/federatedscope/cl/baseline/fedcontrastlearning_linearprob_on_cifar10.yaml index 482c4400c..bb48d3cc3 100644 --- a/federatedscope/cl/baseline/fedcontrastlearning_linearprob_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedcontrastlearning_linearprob_on_cifar10.yaml @@ -6,8 +6,8 @@ federate: total_round_num: 50 client_num: 5 sample_client_rate: 1.0 - method: global - restore_from: '../fedsimclr_on_Cifar4CL_lda0.1_lr0.05_lus5_rn200batch_seed2.ckpt' + method: local + restore_from: '../fedsimclr_on_Cifar4CL_lda0.1_lr0.05_lus5_rn200batch_seed1.ckpt' data: root: 'data' type: 'Cifar4LP' diff --git a/federatedscope/cl/baseline/fedgc_on_cifar10.yaml b/federatedscope/cl/baseline/fedgc_on_cifar10.yaml index 4850bfbf9..f007609ee 100644 --- a/federatedscope/cl/baseline/fedgc_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedgc_on_cifar10.yaml @@ -3,13 +3,13 @@ seed: 1 device: 2 federate: mode: standalone - total_round_num: 20 + total_round_num: 200 client_num: 5 share_local_model: False online_aggr: True sample_client_rate: 1.0 method: fedgc - save_to: 'FedGC_on_Cifar4CL_lda0.1_lr0.05_lus2b_rn20.ckpt' + save_to: 'FedGC_on_Cifar4CL_lda0.1_lr0.05_lus2b_rn200.ckpt' data: root: 'data' type: 'Cifar4CL' @@ -20,7 +20,7 @@ data: model: type: 'SimCLR' train: - local_update_steps: 2 + local_update_steps: 1 batch_or_epoch: 'batch' optimizer: lr: 0.05 diff --git a/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml b/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml index 4667f41b5..8af52e6bc 100644 --- a/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml @@ -2,24 +2,24 @@ use_gpu: True device: 2 federate: mode: standalone - total_round_num: 100 + total_round_num: 200 client_num: 5 share_local_model: True online_aggr: True sample_client_rate: 1.0 - save_to: 'SimCLR_on_Cifar4CL_lda0.5_lr0.05_lus5_rn100.ckpt' + save_to: 'SimCLR_on_Cifar4CL_lda0.1_lr0.05_lus5_rn200.ckpt' data: root: 'data' type: 'Cifar4CL' batch_size: 256 splitter: 'lda' - splitter_args: [{'alpha': 0.5}] + splitter_args: [{'alpha': 0.1}] num_workers: 4 model: type: 'SimCLR' train: local_update_steps: 5 - batch_or_epoch: 'epoch' + batch_or_epoch: 'batch' optimizer: lr: 0.05 momentum: 0.1 diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index 71bbf41f7..1a5e2ea96 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -139,7 +139,7 @@ def Cifar4LP(config): splitter = get_splitter(config) data_train = splitter(data_train) label_data_train = [[i[1] for i in list_i] for list_i in data_train] - data_val = splitter(data_val) + data_val = splitter(data_val, prior=label_data_train) data_test = splitter(data_test, prior=label_data_train) diff --git a/federatedscope/cl/fedgc/client.py b/federatedscope/cl/fedgc/client.py index 7eb489db6..56f36add7 100644 --- a/federatedscope/cl/fedgc/client.py +++ b/federatedscope/cl/fedgc/client.py @@ -21,8 +21,8 @@ def _register_default_handlers(self): self.callback_funcs_for_join_in_info) self.register_handlers('address', self.callback_funcs_for_address) self.register_handlers('model_para', - self.callback_funcs_for_model_para) - self.register_handlers('global_loss', self.callback_funcs_for_global_loss) + self.callback_funcs_for_pred_embedding) + self.register_handlers('global_loss', self.callback_funcs_for_local_backward) self.register_handlers('ss_model_para', self.callback_funcs_for_model_para) @@ -30,29 +30,37 @@ def _register_default_handlers(self): self.register_handlers('finish', self.callback_funcs_for_finish) self.register_handlers('converged', self.callback_funcs_for_converged) - def callback_funcs_for_global_loss(self, message: Message): + def callback_funcs_for_local_backward(self, message: Message): round, sender, content = message.state, message.sender, message.content global_loss = content['global_loss'] model_para = self.trainer.train_with_global_loss(global_loss) self.trainer.update(model_para) + self.state = round + sample_size, model_para= self.trainer.num_samples, self.trainer.get_model_para() + + self.comm_manager.send( + Message(msg_type='model_para', + sender=self.ID, + receiver=[sender], + state=self.state, + content=(sample_size, model_para))) - def callback_funcs_for_model_para(self, message: Message): + def callback_funcs_for_pred_embedding(self, message: Message): round, sender, content = message.state, message.sender, message.content self.trainer.update(content) - self.state = round sample_size, model_para, results = self.trainer.train() + self.state = round pred_embedding = self.trainer.get_train_pred_embedding() - if self._cfg.federate.share_local_model and not \ - self._cfg.federate.online_aggr: - model_para = copy.deepcopy(model_para) + logger.info( self._monitor.format_eval_res(results, rnd=self.state, role='Client #{}'.format(self.ID))) self.comm_manager.send( - Message(msg_type='model_para', + Message(msg_type='pred_embedding', sender=self.ID, receiver=[sender], state=self.state, - content=(sample_size, model_para, pred_embedding))) + content=(pred_embedding))) + diff --git a/federatedscope/cl/fedgc/server.py b/federatedscope/cl/fedgc/server.py index 2f4078f97..2d2800cca 100644 --- a/federatedscope/cl/fedgc/server.py +++ b/federatedscope/cl/fedgc/server.py @@ -42,103 +42,112 @@ def __init__(self, for idx in range(1, self._cfg.federate.client_num + 1) } - def check_and_move_on(self, check_eval_result=False): + def _register_default_handlers(self): + self.register_handlers('join_in', self.callback_funcs_for_join_in) + self.register_handlers('join_in_info', self.callback_funcs_for_join_in) + self.register_handlers('model_para', self.callback_funcs_model_para) + self.register_handlers('metrics', self.callback_funcs_for_metrics) + self.register_handlers('pred_embedding', self.callback_funcs_global_loss) + + def check_and_move_on_for_global_loss(self): + + minimal_number = self.sample_client_num + + if self.check_buffer(self.state, minimal_number, check_eval_result=False): + + # Receiving enough feedback in the training process + + # Get all the message + train_msg_buffer = self.msg_buffer['train'][self.state] + for model_idx in range(self.model_num): + model = self.models[model_idx] + msg_list = list() + for client_id in train_msg_buffer: + if self.model_num == 1: + pred_embedding = train_msg_buffer[client_id] + self.seqs_embedding[client_id] = pred_embedding + else: + raise ValueError( + 'GlobalContrastFL server not support multi-model.') + + global_loss_fn = global_NT_xentloss(device=self.device) + for client_id in train_msg_buffer: + z1, z2 = self.seqs_embedding[client_id][0], self.seqs_embedding[client_id][1] + others_z2 = [self.seqs_embedding[other_client_id][1] + for other_client_id in train_msg_buffer + if other_client_id != client_id] +# print("start cal loss") + self.loss_list[client_id] = global_loss_fn(z1, z2, others_z2) + print(self.loss_list[client_id]) + print('client {} global_loss:{}'.format(client_id, self.loss_list[client_id])) +# print("end cal loss") - if check_eval_result: - # all clients are participating in evaluation - minimal_number = self.client_num - else: - # sampled clients are participating in training - minimal_number = self.sample_client_num - if self.check_buffer(self.state, minimal_number, check_eval_result): + self.state += 1 + + if self.state < self.total_round_num: + + for client_id in train_msg_buffer: - if not check_eval_result: # in the training process - # Receiving enough feedback in the training process - aggregated_num = self._perform_federated_aggregation() + msg_list = { + 'global_loss': self.loss_list[client_id], + } + + # Send loss to Clients + self.comm_manager.send( + Message(msg_type='global_loss', + sender=self.ID, + receiver=[client_id], + state=self.state, + content=msg_list)) + # Clean the msg_buffer + self.msg_buffer['train'][self.state - 1].clear() + self.msg_buffer['train'][self.state] = dict() + self.staled_msg_buffer.clear() + # Start a new training round + - # Get all the message - train_msg_buffer = self.msg_buffer['train'][self.state] - for model_idx in range(self.model_num): - model = self.models[model_idx] - aggregator = self.aggregators[model_idx] - msg_list = list() - for client_id in train_msg_buffer: - if self.model_num == 1: - train_data_size, model_para, pred_embedding = \ - train_msg_buffer[client_id] - self.seqs_embedding[client_id] = pred_embedding - msg_list.append((train_data_size, model_para, pred_embedding)) - else: - raise ValueError( - 'GlobalContrastFL server not support multi-model.') - - global_loss_fn = global_NT_xentloss(device=self.device) - for client_id in train_msg_buffer: - z1, z2 = self.seqs_embedding[client_id][0], self.seqs_embedding[client_id][1] - others_z2 = [self.seqs_embedding[other_client_id][1] - for other_client_id in train_msg_buffer - if other_client_id != client_id] -# print("start cal loss") - self.loss_list[client_id] = global_loss_fn(z1, z2, others_z2) - print(self.loss_list[client_id]) - print('client {} global_loss:{}'.format(client_id, self.loss_list[client_id])) -# print("end cal loss") + def callback_funcs_global_loss(self, message: Message): + """ + The handling function for receiving model embeddings, which triggers + check_and_move_on (calculate global loss when enough feedback has + been received). + Arguments: + message: The received message, which includes sender, receiver, + state, and content. More detail can be found in + federatedscope.core.message + """ + if self.is_finish: + return 'finish' - self.state += 1 + round = message.state + sender = message.sender + timestamp = message.timestamp + content = message.content + self.sampler.change_state(sender, 'idle') + # update the currency timestamp according to the received message + assert timestamp >= self.cur_timestamp # for test + self.cur_timestamp = timestamp - if self.state < self.total_round_num: - - for client_id in train_msg_buffer: - - msg_list = { - 'global_loss': self.loss_list[client_id], - } - - # Send loss to Clients - self.comm_manager.send( - Message(msg_type='global_loss', - sender=self.ID, - receiver=[client_id], - state=self.state, - content=msg_list)) - - if self.state % self._cfg.eval.freq == 0 and self.state != \ - self.total_round_num: - # Evaluate - logger.info( - 'Server: Starting evaluation at round {:d}.'.format( - self.state)) - self.eval() - - - self._start_new_training_round(aggregated_num) - # Move to next round of training - logger.info( - f'----------- Starting a new traininground(Round ' - f'#{self.state}) -------------') - # Clean the msg_buffer - self.msg_buffer['train'][self.state - 1].clear() - self.msg_buffer['train'][self.state] = dict() - self.staled_msg_buffer.clear() - # Start a new training round - + if round == self.state: + if round not in self.msg_buffer['train']: + self.msg_buffer['train'][round] = dict() + # Save the messages in this round + self.msg_buffer['train'][round][sender] = content + elif round >= self.state - self.staleness_toleration: + # Save the staled messages + self.staled_msg_buffer.append((round, sender, content)) + else: + # Drop the out-of-date messages + logger.info(f'Drop a out-of-date message from round #{round}') + self.dropout_num += 1 - else: - # Final Evaluate - logger.info('Server: Training is finished! Starting ' - 'evaluation.') - self.eval() - - else: # in the evaluation process - # Get all the message & aggregate - formatted_eval_res = self.merge_eval_results_from_all_clients() - self.history_results = merge_dict(self.history_results, - formatted_eval_res) - self.check_and_save() - + move_on_flag = self.check_and_move_on_for_global_loss() + + return move_on_flag + def callback_funcs_model_para(self, message: Message): """ The handling function for receiving model parameters, which triggers diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py index 98dd88bd9..91bc04bab 100644 --- a/federatedscope/cl/trainer/trainer.py +++ b/federatedscope/cl/trainer/trainer.py @@ -3,6 +3,7 @@ from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer from federatedscope.core.auxiliaries.scheduler_builder import get_scheduler from federatedscope.core.trainers import GeneralTorchTrainer +from federatedscope.core.trainers.context import Context from federatedscope.core.trainers.context import CtxVar from federatedscope.core.auxiliaries.enums import LIFECYCLE from federatedscope.core.auxiliaries import utils @@ -24,11 +25,13 @@ def __init__(self, only_for_eval, monitor) self.batches_aug_data_1, self.batches_aug_data_2 = [], [] self.z1, self.z2 = torch.empty(1), torch.empty(1) + self.num_samples = 0 + self.local_loss_ratio = 0.5 + self.global_loss_ratio = 1 - self.local_loss_ratio - def get_train_pred_embedding(self): + def get_train_pred_embedding(self): model = self.ctx.model.to(self.ctx.device) - ys_prob_1, ys_prob_2 = [], [] x1, x2 = torch.cat(self.batches_aug_data_1, dim=0).to(self.ctx.device), torch.cat(self.batches_aug_data_2, dim=0).to(self.ctx.device) z1, z2 = model(x1, x2) self.batches_aug_data_1, self.batches_aug_data_2 = [], [] @@ -51,11 +54,20 @@ def _hook_on_batch_forward(self, ctx): ctx.y_prob = CtxVar((z1, z2), LIFECYCLE.BATCH) ctx.loss_batch = CtxVar(ctx.criterion(z1, z2), LIFECYCLE.BATCH) ctx.batch_size = CtxVar(len(label), LIFECYCLE.BATCH) - + + def _hook_on_batch_backward(self, ctx): + ctx.optimizer.zero_grad() + ctx.loss_task = ctx.loss_task * self.local_loss_ratio + ctx.loss_task.backward() + if ctx.grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ctx.model.parameters(), + ctx.grad_clip) def _hook_on_batch_end(self, ctx): # update statistics ctx.num_samples += ctx.batch_size + if ctx.cur_mode in [MODE.TRAIN]: + self.num_samples = ctx.num_samples ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size ctx.loss_regular_total += float(ctx.get("loss_regular", 0.)) # cache label for evaluate @@ -72,10 +84,12 @@ def train_with_global_loss(self, loss): self.ctx.model = self.ctx.model.to(self.ctx.device) - self.ctx.optimizer.zero_grad() +# self.ctx.optimizer.zero_grad() - loss = loss.requires_grad_() + loss = loss.requires_grad_() * self.global_loss_ratio loss.backward() + + self.ctx.optimizer.step() return self.ctx.model.state_dict() From 7e07c801058076505e2ff05ab28fd130505e7411 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 11 Oct 2022 19:28:41 +0800 Subject: [PATCH 24/46] resolve review --- ...ontrastlearning_linearprob_on_cifar10.yaml | 17 ++-- .../cl/baseline/fedgc_on_cifar10.yaml | 11 +-- .../cl/baseline/fedsimclr_on_cifar10.yaml | 20 ++--- .../baseline/repro_exp/args_cifar10_fedgc.sh | 2 +- .../repro_exp/args_cifar10_fedsimclr.sh | 2 +- .../repro_exp/run_contrastive_learning.sh | 12 +-- .../supervised_fedavg_on_cifar10.yaml | 11 +-- federatedscope/cl/fedgc/client.py | 10 ++- federatedscope/cl/fedgc/server.py | 78 ++++++++++++++++--- federatedscope/cl/fedgc/utils.py | 2 +- federatedscope/cl/loss/NT_xentloss.py | 2 +- .../cl/lr_scheduler/ LR_Scheduler.py | 49 ++++++++++++ federatedscope/cl/model/SimCLR.py | 20 ++++- federatedscope/cl/trainer/trainer.py | 16 ++-- federatedscope/core/aggregators/aggregator.py | 2 + .../core/auxiliaries/trainer_builder.py | 3 +- federatedscope/core/workers/server.py | 1 + 17 files changed, 197 insertions(+), 61 deletions(-) create mode 100644 federatedscope/cl/lr_scheduler/ LR_Scheduler.py diff --git a/federatedscope/cl/baseline/fedcontrastlearning_linearprob_on_cifar10.yaml b/federatedscope/cl/baseline/fedcontrastlearning_linearprob_on_cifar10.yaml index bb48d3cc3..26ed8e3a7 100644 --- a/federatedscope/cl/baseline/fedcontrastlearning_linearprob_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedcontrastlearning_linearprob_on_cifar10.yaml @@ -1,19 +1,19 @@ use_gpu: True seed: 1 -device: 1 +device: 2 federate: mode: standalone total_round_num: 50 - client_num: 5 + client_num: 1 sample_client_rate: 1.0 - method: local - restore_from: '../fedsimclr_on_Cifar4CL_lda0.1_lr0.05_lus5_rn200batch_seed1.ckpt' + method: global + restore_from: '../test_supervised.ckpt' data: root: 'data' type: 'Cifar4LP' batch_size: 256 splitter: 'lda' - splitter_args: [{'alpha': 0.1}] + splitter_args: [{'alpha': 0.5}] num_workers: 4 model: type: 'SimCLR_linear' @@ -21,9 +21,12 @@ train: local_update_steps: 1 batch_or_epoch: 'epoch' optimizer: - lr: 0.1 + lr: 0.01 momentum: 0.9 weight_decay: 0.0 + scheduler: + type: CosineAnnealingLR + T_max: 50 early_stop: patience: 0 criterion: @@ -31,6 +34,6 @@ criterion: trainer: type: 'lptrainer' eval: - freq: 2 + freq: 5 metrics: ['acc'] split: ['val', 'test'] \ No newline at end of file diff --git a/federatedscope/cl/baseline/fedgc_on_cifar10.yaml b/federatedscope/cl/baseline/fedgc_on_cifar10.yaml index f007609ee..ba89711d0 100644 --- a/federatedscope/cl/baseline/fedgc_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedgc_on_cifar10.yaml @@ -3,28 +3,29 @@ seed: 1 device: 2 federate: mode: standalone - total_round_num: 200 + total_round_num: 10 client_num: 5 share_local_model: False online_aggr: True sample_client_rate: 1.0 method: fedgc - save_to: 'FedGC_on_Cifar4CL_lda0.1_lr0.05_lus2b_rn200.ckpt' + save_to: 'test.ckpt' data: root: 'data' type: 'Cifar4CL' - batch_size: 256 + batch_size: 512 splitter: 'lda' splitter_args: [{'alpha': 0.1}] num_workers: 4 model: type: 'SimCLR' train: - local_update_steps: 1 + local_update_steps: 5 batch_or_epoch: 'batch' optimizer: lr: 0.05 - momentum: 0.1 + momentum: 0.9 + weight_decay: 0.0001 early_stop: patience: 0 criterion: diff --git a/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml b/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml index 8af52e6bc..680553cb6 100644 --- a/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml +++ b/federatedscope/cl/baseline/fedsimclr_on_cifar10.yaml @@ -2,27 +2,27 @@ use_gpu: True device: 2 federate: mode: standalone - total_round_num: 200 + total_round_num: 100 client_num: 5 - share_local_model: True - online_aggr: True sample_client_rate: 1.0 - save_to: 'SimCLR_on_Cifar4CL_lda0.1_lr0.05_lus5_rn200.ckpt' + method: FedAvg + save_to: '../SimCLR_on_Cifar4CL_global_lr0.03_lus10_rn100epoch_repairsave.ckpt' data: root: 'data' type: 'Cifar4CL' - batch_size: 256 + batch_size: 512 splitter: 'lda' splitter_args: [{'alpha': 0.1}] num_workers: 4 model: type: 'SimCLR' train: - local_update_steps: 5 - batch_or_epoch: 'batch' + local_update_steps: 10 + batch_or_epoch: 'epoch' optimizer: - lr: 0.05 - momentum: 0.1 + lr: 0.03 + momentum: 0.9 + weight_decay: 0.0001 early_stop: patience: 0 criterion: @@ -30,6 +30,6 @@ criterion: trainer: type: 'cltrainer' eval: - freq: 2 + freq: 5 metrics: ['loss'] split: ['val', 'test'] \ No newline at end of file diff --git a/federatedscope/cl/baseline/repro_exp/args_cifar10_fedgc.sh b/federatedscope/cl/baseline/repro_exp/args_cifar10_fedgc.sh index 45d0c369c..78241e32a 100644 --- a/federatedscope/cl/baseline/repro_exp/args_cifar10_fedgc.sh +++ b/federatedscope/cl/baseline/repro_exp/args_cifar10_fedgc.sh @@ -2,7 +2,7 @@ # Fedgc # ---------------------------------------------------------------------- # -bash run_contrastive_learning.sh 2 fedgc cifar10 0.3 +bash run_contrastive_learning.sh 3 fedgc cifar10 0.1 diff --git a/federatedscope/cl/baseline/repro_exp/args_cifar10_fedsimclr.sh b/federatedscope/cl/baseline/repro_exp/args_cifar10_fedsimclr.sh index ec7225e20..749f8da99 100644 --- a/federatedscope/cl/baseline/repro_exp/args_cifar10_fedsimclr.sh +++ b/federatedscope/cl/baseline/repro_exp/args_cifar10_fedsimclr.sh @@ -2,6 +2,6 @@ # Fedsimclr # ---------------------------------------------------------------------- # -bash run_contrastive_learning.sh 0 fedsimclr cifar10 0.5 +bash run_contrastive_learning.sh 1 fedsimclr cifar10 0.1 diff --git a/federatedscope/cl/baseline/repro_exp/run_contrastive_learning.sh b/federatedscope/cl/baseline/repro_exp/run_contrastive_learning.sh index ba4cf3aa4..0520937fb 100644 --- a/federatedscope/cl/baseline/repro_exp/run_contrastive_learning.sh +++ b/federatedscope/cl/baseline/repro_exp/run_contrastive_learning.sh @@ -13,18 +13,18 @@ fi if [[ $method_name = 'fedgc' ]]; then method='fedgc' - total_round_num='200' - batch_or_epoch='batch' + total_round_num='100' + batch_or_epoch='epoch' elif [[ $method_name = 'fedsimclr' ]]; then method='Fedavg' - total_round_num='200' - batch_or_epoch='batch' + total_round_num='100' + batch_or_epoch='epoch' fi echo "Fed Contrastive Learning starts..." -lrs=(0.01 0.05 0.25) -local_updates=(1 3 5) +lrs=(0.003 0.01 0.03) +local_updates=(10) for (( i=0; i<${#lrs[@]}; i++ )) diff --git a/federatedscope/cl/baseline/supervised_fedavg_on_cifar10.yaml b/federatedscope/cl/baseline/supervised_fedavg_on_cifar10.yaml index 7a91c110e..3986eabec 100644 --- a/federatedscope/cl/baseline/supervised_fedavg_on_cifar10.yaml +++ b/federatedscope/cl/baseline/supervised_fedavg_on_cifar10.yaml @@ -2,26 +2,27 @@ use_gpu: True device: 0 federate: mode: standalone - total_round_num: 50 + total_round_num: 100 client_num: 5 sample_client_rate: 1.0 share_local_model: True online_aggr: True method: FedAvg + save_to: '../test_supervised.ckpt' data: root: 'data' type: 'Cifar4LP' batch_size: 256 splitter: 'lda' - splitter_args: [{'alpha': 0.5}] + splitter_args: [{'alpha': 0.1}] num_workers: 4 model: type: 'supervised_fedavg' train: - local_update_steps: 1 + local_update_steps: 3 batch_or_epoch: 'epoch' optimizer: - lr: 0.1 + lr: 0.03 momentum: 0.9 weight_decay: 0.0 early_stop: @@ -31,6 +32,6 @@ criterion: trainer: type: general eval: - freq: 2 + freq: 10 metrics: ['acc'] split: ['val', 'test'] \ No newline at end of file diff --git a/federatedscope/cl/fedgc/client.py b/federatedscope/cl/fedgc/client.py index 56f36add7..8f88ef5ec 100644 --- a/federatedscope/cl/fedgc/client.py +++ b/federatedscope/cl/fedgc/client.py @@ -52,10 +52,12 @@ def callback_funcs_for_pred_embedding(self, message: Message): self.state = round pred_embedding = self.trainer.get_train_pred_embedding() - logger.info( - self._monitor.format_eval_res(results, - rnd=self.state, - role='Client #{}'.format(self.ID))) + train_log_res = self._monitor.format_eval_res( + results, + rnd=self.state, + role='Client #{}'.format(self.ID), + return_raw=True) + logger.info(train_log_res) self.comm_manager.send( Message(msg_type='pred_embedding', diff --git a/federatedscope/cl/fedgc/server.py b/federatedscope/cl/fedgc/server.py index 2d2800cca..eef0699a8 100644 --- a/federatedscope/cl/fedgc/server.py +++ b/federatedscope/cl/fedgc/server.py @@ -82,10 +82,8 @@ def check_and_move_on_for_global_loss(self): print('client {} global_loss:{}'.format(client_id, self.loss_list[client_id])) # print("end cal loss") - self.state += 1 - - if self.state < self.total_round_num: + if self.state <= self.total_round_num: for client_id in train_msg_buffer: @@ -100,13 +98,73 @@ def check_and_move_on_for_global_loss(self): receiver=[client_id], state=self.state, content=msg_list)) - # Clean the msg_buffer - self.msg_buffer['train'][self.state - 1].clear() - self.msg_buffer['train'][self.state] = dict() - self.staled_msg_buffer.clear() - # Start a new training round + def check_and_move_on(self, + check_eval_result=False, + min_received_num=None): + """ + To check the message_buffer. When enough messages are receiving, + some events (such as perform aggregation, evaluation, and move to + the next training round) would be triggered. + + Arguments: + check_eval_result (bool): If True, check the message buffer for + evaluation; and check the message buffer for training otherwise. + """ + if min_received_num is None: + if self._cfg.asyn.use: + min_received_num = self._cfg.asyn.min_received_num + else: + min_received_num = self._cfg.federate.sample_client_num + assert min_received_num <= self.sample_client_num + + if check_eval_result and self._cfg.federate.mode.lower( + ) == "standalone": + # in evaluation stage and standalone simulation mode, we assume + # strong synchronization that receives responses from all clients + min_received_num = len(self.comm_manager.get_neighbors().keys()) + + move_on_flag = True # To record whether moving to a new training + # round or finishing the evaluation + if self.check_buffer(self.state, min_received_num, check_eval_result): + if not check_eval_result: + # Receiving enough feedback in the training process + aggregated_num = self._perform_federated_aggregation() + + if self.state % self._cfg.eval.freq == 0 and self.state != \ + self.total_round_num: + # Evaluate + logger.info(f'Server: Starting evaluation at the end ' + f'of round {self.state - 1}.') + self.eval() + + if self.state < self.total_round_num: + # Move to next round of training + logger.info( + f'----------- Starting a new training round (Round ' + f'#{self.state}) -------------') + # Clean the msg_buffer + self.msg_buffer['train'][self.state - 1].clear() + self.msg_buffer['train'][self.state] = dict() + self.staled_msg_buffer.clear() + # Start a new training round + self._start_new_training_round(aggregated_num) + else: + # Final Evaluate + logger.info('Server: Training is finished! Starting ' + 'evaluation.') + self.eval() + + else: + # Receiving enough feedback in the evaluation process + self._merge_and_format_eval_results() + + else: + move_on_flag = False + + return move_on_flag + def callback_funcs_global_loss(self, message: Message): """ The handling function for receiving model embeddings, which triggers @@ -139,10 +197,6 @@ def callback_funcs_global_loss(self, message: Message): elif round >= self.state - self.staleness_toleration: # Save the staled messages self.staled_msg_buffer.append((round, sender, content)) - else: - # Drop the out-of-date messages - logger.info(f'Drop a out-of-date message from round #{round}') - self.dropout_num += 1 move_on_flag = self.check_and_move_on_for_global_loss() diff --git a/federatedscope/cl/fedgc/utils.py b/federatedscope/cl/fedgc/utils.py index c5226ff29..3a0e51f1e 100644 --- a/federatedscope/cl/fedgc/utils.py +++ b/federatedscope/cl/fedgc/utils.py @@ -19,7 +19,7 @@ class global_NT_xentloss(nn.Module): :rtype: torch.FloatTensor """ - def __init__(self, temperature=0.5, device=torch.device("cpu")): + def __init__(self, temperature=0.1, device=torch.device("cpu")): super(global_NT_xentloss, self).__init__() self.temperature = temperature self.device = device diff --git a/federatedscope/cl/loss/NT_xentloss.py b/federatedscope/cl/loss/NT_xentloss.py index 6791cfe97..349946043 100644 --- a/federatedscope/cl/loss/NT_xentloss.py +++ b/federatedscope/cl/loss/NT_xentloss.py @@ -15,7 +15,7 @@ class NT_xentloss(nn.Module): :rtype: torch.FloatTensor """ - def __init__(self, temperature=0.5): + def __init__(self, temperature=0.1): super(NT_xentloss, self).__init__() self.temperature = temperature diff --git a/federatedscope/cl/lr_scheduler/ LR_Scheduler.py b/federatedscope/cl/lr_scheduler/ LR_Scheduler.py new file mode 100644 index 000000000..ce0c2584e --- /dev/null +++ b/federatedscope/cl/lr_scheduler/ LR_Scheduler.py @@ -0,0 +1,49 @@ +import numpy as np +from federatedscope.register import register_scheduler + + +# LR Scheduler +class LR_Scheduler(object): + def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False): + self.base_lr = base_lr + self.constant_predictor_lr = constant_predictor_lr + warmup_iter = iter_per_epoch * warmup_epochs + warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter) + decay_iter = iter_per_epoch * (num_epochs - warmup_epochs) + cosine_lr_schedule = final_lr+0.5*(base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter)) + + self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) + self.optimizer = optimizer + self.iter = 0 + self.current_lr = 0 + def step(self): + for param_group in self.optimizer.param_groups: + + if self.constant_predictor_lr and param_group['name'] == 'predictor': + param_group['lr'] = self.base_lr + else: + lr = param_group['lr'] = self.lr_schedule[self.iter] + + self.iter += 1 + self.current_lr = lr + return lr + def get_lr(self): + return self.current_lr + +def get_scheduler(optimizer, type): + try: + import torch.optim as optim + except ImportError: + optim = None + scheduler = None + + if type == 'cos_lr_scheduler': + if optim is not None: + lr_lambda = [lambda epoch: epoch // 30] + scheduler = optim.lr_scheduler.LambdaLR(optimizer, warmup_epochs=0, warmup_lr=0, + num_epochs=50, base_lr=30, + final_lr=0, iter_per_epoch=int(50000/512)) + return scheduler + + +register_scheduler('cos_lr_scheduler', get_scheduler) \ No newline at end of file diff --git a/federatedscope/cl/model/SimCLR.py b/federatedscope/cl/model/SimCLR.py index efadfcf71..583dcd365 100644 --- a/federatedscope/cl/model/SimCLR.py +++ b/federatedscope/cl/model/SimCLR.py @@ -139,6 +139,21 @@ def __init__(self, bbone_arch, num_classes=10): self.backbone = create_backbone(bbone_arch, num_classes=0) self.linear = nn.Linear(512, num_classes, bias=True) + def forward(self, x): + with torch.no_grad(): + out = self.backbone(x) + out = self.linear(out) + + return out + +class simclr_supervised(nn.Module): + def __init__(self, bbone_arch, num_classes=10): + super(simclr_supervised, self).__init__() + self.register_buffer("rounds_done", torch.zeros(1)) + + self.backbone = create_backbone(bbone_arch, num_classes=0) + self.linear = nn.Linear(512, num_classes, bias=True) + def forward(self, x): out = self.backbone(x) out = self.linear(out) @@ -150,9 +165,12 @@ def ModelBuilder(model_config, local_data): if model_config.type == "SimCLR": model = simclr(bbone_arch='res18') return model - if model_config.type in ["SimCLR_linear","supervised_local","supervised_fedavg"]: + if model_config.type in ["SimCLR_linear"]: model = simclr_linearprob(bbone_arch='res18', num_classes=10) return model + if model_config.type in ["supervised_local","supervised_fedavg"]: + model = simclr_supervised(bbone_arch='res18', num_classes=10) + return model from federatedscope.register import register_model diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py index 91bc04bab..18a55450b 100644 --- a/federatedscope/cl/trainer/trainer.py +++ b/federatedscope/cl/trainer/trainer.py @@ -23,16 +23,16 @@ def __init__(self, monitor=None): super(CLTrainer, self).__init__(model, data, device, config, only_for_eval, monitor) - self.batches_aug_data_1, self.batches_aug_data_2 = [], [] + self.batches_aug_data_1, self.batches_aug_data_2 = torch.empty(1), torch.empty(1) self.z1, self.z2 = torch.empty(1), torch.empty(1) self.num_samples = 0 - self.local_loss_ratio = 0.5 - self.global_loss_ratio = 1 - self.local_loss_ratio + self.local_loss_ratio = 1 + self.global_loss_ratio = 5 def get_train_pred_embedding(self): model = self.ctx.model.to(self.ctx.device) - x1, x2 = torch.cat(self.batches_aug_data_1, dim=0).to(self.ctx.device), torch.cat(self.batches_aug_data_2, dim=0).to(self.ctx.device) + x1, x2 = self.batches_aug_data_1.to(self.ctx.device), self.batches_aug_data_2.to(self.ctx.device) z1, z2 = model(x1, x2) self.batches_aug_data_1, self.batches_aug_data_2 = [], [] self.z1, self.z2 = z1, z2 @@ -44,8 +44,8 @@ def _hook_on_batch_forward(self, ctx): x, label = [utils.move_to(_, ctx.device) for _ in ctx.data_batch] x1, x2 = x[0], x[1] if ctx.cur_mode in [MODE.TRAIN]: - self.batches_aug_data_1.append(x1) - self.batches_aug_data_2.append(x2) + self.batches_aug_data_1 = x1 + self.batches_aug_data_2 = x2 z1, z2 = ctx.model(x1, x2) if len(label.size()) == 0: label = label.unsqueeze(0) @@ -62,6 +62,10 @@ def _hook_on_batch_backward(self, ctx): if ctx.grad_clip > 0: torch.nn.utils.clip_grad_norm_(ctx.model.parameters(), ctx.grad_clip) + + ctx.optimizer.step() + if ctx.scheduler is not None: + ctx.scheduler.step() def _hook_on_batch_end(self, ctx): # update statistics diff --git a/federatedscope/core/aggregators/aggregator.py b/federatedscope/core/aggregators/aggregator.py index af9c842c0..d633df090 100644 --- a/federatedscope/core/aggregators/aggregator.py +++ b/federatedscope/core/aggregators/aggregator.py @@ -1,3 +1,5 @@ +import os +import torch from abc import ABC, abstractmethod diff --git a/federatedscope/core/auxiliaries/trainer_builder.py b/federatedscope/core/auxiliaries/trainer_builder.py index 9fbae8259..187de4949 100644 --- a/federatedscope/core/auxiliaries/trainer_builder.py +++ b/federatedscope/core/auxiliaries/trainer_builder.py @@ -26,6 +26,7 @@ "fedfocaltrainer": "FedFocalTrainer", "mftrainer": "MFTrainer", "cltrainer": "CLTrainer", + "lptrainer": "LPTrainer", } @@ -62,7 +63,7 @@ def get_trainer(model=None, dict_path = "federatedscope.cv.trainer.trainer" elif config.trainer.type.lower() in ['nlptrainer']: dict_path = "federatedscope.nlp.trainer.trainer" - elif config.trainer.type.lower() in ['cltrainer']: + elif config.trainer.type.lower() in ['cltrainer', 'lptrainer']: dict_path = "federatedscope.cl.trainer.trainer" elif config.trainer.type.lower() in [ 'graphminibatch_trainer', diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index 38ceb6435..5a52d89d3 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -456,6 +456,7 @@ def _perform_federated_aggregation(self): # Due to lazy load, we merge two state dict merged_param = merge_param_dict(model.state_dict().copy(), result) model.load_state_dict(merged_param, strict=False) + aggregator.update(merged_param) return aggregated_num From 6bde8df70f21c066c7baf9711e1cb9ea19f9cc35 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Thu, 13 Oct 2022 03:56:46 +0800 Subject: [PATCH 25/46] create unittest of fedsimclr in cifar10 --- tests/test_simclr_cifar10.py | 76 ++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/test_simclr_cifar10.py diff --git a/tests/test_simclr_cifar10.py b/tests/test_simclr_cifar10.py new file mode 100644 index 000000000..0102f090f --- /dev/null +++ b/tests/test_simclr_cifar10.py @@ -0,0 +1,76 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from federatedscope.core.auxiliaries.data_builder import get_data +from federatedscope.core.auxiliaries.utils import setup_seed, update_logger +from federatedscope.core.configs.config import global_cfg +from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls + +SAMPLE_CLIENT_NUM = 5 + + +class SimCLR_CIFAR10Test(unittest.TestCase): + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def set_config_simclr_cifar10(self, cfg): + backup_cfg = cfg.clone() + + import torch + cfg.use_gpu = torch.cuda.is_available() + cfg.eval.freq = 5 + cfg.eval.metrics = ['loss'] + + cfg.federate.mode = 'standalone' + cfg.train.local_update_steps = 5 + cfg.federate.total_round_num = 20 + cfg.federate.sample_client_num = 5 + + cfg.data.root = 'data/' + cfg.data.type = 'Cifar4CL' + cfg.data.splits = [0.8, 0.1, 0.1] + cfg.data.batch_size = 10 + cfg.data.subsample = 1.0 + + cfg.model.type = 'SimCLR' + cfg.model.hidden = 256 + cfg.model.out_channels = 1 + + cfg.train.optimizer.lr = 0.01 + cfg.train.optimizer.weight_decay = 0.0001 + cfg.train.optimizer.momentum = 0.9 + + cfg.criterion.type = 'NT_xentloss' + cfg.trainer.type = 'cltrainer' + cfg.seed = 1 + + return backup_cfg + + def test_simclr_cifar10_standalone(self): + init_cfg = global_cfg.clone() + backup_cfg = self.set_config_simclr_cifar10(init_cfg) + setup_seed(init_cfg.seed) + update_logger(init_cfg, True) + + data, modified_cfg = get_data(init_cfg.clone()) + init_cfg.merge_from_other_cfg(modified_cfg) + self.assertIsNotNone(data) + self.assertEqual(init_cfg.federate.sample_client_num, + SAMPLE_CLIENT_NUM) + + Fed_runner = FedRunner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) + self.assertIsNotNone(Fed_runner) + test_best_results = Fed_runner.run() + print(test_best_results) + init_cfg.merge_from_other_cfg(backup_cfg) + self.assertLess( + test_best_results["client_summarized_weighted_avg"]['test_loss'], + 100) + + +if __name__ == '__main__': + unittest.main() From 7508570b954ebae57e43a42d5adc6954f61fa688 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Thu, 13 Oct 2022 04:42:16 +0800 Subject: [PATCH 26/46] Update test_simclr_cifar10.py --- tests/test_simclr_cifar10.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_simclr_cifar10.py b/tests/test_simclr_cifar10.py index 0102f090f..bac86ac47 100644 --- a/tests/test_simclr_cifar10.py +++ b/tests/test_simclr_cifar10.py @@ -2,7 +2,8 @@ import unittest from federatedscope.core.auxiliaries.data_builder import get_data -from federatedscope.core.auxiliaries.utils import setup_seed, update_logger +from federatedscope.core.auxiliaries.utils import setup_seed +from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg from federatedscope.core.fed_runner import FedRunner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls From ee465582eeffc64ad69fdbe45eea28b5e8e7f357 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Thu, 13 Oct 2022 05:35:52 +0800 Subject: [PATCH 27/46] Update test_simclr_cifar10.py --- tests/test_simclr_cifar10.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_simclr_cifar10.py b/tests/test_simclr_cifar10.py index bac86ac47..31bff3031 100644 --- a/tests/test_simclr_cifar10.py +++ b/tests/test_simclr_cifar10.py @@ -22,16 +22,21 @@ def set_config_simclr_cifar10(self, cfg): cfg.use_gpu = torch.cuda.is_available() cfg.eval.freq = 5 cfg.eval.metrics = ['loss'] + cfg.eval.split = ['val', 'test'] cfg.federate.mode = 'standalone' cfg.train.local_update_steps = 5 + cfg.train.batch_or_epoch = 'batch' cfg.federate.total_round_num = 20 cfg.federate.sample_client_num = 5 cfg.data.root = 'data/' cfg.data.type = 'Cifar4CL' cfg.data.splits = [0.8, 0.1, 0.1] - cfg.data.batch_size = 10 + cfg.data.batch_size = 256 + cfg.data.splitter = 'lda' + cfg.data.splitter_args = [{'alpha': 0.1}] + cfg.data.num_workers = 4 cfg.data.subsample = 1.0 cfg.model.type = 'SimCLR' From 75526f40861666048819b73d6d8a40b81385e2c9 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Thu, 13 Oct 2022 15:08:30 +0800 Subject: [PATCH 28/46] Update test_simclr_cifar10.py --- tests/test_simclr_cifar10.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_simclr_cifar10.py b/tests/test_simclr_cifar10.py index 31bff3031..6d851f6d3 100644 --- a/tests/test_simclr_cifar10.py +++ b/tests/test_simclr_cifar10.py @@ -30,7 +30,7 @@ def set_config_simclr_cifar10(self, cfg): cfg.federate.total_round_num = 20 cfg.federate.sample_client_num = 5 - cfg.data.root = 'data/' + cfg.data.root = 'test_data/' cfg.data.type = 'Cifar4CL' cfg.data.splits = [0.8, 0.1, 0.1] cfg.data.batch_size = 256 From 7934b1378367ac42a47fc63ba2b0be801e30921a Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Thu, 13 Oct 2022 19:19:41 +0800 Subject: [PATCH 29/46] Update test_simclr_cifar10.py --- tests/test_simclr_cifar10.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_simclr_cifar10.py b/tests/test_simclr_cifar10.py index 6d851f6d3..b086efbd2 100644 --- a/tests/test_simclr_cifar10.py +++ b/tests/test_simclr_cifar10.py @@ -28,6 +28,7 @@ def set_config_simclr_cifar10(self, cfg): cfg.train.local_update_steps = 5 cfg.train.batch_or_epoch = 'batch' cfg.federate.total_round_num = 20 + cfg.federate.client_num = 5 cfg.federate.sample_client_num = 5 cfg.data.root = 'test_data/' @@ -75,7 +76,7 @@ def test_simclr_cifar10_standalone(self): init_cfg.merge_from_other_cfg(backup_cfg) self.assertLess( test_best_results["client_summarized_weighted_avg"]['test_loss'], - 100) + 10000) if __name__ == '__main__': From 7b4a1a613a87b83b7f671a43a95187d5ecac39e8 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Fri, 14 Oct 2022 01:25:35 +0800 Subject: [PATCH 30/46] delete print and repair for unit-test failing --- federatedscope/cl/dataloader/Cifar10.py | 2 -- federatedscope/cl/fedgc/server.py | 3 +-- federatedscope/cl/trainer/trainer.py | 10 ---------- 3 files changed, 1 insertion(+), 14 deletions(-) diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index 1a5e2ea96..65238428f 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -69,11 +69,9 @@ def Cifar4CL(config): data_dict = dict() splitter = get_splitter(config) data_train = splitter(data_train) - print([len(i) for i in data_train]) label_data_train = [[i[1] for i in list_i] for list_i in data_train] data_val = data_train data_test = splitter(data_test, prior=label_data_train) - print([len(i) for i in data_test]) client_num = min(len(data_train), config.federate.client_num diff --git a/federatedscope/cl/fedgc/server.py b/federatedscope/cl/fedgc/server.py index eef0699a8..6b75c776d 100644 --- a/federatedscope/cl/fedgc/server.py +++ b/federatedscope/cl/fedgc/server.py @@ -78,8 +78,7 @@ def check_and_move_on_for_global_loss(self): if other_client_id != client_id] # print("start cal loss") self.loss_list[client_id] = global_loss_fn(z1, z2, others_z2) - print(self.loss_list[client_id]) - print('client {} global_loss:{}'.format(client_id, self.loss_list[client_id])) + logger.info(f'client {client_id} global_loss:{self.loss_list[client_id]}') # print("end cal loss") self.state += 1 diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py index 18a55450b..40716bda1 100644 --- a/federatedscope/cl/trainer/trainer.py +++ b/federatedscope/cl/trainer/trainer.py @@ -7,7 +7,6 @@ from federatedscope.core.trainers.context import CtxVar from federatedscope.core.auxiliaries.enums import LIFECYCLE from federatedscope.core.auxiliaries import utils -from torchviz import make_dot, make_dot_from_trace import torch import numpy as np import copy @@ -79,21 +78,12 @@ def _hook_on_batch_end(self, ctx): ctx.ys_prob.append(ctx.y_prob[0].detach().cpu().numpy()) def train_with_global_loss(self, loss): - """ - Arguments: - loss: loss after global calculate - :returns: - grads: grads to optimize the model of other clients - """ self.ctx.model = self.ctx.model.to(self.ctx.device) -# self.ctx.optimizer.zero_grad() - loss = loss.requires_grad_() * self.global_loss_ratio loss.backward() - self.ctx.optimizer.step() return self.ctx.model.state_dict() From 3810fa5e87d8748100ba736e70d6e68ab2306942 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Fri, 14 Oct 2022 16:29:46 +0800 Subject: [PATCH 31/46] re-try for unit-test timeout error after extending waiting time --- tests/test_simclr_cifar10.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_simclr_cifar10.py b/tests/test_simclr_cifar10.py index b086efbd2..82e54f814 100644 --- a/tests/test_simclr_cifar10.py +++ b/tests/test_simclr_cifar10.py @@ -50,7 +50,7 @@ def set_config_simclr_cifar10(self, cfg): cfg.criterion.type = 'NT_xentloss' cfg.trainer.type = 'cltrainer' - cfg.seed = 1 + cfg.seed = 2 return backup_cfg From a474986551473989989ced540c94184076647ca8 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Fri, 14 Oct 2022 17:31:01 +0800 Subject: [PATCH 32/46] re-try for unit-test timeout error with adding shared memory --- tests/test_simclr_cifar10.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_simclr_cifar10.py b/tests/test_simclr_cifar10.py index 82e54f814..b086efbd2 100644 --- a/tests/test_simclr_cifar10.py +++ b/tests/test_simclr_cifar10.py @@ -50,7 +50,7 @@ def set_config_simclr_cifar10(self, cfg): cfg.criterion.type = 'NT_xentloss' cfg.trainer.type = 'cltrainer' - cfg.seed = 2 + cfg.seed = 1 return backup_cfg From 1eed27e0da719c31c3b5f3ce8e42a5fe44395a28 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 18 Oct 2022 04:35:16 +0800 Subject: [PATCH 33/46] delete never used --- federatedscope/cl/fedgc/server.py | 1 - federatedscope/cl/fedgc/utils.py | 49 ------------------------------- 2 files changed, 50 deletions(-) diff --git a/federatedscope/cl/fedgc/server.py b/federatedscope/cl/fedgc/server.py index 6b75c776d..0c89fd066 100644 --- a/federatedscope/cl/fedgc/server.py +++ b/federatedscope/cl/fedgc/server.py @@ -7,7 +7,6 @@ from federatedscope.core.workers.server import Server from federatedscope.core.auxiliaries.utils import merge_dict from federatedscope.cl.fedgc.utils import global_NT_xentloss -from torchviz import make_dot, make_dot_from_trace logger = logging.getLogger(__name__) diff --git a/federatedscope/cl/fedgc/utils.py b/federatedscope/cl/fedgc/utils.py index 3a0e51f1e..ac3aee3fa 100644 --- a/federatedscope/cl/fedgc/utils.py +++ b/federatedscope/cl/fedgc/utils.py @@ -1,8 +1,6 @@ import torch -import numpy as np import torch.nn as nn import torch.nn.functional as F -import networkx as nx def norm(w): @@ -56,50 +54,3 @@ def forward(self, z1, z2, others_z2=[]): loss = F.cross_entropy(logits, labels, reduction='sum') / (2 * N) return loss - -# def compute_global_NT_xentloss(z1, z2, others_z2=[], temperature=0.5, device='cpu'): -# r""" -# global_NT_xentloss is federated NT_xentloss in server. It collect sample client -# embedding and calculate NT_xentloss from local client positive examples and -# negative examples of local and other clients. -# Arguments: -# z1 (torch.tensor): the embedding of local model. -# z2 (torch.tensor): the embedding of local model using another augmentation. -# others_z2 (list[torch.tensor]): the embedding list of other clients, each client has two embedding. -# returns: -# loss: the NT_xentloss loss for this aggregation of global clients -# :rtype: -# torch.FloatTensor -# """ -# z1, z2 = z1.cuda(device=device), z2.cuda(device=device) -# N, Z = z1.shape -# representations = torch.cat([z1, z2], dim=0) -# similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) - -# l_pos = torch.diag(similarity_matrix, N) -# r_pos = torch.diag(similarity_matrix, -N) -# positives = torch.cat([l_pos, r_pos]).view(2 * N, 1) - -# diag = torch.eye(2*N, dtype=torch.bool, device=device) -# diag[N:,:N] = diag[:N,N:] = diag[:N,:N] -# negatives = similarity_matrix[~diag].view(2*N, -1) - - -# if len(others_z2) != 0: -# for z2_ in others_z2: -# z2_ = z2_.cuda(device=device) -# N2, Z2 = z2_.shape -# representations = torch.cat([z1, z2_], dim=0) -# similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) -# mask = torch.zeros_like(similarity_matrix, dtype=torch.bool, device=device) -# mask[N:,:N] = True -# mask[:N,N:] = True -# negatives_other = similarity_matrix[mask].view(2*N, -1) -# negatives = torch.cat([negatives, negatives_other], dim=1) - - -# logits = torch.cat([positives, negatives], dim=1) / temperature -# labels = torch.zeros(2*N, dtype=torch.int64, device=device) # scalar label per sample -# loss = F.cross_entropy(logits, labels, reduction='sum') - -# return loss / (2 * N) From 0a395c5c96ee644dad34a4f25bdcc29c80cd81ba Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 25 Oct 2022 03:45:00 +0800 Subject: [PATCH 34/46] Update utils.py --- federatedscope/core/data/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/federatedscope/core/data/utils.py b/federatedscope/core/data/utils.py index 1b46cac67..6066ac855 100644 --- a/federatedscope/core/data/utils.py +++ b/federatedscope/core/data/utils.py @@ -38,6 +38,8 @@ def load_dataset(config): dataset, modified_config = load_quadratic_dataset(config) elif config.data.type.lower() in ['femnist', 'celeba']: from federatedscope.cv.dataloader import load_cv_dataset + elif config.data.type.lower() in ['cifar4cl', 'cifar4lp']: + from federatedscope.cl.dataloader import load_cifar_dataset dataset, modified_config = load_cv_dataset(config) elif config.data.type.lower() in [ 'shakespeare', 'twitter', 'subreddit', 'synthetic' From 4cf9a96b55377f8f8d348903d18505a397e2ba91 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 25 Oct 2022 03:46:19 +0800 Subject: [PATCH 35/46] Update SimCLR.py --- federatedscope/cl/model/SimCLR.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/federatedscope/cl/model/SimCLR.py b/federatedscope/cl/model/SimCLR.py index 583dcd365..d795c8567 100644 --- a/federatedscope/cl/model/SimCLR.py +++ b/federatedscope/cl/model/SimCLR.py @@ -98,8 +98,6 @@ def create_backbone(name, num_classes=10, block='BasicBlock'): net = ResNet18(num_classes=num_classes, block=block) elif(name == 'res34'): net = ResNet34(num_classes=num_classes, block=block) - elif(name == 'res56'): - net = ResNet56(num_classes=num_classes, block=block) return net From 09992884ae60bf5b9992fd99377617fb0a2b4146 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 25 Oct 2022 04:56:46 +0800 Subject: [PATCH 36/46] modify format --- federatedscope/cl/dataloader/Cifar10.py | 12 ++++++++---- federatedscope/cl/fedgc/client.py | 12 +++++++----- federatedscope/cl/fedgc/server.py | 20 +++++++++----------- 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index 65238428f..3ec79b95c 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -17,9 +17,11 @@ class SimCLRTransform(): r""" - Data Augmentations of SimCLR refer from https://github.com/akhilmathurs/orchestra/blob/main/utils.py + Data Augmentations of SimCLR refer from + https://github.com/akhilmathurs/orchestra/blob/main/utils.py Arguments: - is_sup (bool): the transform for supervised learning or contrastive learning. + is_sup (bool): the transform for supervised learning + or contrastive learning. :returns: torch.tensor: one output for supervised learning. :returns: @@ -28,7 +30,8 @@ class SimCLRTransform(): """ def __init__(self, is_sup, image_size=32): self.transform = T.Compose([ - T.RandomResizedCrop(image_size, scale=(0.5, 1.0), interpolation=T.InterpolationMode.BICUBIC), + T.RandomResizedCrop(image_size, scale=(0.5, 1.0), + interpolation=T.InterpolationMode.BICUBIC), T.RandomHorizontalFlip(p=0.5), T.RandomApply([T.ColorJitter(0.4,0.4,0.2,0.1)], p=0.8), T.RandomGrayscale(p=0.2), @@ -104,7 +107,8 @@ def Cifar4CL(config): def Cifar4LP(config): r""" - generate Cifar10 Dataset transform and split dict for linear prob evaluation of contrastive learning + generate Cifar10 Dataset transform and split dict for linear prob + evaluation of contrastive learning return { 'client_id': { 'train': DataLoader(), diff --git a/federatedscope/cl/fedgc/client.py b/federatedscope/cl/fedgc/client.py index 8f88ef5ec..59801e551 100644 --- a/federatedscope/cl/fedgc/client.py +++ b/federatedscope/cl/fedgc/client.py @@ -11,8 +11,9 @@ class GlobalContrastFLClient(Client): r""" - GlobalContrastFL(Fedgc) Client receive aggregated model weight from server then update local - weight; it also receive global loss from server to train model and update weight locally. + GlobalContrastFL(Fedgc) Client receive aggregated model weight from + server then update local weight; it also receive global loss from server + to train model and update weight locally. """ def _register_default_handlers(self): self.register_handlers('assign_client_id', @@ -25,18 +26,19 @@ def _register_default_handlers(self): self.register_handlers('global_loss', self.callback_funcs_for_local_backward) self.register_handlers('ss_model_para', self.callback_funcs_for_model_para) - + self.register_handlers('evaluate', self.callback_funcs_for_evaluate) self.register_handlers('finish', self.callback_funcs_for_finish) self.register_handlers('converged', self.callback_funcs_for_converged) - + def callback_funcs_for_local_backward(self, message: Message): round, sender, content = message.state, message.sender, message.content global_loss = content['global_loss'] model_para = self.trainer.train_with_global_loss(global_loss) self.trainer.update(model_para) self.state = round - sample_size, model_para= self.trainer.num_samples, self.trainer.get_model_para() + sample_size= self.trainer.num_samples + model_para = self.trainer.get_model_para() self.comm_manager.send( Message(msg_type='model_para', diff --git a/federatedscope/cl/fedgc/server.py b/federatedscope/cl/fedgc/server.py index 0c89fd066..a6c740a52 100644 --- a/federatedscope/cl/fedgc/server.py +++ b/federatedscope/cl/fedgc/server.py @@ -13,9 +13,9 @@ class GlobalContrastFLServer(Server): r""" - GlobalContrastFL(Fedgc) Server contain two part in training: Fedavg aggragator - for client model weight and calculate global loss from all sampled client embedding - then broadcast all client to train model. + GlobalContrastFL(Fedgc) Server contain two part in training: Fedavg + aggragator for client model weight and calculate global loss from + all sampled client embedding then broadcast all client to train model. """ def __init__(self, ID=-1, @@ -71,14 +71,14 @@ def check_and_move_on_for_global_loss(self): global_loss_fn = global_NT_xentloss(device=self.device) for client_id in train_msg_buffer: - z1, z2 = self.seqs_embedding[client_id][0], self.seqs_embedding[client_id][1] + z1 = self.seqs_embedding[client_id][0] + z2 = self.seqs_embedding[client_id][1] others_z2 = [self.seqs_embedding[other_client_id][1] - for other_client_id in train_msg_buffer - if other_client_id != client_id] -# print("start cal loss") + for other_client_id in train_msg_buffer + if other_client_id != client_id] self.loss_list[client_id] = global_loss_fn(z1, z2, others_z2) - logger.info(f'client {client_id} global_loss:{self.loss_list[client_id]}') -# print("end cal loss") + logger.info(f'client {client_id}' + f'global_loss:{self.loss_list[client_id]}') self.state += 1 if self.state <= self.total_round_num: @@ -89,7 +89,6 @@ def check_and_move_on_for_global_loss(self): 'global_loss': self.loss_list[client_id], } - # Send loss to Clients self.comm_manager.send( Message(msg_type='global_loss', sender=self.ID, @@ -128,7 +127,6 @@ def check_and_move_on(self, if not check_eval_result: # Receiving enough feedback in the training process aggregated_num = self._perform_federated_aggregation() - if self.state % self._cfg.eval.freq == 0 and self.state != \ self.total_round_num: From a18aa3f9b8fef174fdd0231771c55ced34975643 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 25 Oct 2022 15:41:39 +0800 Subject: [PATCH 37/46] modify format --- federatedscope/cl/dataloader/Cifar10.py | 144 ++++++++++-------- federatedscope/cl/fedgc/client.py | 31 ++-- federatedscope/cl/fedgc/server.py | 48 +++--- federatedscope/cl/fedgc/utils.py | 40 +++-- federatedscope/cl/loss/NT_xentloss.py | 24 +-- .../cl/lr_scheduler/ LR_Scheduler.py | 41 +++-- federatedscope/cl/model/SimCLR.py | 85 +++++++---- federatedscope/cl/trainer/trainer.py | 37 ++--- federatedscope/core/aggregator.py | 8 +- federatedscope/core/aggregators/aggregator.py | 4 +- .../core/auxiliaries/aggregator_builder.py | 4 +- .../core/auxiliaries/model_builder.py | 6 +- .../core/auxiliaries/trainer_builder.py | 2 +- 13 files changed, 276 insertions(+), 198 deletions(-) diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index 65238428f..4e5070b83 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -14,12 +14,13 @@ from federatedscope.core.auxiliaries.splitter_builder import get_splitter - class SimCLRTransform(): r""" - Data Augmentations of SimCLR refer from https://github.com/akhilmathurs/orchestra/blob/main/utils.py + Data Augmentations of SimCLR refer from + https://github.com/akhilmathurs/orchestra/blob/main/utils.py Arguments: - is_sup (bool): the transform for supervised learning or contrastive learning. + is_sup (bool): the transform for supervised learning + or contrastive learning. :returns: torch.tensor: one output for supervised learning. :returns: @@ -28,11 +29,14 @@ class SimCLRTransform(): """ def __init__(self, is_sup, image_size=32): self.transform = T.Compose([ - T.RandomResizedCrop(image_size, scale=(0.5, 1.0), interpolation=T.InterpolationMode.BICUBIC), + T.RandomResizedCrop(image_size, + scale=(0.5, 1.0), + interpolation=T.InterpolationMode.BICUBIC), T.RandomHorizontalFlip(p=0.5), - T.RandomApply([T.ColorJitter(0.4,0.4,0.2,0.1)], p=0.8), + T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), T.RandomGrayscale(p=0.2), - T.RandomApply([T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.5), + T.RandomApply([T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], + p=0.5), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) @@ -40,12 +44,13 @@ def __init__(self, is_sup, image_size=32): self.mode = is_sup def __call__(self, x): - if(self.mode): + if (self.mode): return self.transform(x) else: x1 = self.transform(x) x2 = self.transform(x) - return x1, x2 + return x1, x2 + def Cifar4CL(config): r""" @@ -61,10 +66,16 @@ def Cifar4CL(config): transform_train = SimCLRTransform(is_sup=False, image_size=32) path = config.data.root - - data_train = CIFAR10(path, train=True, download=True, transform=transform_train) - data_test = CIFAR10(path, train=False, download=True, transform=transform_train) - + + data_train = CIFAR10(path, + train=True, + download=True, + transform=transform_train) + data_test = CIFAR10(path, + train=False, + download=True, + transform=transform_train) + # Split data into dict data_dict = dict() splitter = get_splitter(config) @@ -73,38 +84,35 @@ def Cifar4CL(config): data_val = data_train data_test = splitter(data_test, prior=label_data_train) - client_num = min(len(data_train), config.federate.client_num ) if config.federate.client_num > 0 else len(data_train) config.merge_from_list(['federate.client_num', client_num]) - for client_idx in range(1, client_num + 1): dataloader_dict = { - 'train': - DataLoader(data_train[client_idx - 1], - config.data.batch_size, - shuffle=config.data.shuffle, - num_workers=config.data.num_workers), - 'val': - DataLoader(data_val[client_idx - 1], - config.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers), - 'test': - DataLoader(data_test[client_idx - 1], - config.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers), - } + 'train': DataLoader(data_train[client_idx - 1], + config.data.batch_size, + shuffle=config.data.shuffle, + num_workers=config.data.num_workers), + 'val': DataLoader(data_val[client_idx - 1], + config.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers), + 'test': DataLoader(data_test[client_idx - 1], + config.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers), + } data_dict[client_idx] = dataloader_dict config = config return data_dict, config + def Cifar4LP(config): r""" - generate Cifar10 Dataset transform and split dict for linear prob evaluation of contrastive learning + generate Cifar10 Dataset transform and split dict for linear prob + evaluation of contrastive learning return { 'client_id': { 'train': DataLoader(), @@ -114,23 +122,33 @@ def Cifar4LP(config): } """ transform_train = T.Compose([ - T.RandomResizedCrop(32, scale=(0.5, 1.0), interpolation=T.InterpolationMode.BICUBIC), - T.RandomHorizontalFlip(p=0.5), - T.ToTensor(), - T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) - ]) - transform_test = T.Compose([ - T.ToTensor(), - T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - ) + T.RandomResizedCrop(32, + scale=(0.5, 1.0), + interpolation=T.InterpolationMode.BICUBIC), + T.RandomHorizontalFlip(p=0.5), + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + transform_test = T.Compose( + [T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) path = config.data.root - - data_train = CIFAR10(path, train=True, download=True, transform=transform_train) - data_val = CIFAR10(path, train=True, download=True, transform=transform_test) - data_test = CIFAR10(path, train=False, download=True, transform=transform_test) - - # Split data into dict + + data_train = CIFAR10(path, + train=True, + download=True, + transform=transform_train) + data_val = CIFAR10(path, + train=True, + download=True, + transform=transform_test) + data_test = CIFAR10(path, + train=False, + download=True, + transform=transform_test) + + # Split data into dict data_dict = dict() # Splitter @@ -140,37 +158,31 @@ def Cifar4LP(config): data_val = splitter(data_val, prior=label_data_train) data_test = splitter(data_test, prior=label_data_train) - client_num = min(len(data_train), config.federate.client_num ) if config.federate.client_num > 0 else len(data_train) config.merge_from_list(['federate.client_num', client_num]) - for client_idx in range(1, client_num + 1): dataloader_dict = { - 'train': - DataLoader(data_train[client_idx - 1], - config.data.batch_size, - shuffle=config.data.shuffle, - num_workers=config.data.num_workers), - 'val': - DataLoader(data_val[client_idx - 1], - config.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers), - 'test': - DataLoader(data_test[client_idx - 1], - config.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers), - } + 'train': DataLoader(data_train[client_idx - 1], + config.data.batch_size, + shuffle=config.data.shuffle, + num_workers=config.data.num_workers), + 'val': DataLoader(data_val[client_idx - 1], + config.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers), + 'test': DataLoader(data_test[client_idx - 1], + config.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers), + } data_dict[client_idx] = dataloader_dict - + config = config return data_dict, config - def load_cifar_dataset(config): if config.data.type == "Cifar4CL": data, modified_config = Cifar4CL(config) diff --git a/federatedscope/cl/fedgc/client.py b/federatedscope/cl/fedgc/client.py index 8f88ef5ec..fd4b471af 100644 --- a/federatedscope/cl/fedgc/client.py +++ b/federatedscope/cl/fedgc/client.py @@ -9,10 +9,12 @@ logger = logging.getLogger(__name__) + class GlobalContrastFLClient(Client): r""" - GlobalContrastFL(Fedgc) Client receive aggregated model weight from server then update local - weight; it also receive global loss from server to train model and update weight locally. + GlobalContrastFL(Fedgc) Client receive aggregated model weight from + server then update local weight; it also receive global loss from server + to train model and update weight locally. """ def _register_default_handlers(self): self.register_handlers('assign_client_id', @@ -22,21 +24,23 @@ def _register_default_handlers(self): self.register_handlers('address', self.callback_funcs_for_address) self.register_handlers('model_para', self.callback_funcs_for_pred_embedding) - self.register_handlers('global_loss', self.callback_funcs_for_local_backward) + self.register_handlers('global_loss', + self.callback_funcs_for_local_backward) self.register_handlers('ss_model_para', self.callback_funcs_for_model_para) - + self.register_handlers('evaluate', self.callback_funcs_for_evaluate) self.register_handlers('finish', self.callback_funcs_for_finish) self.register_handlers('converged', self.callback_funcs_for_converged) - + def callback_funcs_for_local_backward(self, message: Message): round, sender, content = message.state, message.sender, message.content global_loss = content['global_loss'] model_para = self.trainer.train_with_global_loss(global_loss) self.trainer.update(model_para) self.state = round - sample_size, model_para= self.trainer.num_samples, self.trainer.get_model_para() + sample_size = self.trainer.num_samples + model_para = self.trainer.get_model_para() self.comm_manager.send( Message(msg_type='model_para', @@ -44,19 +48,19 @@ def callback_funcs_for_local_backward(self, message: Message): receiver=[sender], state=self.state, content=(sample_size, model_para))) - + def callback_funcs_for_pred_embedding(self, message: Message): round, sender, content = message.state, message.sender, message.content self.trainer.update(content) sample_size, model_para, results = self.trainer.train() self.state = round pred_embedding = self.trainer.get_train_pred_embedding() - - train_log_res = self._monitor.format_eval_res( - results, - rnd=self.state, - role='Client #{}'.format(self.ID), - return_raw=True) + + train_log_res = self._monitor.format_eval_res(results, + rnd=self.state, + role='Client #{}'.format( + self.ID), + return_raw=True) logger.info(train_log_res) self.comm_manager.send( @@ -65,4 +69,3 @@ def callback_funcs_for_pred_embedding(self, message: Message): receiver=[sender], state=self.state, content=(pred_embedding))) - diff --git a/federatedscope/cl/fedgc/server.py b/federatedscope/cl/fedgc/server.py index 0c89fd066..c11ec883b 100644 --- a/federatedscope/cl/fedgc/server.py +++ b/federatedscope/cl/fedgc/server.py @@ -13,9 +13,9 @@ class GlobalContrastFLServer(Server): r""" - GlobalContrastFL(Fedgc) Server contain two part in training: Fedavg aggragator - for client model weight and calculate global loss from all sampled client embedding - then broadcast all client to train model. + GlobalContrastFL(Fedgc) Server contain two part in training: Fedavg + aggragator for client model weight and calculate global loss from + all sampled client embedding then broadcast all client to train model. """ def __init__(self, ID=-1, @@ -36,23 +36,26 @@ def __init__(self, idx: () for idx in range(1, self._cfg.federate.client_num + 1) } - self.loss_list= { + self.loss_list = { idx: 0 for idx in range(1, self._cfg.federate.client_num + 1) } - + def _register_default_handlers(self): self.register_handlers('join_in', self.callback_funcs_for_join_in) self.register_handlers('join_in_info', self.callback_funcs_for_join_in) self.register_handlers('model_para', self.callback_funcs_model_para) self.register_handlers('metrics', self.callback_funcs_for_metrics) - self.register_handlers('pred_embedding', self.callback_funcs_global_loss) - + self.register_handlers('pred_embedding', + self.callback_funcs_global_loss) + def check_and_move_on_for_global_loss(self): minimal_number = self.sample_client_num - if self.check_buffer(self.state, minimal_number, check_eval_result=False): + if self.check_buffer(self.state, + minimal_number, + check_eval_result=False): # Receiving enough feedback in the training process @@ -71,14 +74,17 @@ def check_and_move_on_for_global_loss(self): global_loss_fn = global_NT_xentloss(device=self.device) for client_id in train_msg_buffer: - z1, z2 = self.seqs_embedding[client_id][0], self.seqs_embedding[client_id][1] - others_z2 = [self.seqs_embedding[other_client_id][1] - for other_client_id in train_msg_buffer - if other_client_id != client_id] -# print("start cal loss") - self.loss_list[client_id] = global_loss_fn(z1, z2, others_z2) - logger.info(f'client {client_id} global_loss:{self.loss_list[client_id]}') -# print("end cal loss") + z1 = self.seqs_embedding[client_id][0] + z2 = self.seqs_embedding[client_id][1] + others_z2 = [ + self.seqs_embedding[other_client_id][1] + for other_client_id in train_msg_buffer + if other_client_id != client_id + ] + self.loss_list[client_id] = global_loss_fn( + z1, z2, others_z2) + logger.info(f'client {client_id}' + f'global_loss:{self.loss_list[client_id]}') self.state += 1 if self.state <= self.total_round_num: @@ -89,14 +95,13 @@ def check_and_move_on_for_global_loss(self): 'global_loss': self.loss_list[client_id], } - # Send loss to Clients self.comm_manager.send( Message(msg_type='global_loss', sender=self.ID, receiver=[client_id], state=self.state, content=msg_list)) - + def check_and_move_on(self, check_eval_result=False, min_received_num=None): @@ -129,7 +134,6 @@ def check_and_move_on(self, # Receiving enough feedback in the training process aggregated_num = self._perform_federated_aggregation() - if self.state % self._cfg.eval.freq == 0 and self.state != \ self.total_round_num: # Evaluate @@ -161,8 +165,8 @@ def check_and_move_on(self, else: move_on_flag = False - return move_on_flag - + return move_on_flag + def callback_funcs_global_loss(self, message: Message): """ The handling function for receiving model embeddings, which triggers @@ -199,7 +203,7 @@ def callback_funcs_global_loss(self, message: Message): move_on_flag = self.check_and_move_on_for_global_loss() return move_on_flag - + def callback_funcs_model_para(self, message: Message): """ The handling function for receiving model parameters, which triggers diff --git a/federatedscope/cl/fedgc/utils.py b/federatedscope/cl/fedgc/utils.py index ac3aee3fa..7deb811e4 100644 --- a/federatedscope/cl/fedgc/utils.py +++ b/federatedscope/cl/fedgc/utils.py @@ -6,6 +6,7 @@ def norm(w): return torch.norm(torch.cat([v.flatten() for v in w.values()])).item() + class global_NT_xentloss(nn.Module): r""" NT_xentloss definition adapted from https://github.com/PatrickHua/SimSiam @@ -21,36 +22,43 @@ def __init__(self, temperature=0.1, device=torch.device("cpu")): super(global_NT_xentloss, self).__init__() self.temperature = temperature self.device = device - + def forward(self, z1, z2, others_z2=[]): - N, Z = z1.shape - z1, z2 = z1.to(self.device), z2.to(self.device) + N, Z = z1.shape + z1, z2 = z1.to(self.device), z2.to(self.device) representations = torch.cat([z1, z2], dim=0) - similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) + similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), + representations.unsqueeze(0), + dim=-1) l_pos = torch.diag(similarity_matrix, N) r_pos = torch.diag(similarity_matrix, -N) positives = torch.cat([l_pos, r_pos]).view(2 * N, 1) - diag = torch.eye(2*N, dtype=torch.bool, device=self.device) - diag[N:,:N] = diag[:N,N:] = diag[:N,:N] - negatives = similarity_matrix[~diag].view(2*N, -1) + diag = torch.eye(2 * N, dtype=torch.bool, device=self.device) + diag[N:, :N] = diag[:N, N:] = diag[:N, :N] + negatives = similarity_matrix[~diag].view(2 * N, -1) if len(others_z2) != 0: for z2_ in others_z2: z2_ = z2_.detach().to(self.device) N2, Z2 = z2_.shape representations = torch.cat([z1, z2_], dim=0) - similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) - mask = torch.zeros_like(similarity_matrix, dtype=torch.bool, device=self.device) - mask[N:,:N] = True - mask[:N,N:] = True - negatives_other = similarity_matrix[mask].view(2*N, -1) + similarity_matrix = F.cosine_similarity( + representations.unsqueeze(1), + representations.unsqueeze(0), + dim=-1) + mask = torch.zeros_like(similarity_matrix, + dtype=torch.bool, + device=self.device) + mask[N:, :N] = True + mask[:N, N:] = True + negatives_other = similarity_matrix[mask].view(2 * N, -1) negatives = torch.cat([negatives, negatives_other], dim=1) - logits = torch.cat([positives, negatives], dim=1) / self.temperature - labels = torch.zeros(2*N, dtype=torch.int64, device=self.device) # scalar label per sample + labels = torch.zeros(2 * N, dtype=torch.int64, + device=self.device) # scalar label per sample loss = F.cross_entropy(logits, labels, reduction='sum') / (2 * N) - - return loss + + return loss diff --git a/federatedscope/cl/loss/NT_xentloss.py b/federatedscope/cl/loss/NT_xentloss.py index 349946043..180128cb1 100644 --- a/federatedscope/cl/loss/NT_xentloss.py +++ b/federatedscope/cl/loss/NT_xentloss.py @@ -4,6 +4,7 @@ from federatedscope.register import register_criterion + class NT_xentloss(nn.Module): r""" NT_xentloss definition adapted from https://github.com/PatrickHua/SimSiam @@ -18,26 +19,29 @@ class NT_xentloss(nn.Module): def __init__(self, temperature=0.1): super(NT_xentloss, self).__init__() self.temperature = temperature - + def forward(self, z1, z2): - N, Z = z1.shape - device = z1.device + N, Z = z1.shape + device = z1.device representations = torch.cat([z1, z2], dim=0) - similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) + similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), + representations.unsqueeze(0), + dim=-1) l_pos = torch.diag(similarity_matrix, N) r_pos = torch.diag(similarity_matrix, -N) positives = torch.cat([l_pos, r_pos]).view(2 * N, 1) - diag = torch.eye(2*N, dtype=torch.bool, device=device) - diag[N:,:N] = diag[:N,N:] = diag[:N,:N] - negatives = similarity_matrix[~diag].view(2*N, -1) + diag = torch.eye(2 * N, dtype=torch.bool, device=device) + diag[N:, :N] = diag[:N, N:] = diag[:N, :N] + negatives = similarity_matrix[~diag].view(2 * N, -1) logits = torch.cat([positives, negatives], dim=1) / self.temperature - labels = torch.zeros(2*N, device=device, dtype=torch.int64) # scalar label per sample + labels = torch.zeros(2 * N, device=device, + dtype=torch.int64) # scalar label per sample loss = F.cross_entropy(logits, labels, reduction='sum') / (2 * N) - - return loss + + return loss def create_NT_xentloss(type, device): diff --git a/federatedscope/cl/lr_scheduler/ LR_Scheduler.py b/federatedscope/cl/lr_scheduler/ LR_Scheduler.py index ce0c2584e..90b82a51e 100644 --- a/federatedscope/cl/lr_scheduler/ LR_Scheduler.py +++ b/federatedscope/cl/lr_scheduler/ LR_Scheduler.py @@ -4,32 +4,46 @@ # LR Scheduler class LR_Scheduler(object): - def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False): + def __init__(self, + optimizer, + warmup_epochs, + warmup_lr, + num_epochs, + base_lr, + final_lr, + iter_per_epoch, + constant_predictor_lr=False): self.base_lr = base_lr self.constant_predictor_lr = constant_predictor_lr warmup_iter = iter_per_epoch * warmup_epochs warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter) decay_iter = iter_per_epoch * (num_epochs - warmup_epochs) - cosine_lr_schedule = final_lr+0.5*(base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter)) - - self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) + cosine_lr_schedule = final_lr + 0.5 * (base_lr - final_lr) * ( + 1 + np.cos(np.pi * np.arange(decay_iter) / decay_iter)) + + self.lr_schedule = np.concatenate( + (warmup_lr_schedule, cosine_lr_schedule)) self.optimizer = optimizer self.iter = 0 self.current_lr = 0 + def step(self): for param_group in self.optimizer.param_groups: - if self.constant_predictor_lr and param_group['name'] == 'predictor': + if self.constant_predictor_lr and param_group[ + 'name'] == 'predictor': param_group['lr'] = self.base_lr else: lr = param_group['lr'] = self.lr_schedule[self.iter] - + self.iter += 1 self.current_lr = lr return lr + def get_lr(self): return self.current_lr - + + def get_scheduler(optimizer, type): try: import torch.optim as optim @@ -40,10 +54,15 @@ def get_scheduler(optimizer, type): if type == 'cos_lr_scheduler': if optim is not None: lr_lambda = [lambda epoch: epoch // 30] - scheduler = optim.lr_scheduler.LambdaLR(optimizer, warmup_epochs=0, warmup_lr=0, - num_epochs=50, base_lr=30, - final_lr=0, iter_per_epoch=int(50000/512)) + scheduler = optim.lr_scheduler.LambdaLR(optimizer, + warmup_epochs=0, + warmup_lr=0, + num_epochs=50, + base_lr=30, + final_lr=0, + iter_per_epoch=int(50000 / + 512)) return scheduler -register_scheduler('cos_lr_scheduler', get_scheduler) \ No newline at end of file +register_scheduler('cos_lr_scheduler', get_scheduler) diff --git a/federatedscope/cl/model/SimCLR.py b/federatedscope/cl/model/SimCLR.py index d795c8567..a23ffd4cf 100644 --- a/federatedscope/cl/model/SimCLR.py +++ b/federatedscope/cl/model/SimCLR.py @@ -7,6 +7,7 @@ from collections import OrderedDict from federatedscope.contrib.model.resnet import BasicBlock, Bottleneck + # Model class class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10, cfg=None): @@ -14,18 +15,23 @@ def __init__(self, block, num_blocks, num_classes=10, cfg=None): self.train_sup = (num_classes > 0) self.in_planes = 64 - self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.conv1 = nn.Conv2d(3, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=False) self.bn1 = nn.BatchNorm2d(64, affine=True) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) - self.output_dim = 512*block.expansion - if(self.train_sup): - self.linear = nn.Linear(512*block.expansion, num_classes) + self.output_dim = 512 * block.expansion + if (self.train_sup): + self.linear = nn.Linear(512 * block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): - strides = [stride] + [1]*(num_blocks-1) + strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) @@ -33,34 +39,42 @@ def _make_layer(self, block, planes, num_blocks, stride): return nn.Sequential(*layers) def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.adaptive_avg_pool2d(out, (1, 1)) out = out.view(out.size(0), -1) - if(self.train_sup): + if (self.train_sup): out = self.linear(out) return out + class ResNet_basic(nn.Module): def __init__(self, block, num_blocks, num_classes=10, cfg=None): super(ResNet_basic, self).__init__() self.train_sup = (num_classes > 0) self.in_planes = 16 - self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.conv1 = nn.Conv2d(3, + 16, + kernel_size=3, + stride=1, + padding=1, + bias=False) self.bn1 = nn.BatchNorm2d(16, affine=True) self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) - self.output_dim = 512*block.expansion - if(self.train_sup): - self.linear = nn.Linear(64*block.expansion, num_classes, bias=True) + self.output_dim = 512 * block.expansion + if (self.train_sup): + self.linear = nn.Linear(64 * block.expansion, + num_classes, + bias=True) def _make_layer(self, block, planes, num_blocks, stride): - strides = [stride] + [1]*(num_blocks-1) + strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) @@ -68,40 +82,42 @@ def _make_layer(self, block, planes, num_blocks, stride): return nn.Sequential(*layers) def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = F.adaptive_avg_pool2d(out, (1, 1)) out = out.view(out.size(0), -1) - if(self.train_sup): + if (self.train_sup): out = self.linear(out) return out def get_block(block): - if(block=="BasicBlock"): + if (block == "BasicBlock"): return BasicBlock - elif(block=="Bottleneck"): + elif (block == "Bottleneck"): return Bottleneck + def ResNet18(num_classes=10, block="BasicBlock"): - return ResNet(get_block(block), [2,2,2,2], num_classes=num_classes) + return ResNet(get_block(block), [2, 2, 2, 2], num_classes=num_classes) -def ResNet34(num_classes=10, block="BasicBlock"): - return ResNet(get_block(block), [3,4,6,3], num_classes=num_classes) +def ResNet34(num_classes=10, block="BasicBlock"): + return ResNet(get_block(block), [3, 4, 6, 3], num_classes=num_classes) def create_backbone(name, num_classes=10, block='BasicBlock'): - if(name == 'res18'): + if (name == 'res18'): net = ResNet18(num_classes=num_classes, block=block) - elif(name == 'res34'): + elif (name == 'res34'): net = ResNet34(num_classes=num_classes, block=block) return net -# Projector + +# Projector class projection_MLP_simclr(nn.Module): def __init__(self, in_dim, hidden_dim=512, out_dim=512): super(projection_MLP_simclr, self).__init__() @@ -111,10 +127,11 @@ def __init__(self, in_dim, hidden_dim=512, out_dim=512): self.layer2_bn = nn.BatchNorm1d(out_dim, affine=False) def forward(self, x): - x = F.relu(self.layer1_bn(self.layer1(x))) - x = self.layer2_bn(self.layer2(x)) + x = F.relu(self.layer1_bn(self.layer1(x))) + x = self.layer2_bn(self.layer2(x)) return x + # SimCLR class simclr(nn.Module): def __init__(self, bbone_arch): @@ -122,13 +139,17 @@ def __init__(self, bbone_arch): self.register_buffer("rounds_done", torch.zeros(1)) self.backbone = create_backbone(bbone_arch, num_classes=0) - self.projector = projection_MLP_simclr(self.backbone.output_dim, hidden_dim=512, out_dim=512) + self.projector = projection_MLP_simclr(self.backbone.output_dim, + hidden_dim=512, + out_dim=512) def forward(self, x1, x2, x3=None, deg_labels=None): - z1, z2 = self.projector(self.backbone(x1)), self.projector(self.backbone(x2)) + z1, z2 = self.projector(self.backbone(x1)), self.projector( + self.backbone(x2)) return z1, z2 + class simclr_linearprob(nn.Module): def __init__(self, bbone_arch, num_classes=10): super(simclr_linearprob, self).__init__() @@ -143,7 +164,8 @@ def forward(self, x): out = self.linear(out) return out - + + class simclr_supervised(nn.Module): def __init__(self, bbone_arch, num_classes=10): super(simclr_supervised, self).__init__() @@ -157,21 +179,24 @@ def forward(self, x): out = self.linear(out) return out - + + def ModelBuilder(model_config, local_data): - # You can also build models without local_data + # You can also build models without local_data if model_config.type == "SimCLR": model = simclr(bbone_arch='res18') return model if model_config.type in ["SimCLR_linear"]: model = simclr_linearprob(bbone_arch='res18', num_classes=10) return model - if model_config.type in ["supervised_local","supervised_fedavg"]: + if model_config.type in ["supervised_local", "supervised_fedavg"]: model = simclr_supervised(bbone_arch='res18', num_classes=10) return model + from federatedscope.register import register_model + def get_simclr(model_config, local_data): model = ModelBuilder(model_config, local_data) return model diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py index 40716bda1..2232815ff 100644 --- a/federatedscope/cl/trainer/trainer.py +++ b/federatedscope/cl/trainer/trainer.py @@ -21,24 +21,25 @@ def __init__(self, only_for_eval=False, monitor=None): super(CLTrainer, self).__init__(model, data, device, config, - only_for_eval, monitor) - self.batches_aug_data_1, self.batches_aug_data_2 = torch.empty(1), torch.empty(1) + only_for_eval, monitor) + self.batches_aug_data_1, self.batches_aug_data_2 = torch.empty( + 1), torch.empty(1) self.z1, self.z2 = torch.empty(1), torch.empty(1) self.num_samples = 0 self.local_loss_ratio = 1 self.global_loss_ratio = 5 - - def get_train_pred_embedding(self): + def get_train_pred_embedding(self): model = self.ctx.model.to(self.ctx.device) - x1, x2 = self.batches_aug_data_1.to(self.ctx.device), self.batches_aug_data_2.to(self.ctx.device) + x1, x2 = self.batches_aug_data_1.to( + self.ctx.device), self.batches_aug_data_2.to(self.ctx.device) z1, z2 = model(x1, x2) self.batches_aug_data_1, self.batches_aug_data_2 = [], [] self.z1, self.z2 = z1, z2 self.ctx.model.to(torch.device('cpu')) - + return [self.z1, self.z2] - + def _hook_on_batch_forward(self, ctx): x, label = [utils.move_to(_, ctx.device) for _ in ctx.data_batch] x1, x2 = x[0], x[1] @@ -48,7 +49,7 @@ def _hook_on_batch_forward(self, ctx): z1, z2 = ctx.model(x1, x2) if len(label.size()) == 0: label = label.unsqueeze(0) - + ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) ctx.y_prob = CtxVar((z1, z2), LIFECYCLE.BATCH) ctx.loss_batch = CtxVar(ctx.criterion(z1, z2), LIFECYCLE.BATCH) @@ -61,11 +62,11 @@ def _hook_on_batch_backward(self, ctx): if ctx.grad_clip > 0: torch.nn.utils.clip_grad_norm_(ctx.model.parameters(), ctx.grad_clip) - + ctx.optimizer.step() if ctx.scheduler is not None: ctx.scheduler.step() - + def _hook_on_batch_end(self, ctx): # update statistics ctx.num_samples += ctx.batch_size @@ -76,19 +77,19 @@ def _hook_on_batch_end(self, ctx): # cache label for evaluate ctx.ys_true.append(ctx.y_true.detach().cpu().numpy()) ctx.ys_prob.append(ctx.y_prob[0].detach().cpu().numpy()) - + def train_with_global_loss(self, loss): self.ctx.model = self.ctx.model.to(self.ctx.device) loss = loss.requires_grad_() * self.global_loss_ratio loss.backward() - + self.ctx.optimizer.step() - + return self.ctx.model.state_dict() - + class LPTrainer(GeneralTorchTrainer): def __init__(self, model, @@ -98,19 +99,19 @@ def __init__(self, only_for_eval=False, monitor=None): super(LPTrainer, self).__init__(model, data, device, config, - only_for_eval, monitor) - + only_for_eval, monitor) + if config.federate.restore_from != '': self.load_model(config.federate.restore_from) - + def call_cl_trainer(trainer_type): if trainer_type == 'cltrainer': trainer_builder = CLTrainer return trainer_builder elif trainer_type == 'lptrainer': trainer_builder = LPTrainer - return trainer_builder + return trainer_builder register_trainer('cltrainer', call_cl_trainer) diff --git a/federatedscope/core/aggregator.py b/federatedscope/core/aggregator.py index 584d8c7c2..68c7ba65a 100644 --- a/federatedscope/core/aggregator.py +++ b/federatedscope/core/aggregator.py @@ -109,14 +109,14 @@ class NoCommunicationAggregator(ClientsAvgAggregator): """ def __init__(self, model=None, device='cpu', config=None): super(NoCommunicationAggregator, self).__init__(model, device, config) - + def aggregate(self, agg_info): # do nothing return {} - + def update(self, model_parameters): pass - + def save_model(self, path, cur_round=-1): assert self.model is not None @@ -132,7 +132,7 @@ def load_model(self, path): return ckpt['cur_round'] else: raise ValueError("The file {} does NOT exist".format(path)) - + def _para_weighted_avg(self, models, recover_fun=None): pass diff --git a/federatedscope/core/aggregators/aggregator.py b/federatedscope/core/aggregators/aggregator.py index d633df090..12a132890 100644 --- a/federatedscope/core/aggregators/aggregator.py +++ b/federatedscope/core/aggregators/aggregator.py @@ -20,7 +20,7 @@ def __init__(self, model=None, device='cpu', config=None): self.model = model self.device = device self.cfg = config - + def update(self, model_parameters): ''' Arguments: @@ -43,7 +43,7 @@ def load_model(self, path): return ckpt['cur_round'] else: raise ValueError("The file {} does NOT exist".format(path)) - + def aggregate(self, agg_info): # do nothing return {} diff --git a/federatedscope/core/auxiliaries/aggregator_builder.py b/federatedscope/core/auxiliaries/aggregator_builder.py index 8b7fad6e7..e43e173f3 100644 --- a/federatedscope/core/auxiliaries/aggregator_builder.py +++ b/federatedscope/core/auxiliaries/aggregator_builder.py @@ -49,8 +49,8 @@ def get_aggregator(method, model=None, device=None, online=False, config=None): beta=config.personalization.beta) elif aggregator_type == 'no_communication': return NoCommunicationAggregator(model=model, - device=device, - config=config) + device=device, + config=config) else: raise NotImplementedError( "Aggregator {} is not implemented.".format(aggregator_type)) diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index a1527cd40..e819ace9a 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -144,12 +144,14 @@ def get_model(model_config, local_data=None, backend='torch'): elif model_config.type.lower() in ['convnet2', 'convnet5', 'vgg11', 'lr']: from federatedscope.cv.model import get_cnn model = get_cnn(model_config, input_shape) - elif model_config.type.lower() in ['simclr', 'simclr_linear',"supervised_local","supervised_fedavg"]: + elif model_config.type.lower() in [ + 'simclr', 'simclr_linear', "supervised_local", "supervised_fedavg" + ]: from federatedscope.cl.model import get_simclr model = get_simclr(model_config, input_shape) if model_config.type.lower().endswith('linear'): for name, value in model.named_parameters(): - if not name.startswith('linear') : + if not name.startswith('linear'): value.requires_grad = False elif model_config.type.lower() in ['lstm']: from federatedscope.nlp.model import get_rnn diff --git a/federatedscope/core/auxiliaries/trainer_builder.py b/federatedscope/core/auxiliaries/trainer_builder.py index 6ab82e4e2..b7c0e1c7c 100644 --- a/federatedscope/core/auxiliaries/trainer_builder.py +++ b/federatedscope/core/auxiliaries/trainer_builder.py @@ -65,7 +65,7 @@ def get_trainer(model=None, elif config.trainer.type.lower() in ['nlptrainer']: dict_path = "federatedscope.nlp.trainer.trainer" elif config.trainer.type.lower() in ['cltrainer', 'lptrainer']: - dict_path = "federatedscope.cl.trainer.trainer" + dict_path = "federatedscope.cl.trainer.trainer" elif config.trainer.type.lower() in [ 'graphminibatch_trainer', ]: From d213297e7e4c48ff0bcb82a9cca6f32aa9c20ab8 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 25 Oct 2022 22:10:20 +0800 Subject: [PATCH 38/46] modify yapf format --- federatedscope/cl/dataloader/Cifar10.py | 134 +++++++++++++----------- federatedscope/cl/fedgc/client.py | 21 ++-- federatedscope/cl/fedgc/server.py | 36 ++++--- 3 files changed, 103 insertions(+), 88 deletions(-) diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index 3ec79b95c..4e5070b83 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -14,7 +14,6 @@ from federatedscope.core.auxiliaries.splitter_builder import get_splitter - class SimCLRTransform(): r""" Data Augmentations of SimCLR refer from @@ -30,12 +29,14 @@ class SimCLRTransform(): """ def __init__(self, is_sup, image_size=32): self.transform = T.Compose([ - T.RandomResizedCrop(image_size, scale=(0.5, 1.0), + T.RandomResizedCrop(image_size, + scale=(0.5, 1.0), interpolation=T.InterpolationMode.BICUBIC), T.RandomHorizontalFlip(p=0.5), - T.RandomApply([T.ColorJitter(0.4,0.4,0.2,0.1)], p=0.8), + T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), T.RandomGrayscale(p=0.2), - T.RandomApply([T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.5), + T.RandomApply([T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], + p=0.5), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) @@ -43,12 +44,13 @@ def __init__(self, is_sup, image_size=32): self.mode = is_sup def __call__(self, x): - if(self.mode): + if (self.mode): return self.transform(x) else: x1 = self.transform(x) x2 = self.transform(x) - return x1, x2 + return x1, x2 + def Cifar4CL(config): r""" @@ -64,10 +66,16 @@ def Cifar4CL(config): transform_train = SimCLRTransform(is_sup=False, image_size=32) path = config.data.root - - data_train = CIFAR10(path, train=True, download=True, transform=transform_train) - data_test = CIFAR10(path, train=False, download=True, transform=transform_train) - + + data_train = CIFAR10(path, + train=True, + download=True, + transform=transform_train) + data_test = CIFAR10(path, + train=False, + download=True, + transform=transform_train) + # Split data into dict data_dict = dict() splitter = get_splitter(config) @@ -76,35 +84,31 @@ def Cifar4CL(config): data_val = data_train data_test = splitter(data_test, prior=label_data_train) - client_num = min(len(data_train), config.federate.client_num ) if config.federate.client_num > 0 else len(data_train) config.merge_from_list(['federate.client_num', client_num]) - for client_idx in range(1, client_num + 1): dataloader_dict = { - 'train': - DataLoader(data_train[client_idx - 1], - config.data.batch_size, - shuffle=config.data.shuffle, - num_workers=config.data.num_workers), - 'val': - DataLoader(data_val[client_idx - 1], - config.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers), - 'test': - DataLoader(data_test[client_idx - 1], - config.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers), - } + 'train': DataLoader(data_train[client_idx - 1], + config.data.batch_size, + shuffle=config.data.shuffle, + num_workers=config.data.num_workers), + 'val': DataLoader(data_val[client_idx - 1], + config.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers), + 'test': DataLoader(data_test[client_idx - 1], + config.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers), + } data_dict[client_idx] = dataloader_dict config = config return data_dict, config + def Cifar4LP(config): r""" generate Cifar10 Dataset transform and split dict for linear prob @@ -118,23 +122,33 @@ def Cifar4LP(config): } """ transform_train = T.Compose([ - T.RandomResizedCrop(32, scale=(0.5, 1.0), interpolation=T.InterpolationMode.BICUBIC), - T.RandomHorizontalFlip(p=0.5), - T.ToTensor(), - T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) - ]) - transform_test = T.Compose([ - T.ToTensor(), - T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - ) + T.RandomResizedCrop(32, + scale=(0.5, 1.0), + interpolation=T.InterpolationMode.BICUBIC), + T.RandomHorizontalFlip(p=0.5), + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + transform_test = T.Compose( + [T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) path = config.data.root - - data_train = CIFAR10(path, train=True, download=True, transform=transform_train) - data_val = CIFAR10(path, train=True, download=True, transform=transform_test) - data_test = CIFAR10(path, train=False, download=True, transform=transform_test) - - # Split data into dict + + data_train = CIFAR10(path, + train=True, + download=True, + transform=transform_train) + data_val = CIFAR10(path, + train=True, + download=True, + transform=transform_test) + data_test = CIFAR10(path, + train=False, + download=True, + transform=transform_test) + + # Split data into dict data_dict = dict() # Splitter @@ -144,37 +158,31 @@ def Cifar4LP(config): data_val = splitter(data_val, prior=label_data_train) data_test = splitter(data_test, prior=label_data_train) - client_num = min(len(data_train), config.federate.client_num ) if config.federate.client_num > 0 else len(data_train) config.merge_from_list(['federate.client_num', client_num]) - for client_idx in range(1, client_num + 1): dataloader_dict = { - 'train': - DataLoader(data_train[client_idx - 1], - config.data.batch_size, - shuffle=config.data.shuffle, - num_workers=config.data.num_workers), - 'val': - DataLoader(data_val[client_idx - 1], - config.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers), - 'test': - DataLoader(data_test[client_idx - 1], - config.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers), - } + 'train': DataLoader(data_train[client_idx - 1], + config.data.batch_size, + shuffle=config.data.shuffle, + num_workers=config.data.num_workers), + 'val': DataLoader(data_val[client_idx - 1], + config.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers), + 'test': DataLoader(data_test[client_idx - 1], + config.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers), + } data_dict[client_idx] = dataloader_dict - + config = config return data_dict, config - def load_cifar_dataset(config): if config.data.type == "Cifar4CL": data, modified_config = Cifar4CL(config) diff --git a/federatedscope/cl/fedgc/client.py b/federatedscope/cl/fedgc/client.py index 59801e551..fd4b471af 100644 --- a/federatedscope/cl/fedgc/client.py +++ b/federatedscope/cl/fedgc/client.py @@ -9,6 +9,7 @@ logger = logging.getLogger(__name__) + class GlobalContrastFLClient(Client): r""" GlobalContrastFL(Fedgc) Client receive aggregated model weight from @@ -23,7 +24,8 @@ def _register_default_handlers(self): self.register_handlers('address', self.callback_funcs_for_address) self.register_handlers('model_para', self.callback_funcs_for_pred_embedding) - self.register_handlers('global_loss', self.callback_funcs_for_local_backward) + self.register_handlers('global_loss', + self.callback_funcs_for_local_backward) self.register_handlers('ss_model_para', self.callback_funcs_for_model_para) @@ -37,7 +39,7 @@ def callback_funcs_for_local_backward(self, message: Message): model_para = self.trainer.train_with_global_loss(global_loss) self.trainer.update(model_para) self.state = round - sample_size= self.trainer.num_samples + sample_size = self.trainer.num_samples model_para = self.trainer.get_model_para() self.comm_manager.send( @@ -46,19 +48,19 @@ def callback_funcs_for_local_backward(self, message: Message): receiver=[sender], state=self.state, content=(sample_size, model_para))) - + def callback_funcs_for_pred_embedding(self, message: Message): round, sender, content = message.state, message.sender, message.content self.trainer.update(content) sample_size, model_para, results = self.trainer.train() self.state = round pred_embedding = self.trainer.get_train_pred_embedding() - - train_log_res = self._monitor.format_eval_res( - results, - rnd=self.state, - role='Client #{}'.format(self.ID), - return_raw=True) + + train_log_res = self._monitor.format_eval_res(results, + rnd=self.state, + role='Client #{}'.format( + self.ID), + return_raw=True) logger.info(train_log_res) self.comm_manager.send( @@ -67,4 +69,3 @@ def callback_funcs_for_pred_embedding(self, message: Message): receiver=[sender], state=self.state, content=(pred_embedding))) - diff --git a/federatedscope/cl/fedgc/server.py b/federatedscope/cl/fedgc/server.py index a6c740a52..c11ec883b 100644 --- a/federatedscope/cl/fedgc/server.py +++ b/federatedscope/cl/fedgc/server.py @@ -36,23 +36,26 @@ def __init__(self, idx: () for idx in range(1, self._cfg.federate.client_num + 1) } - self.loss_list= { + self.loss_list = { idx: 0 for idx in range(1, self._cfg.federate.client_num + 1) } - + def _register_default_handlers(self): self.register_handlers('join_in', self.callback_funcs_for_join_in) self.register_handlers('join_in_info', self.callback_funcs_for_join_in) self.register_handlers('model_para', self.callback_funcs_model_para) self.register_handlers('metrics', self.callback_funcs_for_metrics) - self.register_handlers('pred_embedding', self.callback_funcs_global_loss) - + self.register_handlers('pred_embedding', + self.callback_funcs_global_loss) + def check_and_move_on_for_global_loss(self): minimal_number = self.sample_client_num - if self.check_buffer(self.state, minimal_number, check_eval_result=False): + if self.check_buffer(self.state, + minimal_number, + check_eval_result=False): # Receiving enough feedback in the training process @@ -73,11 +76,14 @@ def check_and_move_on_for_global_loss(self): for client_id in train_msg_buffer: z1 = self.seqs_embedding[client_id][0] z2 = self.seqs_embedding[client_id][1] - others_z2 = [self.seqs_embedding[other_client_id][1] - for other_client_id in train_msg_buffer - if other_client_id != client_id] - self.loss_list[client_id] = global_loss_fn(z1, z2, others_z2) - logger.info(f'client {client_id}' + others_z2 = [ + self.seqs_embedding[other_client_id][1] + for other_client_id in train_msg_buffer + if other_client_id != client_id + ] + self.loss_list[client_id] = global_loss_fn( + z1, z2, others_z2) + logger.info(f'client {client_id}' f'global_loss:{self.loss_list[client_id]}') self.state += 1 @@ -95,7 +101,7 @@ def check_and_move_on_for_global_loss(self): receiver=[client_id], state=self.state, content=msg_list)) - + def check_and_move_on(self, check_eval_result=False, min_received_num=None): @@ -127,7 +133,7 @@ def check_and_move_on(self, if not check_eval_result: # Receiving enough feedback in the training process aggregated_num = self._perform_federated_aggregation() - + if self.state % self._cfg.eval.freq == 0 and self.state != \ self.total_round_num: # Evaluate @@ -159,8 +165,8 @@ def check_and_move_on(self, else: move_on_flag = False - return move_on_flag - + return move_on_flag + def callback_funcs_global_loss(self, message: Message): """ The handling function for receiving model embeddings, which triggers @@ -197,7 +203,7 @@ def callback_funcs_global_loss(self, message: Message): move_on_flag = self.check_and_move_on_for_global_loss() return move_on_flag - + def callback_funcs_model_para(self, message: Message): """ The handling function for receiving model parameters, which triggers From d6fadbc9e88666eec3693061d90913d6e3a70049 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 25 Oct 2022 23:12:46 +0800 Subject: [PATCH 39/46] debug for unit-test --- federatedscope/cl/dataloader/Cifar10.py | 2 +- federatedscope/core/aggregator.py | 300 ------------------ .../core/auxiliaries/data_builder.py | 2 +- .../core/auxiliaries/worker_builder.py | 4 +- federatedscope/core/data/utils.py | 3 +- federatedscope/core/workers/server.py | 1 - 6 files changed, 6 insertions(+), 306 deletions(-) delete mode 100644 federatedscope/core/aggregator.py diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index 4e5070b83..d57eaf90c 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -183,7 +183,7 @@ def Cifar4LP(config): return data_dict, config -def load_cifar_dataset(config): +def load_cifar_dataset_for_contrast_learning(config): if config.data.type == "Cifar4CL": data, modified_config = Cifar4CL(config) return data, modified_config diff --git a/federatedscope/core/aggregator.py b/federatedscope/core/aggregator.py deleted file mode 100644 index 68c7ba65a..000000000 --- a/federatedscope/core/aggregator.py +++ /dev/null @@ -1,300 +0,0 @@ -from abc import ABC, abstractmethod -from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer -from federatedscope.core.auxiliaries.utils import param2tensor - -import torch -import os -import copy - - -class Aggregator(ABC): - def __init__(self): - pass - - @abstractmethod - def aggregate(self, agg_info): - pass - - -class ClientsAvgAggregator(Aggregator): - """Implementation of vanilla FedAvg refer to `Communication-efficient - learning of deep networks from decentralized data` [McMahan et al., 2017] - (http://proceedings.mlr.press/v54/mcmahan17a.html) - """ - def __init__(self, model=None, device='cpu', config=None): - super(Aggregator, self).__init__() - self.model = model - self.device = device - self.cfg = config - - def aggregate(self, agg_info): - """ - To preform aggregation - - Arguments: - agg_info (dict): the feedbacks from clients - :returns: the aggregated results - :rtype: dict - """ - - models = agg_info["client_feedback"] - recover_fun = agg_info['recover_fun'] if ( - 'recover_fun' in agg_info and self.cfg.federate.use_ss) else None - avg_model = self._para_weighted_avg(models, recover_fun=recover_fun) - - return avg_model - - def update(self, model_parameters): - ''' - Arguments: - model_parameters (dict): PyTorch Module object's state_dict. - ''' - self.model.load_state_dict(model_parameters, strict=False) - - def save_model(self, path, cur_round=-1): - assert self.model is not None - - ckpt = {'cur_round': cur_round, 'model': self.model.state_dict()} - torch.save(ckpt, path) - - def load_model(self, path): - assert self.model is not None - - if os.path.exists(path): - ckpt = torch.load(path, map_location=self.device) - self.model.load_state_dict(ckpt['model']) - return ckpt['cur_round'] - else: - raise ValueError("The file {} does NOT exist".format(path)) - - def _para_weighted_avg(self, models, recover_fun=None): - training_set_size = 0 - for i in range(len(models)): - sample_size, _ = models[i] - training_set_size += sample_size - - sample_size, avg_model = models[0] - for key in avg_model: - for i in range(len(models)): - local_sample_size, local_model = models[i] - - if self.cfg.federate.ignore_weight: - weight = 1.0 / len(models) - elif self.cfg.federate.use_ss: - # When using secret sharing, what the server receives - # are sample_size * model_para - weight = 1.0 - else: - weight = local_sample_size / training_set_size - - if not self.cfg.federate.use_ss: - local_model[key] = param2tensor(local_model[key]) - if i == 0: - avg_model[key] = local_model[key] * weight - else: - avg_model[key] += local_model[key] * weight - - if self.cfg.federate.use_ss and recover_fun: - avg_model[key] = recover_fun(avg_model[key]) - # When using secret sharing, what the server receives are - # sample_size * model_para - avg_model[key] /= training_set_size - avg_model[key] = torch.FloatTensor(avg_model[key]) - - return avg_model - - -class NoCommunicationAggregator(ClientsAvgAggregator): - """"Clients do not communicate. Each client work locally - """ - def __init__(self, model=None, device='cpu', config=None): - super(NoCommunicationAggregator, self).__init__(model, device, config) - - def aggregate(self, agg_info): - # do nothing - return {} - - def update(self, model_parameters): - pass - - def save_model(self, path, cur_round=-1): - assert self.model is not None - - ckpt = {'cur_round': cur_round, 'model': self.model.state_dict()} - torch.save(ckpt, path) - - def load_model(self, path): - assert self.model is not None - - if os.path.exists(path): - ckpt = torch.load(path, map_location=self.device) - self.model.load_state_dict(ckpt['model'], strict=False) - return ckpt['cur_round'] - else: - raise ValueError("The file {} does NOT exist".format(path)) - - def _para_weighted_avg(self, models, recover_fun=None): - pass - - -class AsynClientsAvgAggregator(ClientsAvgAggregator): - """The aggregator used in asynchronous training, which discounts the - staled model updates - """ - def __init__(self, model=None, device='cpu', config=None): - super(AsynClientsAvgAggregator, self).__init__(model, device, config) - - def aggregate(self, agg_info): - """ - To preform aggregation - - Arguments: - agg_info (dict): the feedbacks from clients - :returns: the aggregated results - :rtype: dict - """ - - models = agg_info["client_feedback"] - recover_fun = agg_info['recover_fun'] if ( - 'recover_fun' in agg_info and self.cfg.federate.use_ss) else None - staleness = [x[1] - for x in agg_info['staleness']] # (client_id, staleness) - avg_model = self._para_weighted_avg(models, - recover_fun=recover_fun, - staleness=staleness) - - # When using asynchronous training, the return feedback is model delta - # rather than the model param - updated_model = copy.deepcopy(avg_model) - init_model = self.model.state_dict() - for key in avg_model: - updated_model[key] = init_model[key] + avg_model[key] - return updated_model - - def discount_func(self, staleness): - """ - Served as an example, we discount the model update with staleness \tau - as: (1.0/((1.0+\tau)**factor)), - which has been used in previous studies such as FedAsync (Asynchronous - Federated Optimization) and FedBuff - (Federated Learning with Buffered Asynchronous Aggregation). - """ - return (1.0 / - ((1.0 + staleness)**self.cfg.asyn.staleness_discount_factor)) - - def _para_weighted_avg(self, models, recover_fun=None, staleness=None): - training_set_size = 0 - for i in range(len(models)): - sample_size, _ = models[i] - training_set_size += sample_size - - sample_size, avg_model = models[0] - for key in avg_model: - for i in range(len(models)): - local_sample_size, local_model = models[i] - - if self.cfg.federate.ignore_weight: - weight = 1.0 / len(models) - else: - weight = local_sample_size / training_set_size - - assert staleness is not None - weight *= self.discount_func(staleness[i]) - if isinstance(local_model[key], torch.Tensor): - local_model[key] = local_model[key].float() - else: - local_model[key] = torch.FloatTensor(local_model[key]) - - if i == 0: - avg_model[key] = local_model[key] * weight - else: - avg_model[key] += local_model[key] * weight - - return avg_model - - -class OnlineClientsAvgAggregator(ClientsAvgAggregator): - def __init__(self, - model=None, - device='cpu', - src_device='cpu', - config=None): - super(OnlineClientsAvgAggregator, self).__init__(model, device, config) - self.src_device = src_device - - def reset(self): - self.maintained = self.model.state_dict() - for key in self.maintained: - self.maintained[key].data = torch.zeros_like( - self.maintained[key], device=self.src_device) - self.cnt = 0 - - def inc(self, content): - if isinstance(content, tuple): - sample_size, model_params = content - for key in self.maintained: - # if model_params[key].device != self.maintained[key].device: - # model_params[key].to(self.maintained[key].device) - self.maintained[key] = (self.cnt * self.maintained[key] + - sample_size * model_params[key]) / ( - self.cnt + sample_size) - self.cnt += sample_size - else: - raise TypeError( - "{} is not a tuple (sample_size, model_para)".format(content)) - - def aggregate(self, agg_info): - return self.maintained - - -class ServerClientsInterpolateAggregator(ClientsAvgAggregator): - """" - # conduct aggregation by interpolating global model from server and - local models from clients - """ - def __init__(self, model=None, device='cpu', config=None, beta=1.0): - super(ServerClientsInterpolateAggregator, - self).__init__(model, device, config) - self.beta = beta # the weight for local models used in interpolation - - def aggregate(self, agg_info): - models = agg_info["client_feedback"] - global_model = self.model - elem_each_client = next(iter(models)) - assert len(elem_each_client) == 2, f"Require (sample_size, " \ - f"model_para) tuple for each " \ - f"client, i.e., len=2, but got " \ - f"len={len(elem_each_client)}" - avg_model_by_clients = self._para_weighted_avg(models) - global_local_models = [((1 - self.beta), global_model.state_dict()), - (self.beta, avg_model_by_clients)] - - avg_model_by_interpolate = self._para_weighted_avg(global_local_models) - return avg_model_by_interpolate - - -class FedOptAggregator(ClientsAvgAggregator): - """Implementation of FedOpt refer to `Adaptive Federated Optimization` [ - Reddi et al., 2021] - (https://openreview.net/forum?id=LkFG3lB13U5) - - """ - def __init__(self, config, model, device='cpu'): - super(FedOptAggregator, self).__init__(model, device, config) - self.optimizer = get_optimizer(model=self.model, - **config.fedopt.optimizer) - - def aggregate(self, agg_info): - new_model = super().aggregate(agg_info) - - model = self.model.cpu().state_dict() - with torch.no_grad(): - grads = {key: model[key] - new_model[key] for key in new_model} - - self.optimizer.zero_grad() - for key, p in self.model.named_parameters(): - if key in new_model.keys(): - p.grad = grads[key] - self.optimizer.step() - - return self.model.state_dict() diff --git a/federatedscope/core/auxiliaries/data_builder.py b/federatedscope/core/auxiliaries/data_builder.py index 617d9689e..824463bca 100644 --- a/federatedscope/core/auxiliaries/data_builder.py +++ b/federatedscope/core/auxiliaries/data_builder.py @@ -28,7 +28,7 @@ 'subreddit', 'synthetic', 'ciao', 'epinions', '.*?vertical_fl_data.*?', '.*?movielens.*?', '.*?cikmcup.*?', 'graph_multi_domain.*?', 'cora', 'citeseer', 'pubmed', 'dblp_conf', 'dblp_org', 'csbm.*?', 'fb15k-237', - 'wn18' + 'wn18', 'cifar4cl', 'cifar4lp' ], # Dummy for FL dataset } DATA_TRANS_MAP = RegexInverseMap(TRANS_DATA_MAP, None) diff --git a/federatedscope/core/auxiliaries/worker_builder.py b/federatedscope/core/auxiliaries/worker_builder.py index 827ccfb2e..7a5766b4d 100644 --- a/federatedscope/core/auxiliaries/worker_builder.py +++ b/federatedscope/core/auxiliaries/worker_builder.py @@ -105,10 +105,10 @@ def get_server_cls(cfg): server_class = FedSagePlusServer elif server_type == 'gcflplus': from federatedscope.gfl.gcflplus.worker import GCFLPlusServer - return GCFLPlusServer + server_class = GCFLPlusServer elif server_type == 'fedgc': from federatedscope.cl.fedgc.server import GlobalContrastFLServer - return GlobalContrastFLServer + server_class = GlobalContrastFLServer else: server_class = Server diff --git a/federatedscope/core/data/utils.py b/federatedscope/core/data/utils.py index 6066ac855..f4f96b15d 100644 --- a/federatedscope/core/data/utils.py +++ b/federatedscope/core/data/utils.py @@ -38,8 +38,9 @@ def load_dataset(config): dataset, modified_config = load_quadratic_dataset(config) elif config.data.type.lower() in ['femnist', 'celeba']: from federatedscope.cv.dataloader import load_cv_dataset + dataset, modified_config = load_cv_dataset(config) elif config.data.type.lower() in ['cifar4cl', 'cifar4lp']: - from federatedscope.cl.dataloader import load_cifar_dataset + from federatedscope.cl.dataloader import load_cifar_dataset_for_contrast_learning dataset, modified_config = load_cv_dataset(config) elif config.data.type.lower() in [ 'shakespeare', 'twitter', 'subreddit', 'synthetic' diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index 09f0359d6..0bb9a0d82 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -438,7 +438,6 @@ def _perform_federated_aggregation(self): # Due to lazy load, we merge two state dict merged_param = merge_param_dict(model.state_dict().copy(), result) model.load_state_dict(merged_param, strict=False) - aggregator.update(merged_param) return aggregated_num From a250837d392231128f065efe8cfd0a78e7125f27 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Wed, 26 Oct 2022 19:00:49 +0800 Subject: [PATCH 40/46] modify for unit-test --- federatedscope/cl/dataloader/Cifar10.py | 66 ++++--------------- .../core/auxiliaries/data_builder.py | 4 +- federatedscope/core/data/utils.py | 4 +- tests/test_simclr_cifar10.py | 3 +- 4 files changed, 17 insertions(+), 60 deletions(-) diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index d57eaf90c..85dbbcc70 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -78,35 +78,14 @@ def Cifar4CL(config): # Split data into dict data_dict = dict() - splitter = get_splitter(config) - data_train = splitter(data_train) - label_data_train = [[i[1] for i in list_i] for list_i in data_train] data_val = data_train - data_test = splitter(data_test, prior=label_data_train) - - client_num = min(len(data_train), config.federate.client_num - ) if config.federate.client_num > 0 else len(data_train) - config.merge_from_list(['federate.client_num', client_num]) - - for client_idx in range(1, client_num + 1): - dataloader_dict = { - 'train': DataLoader(data_train[client_idx - 1], - config.data.batch_size, - shuffle=config.data.shuffle, - num_workers=config.data.num_workers), - 'val': DataLoader(data_val[client_idx - 1], - config.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers), - 'test': DataLoader(data_test[client_idx - 1], - config.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers), - } - data_dict[client_idx] = dataloader_dict + + data_dict = {'train': data_train, 'val': data_val, 'test': data_test} + data_split_tuple = (data_dict.get('train'), data_dict.get('val'), + data_dict.get('test')) config = config - return data_dict, config + return data_split_tuple, config def Cifar4LP(config): @@ -150,40 +129,17 @@ def Cifar4LP(config): # Split data into dict data_dict = dict() + data_val = data_train - # Splitter - splitter = get_splitter(config) - data_train = splitter(data_train) - label_data_train = [[i[1] for i in list_i] for list_i in data_train] - data_val = splitter(data_val, prior=label_data_train) - data_test = splitter(data_test, prior=label_data_train) - - client_num = min(len(data_train), config.federate.client_num - ) if config.federate.client_num > 0 else len(data_train) - config.merge_from_list(['federate.client_num', client_num]) - - for client_idx in range(1, client_num + 1): - dataloader_dict = { - 'train': DataLoader(data_train[client_idx - 1], - config.data.batch_size, - shuffle=config.data.shuffle, - num_workers=config.data.num_workers), - 'val': DataLoader(data_val[client_idx - 1], - config.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers), - 'test': DataLoader(data_test[client_idx - 1], - config.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers), - } - data_dict[client_idx] = dataloader_dict + data_dict = {'train': data_train, 'val': data_val, 'test': data_test} + data_split_tuple = (data_dict.get('train'), data_dict.get('val'), + data_dict.get('test')) config = config - return data_dict, config + return data_split_tuple, config -def load_cifar_dataset_for_contrast_learning(config): +def load_cifar_dataset(config): if config.data.type == "Cifar4CL": data, modified_config = Cifar4CL(config) return data, modified_config diff --git a/federatedscope/core/auxiliaries/data_builder.py b/federatedscope/core/auxiliaries/data_builder.py index 824463bca..5e094786c 100644 --- a/federatedscope/core/auxiliaries/data_builder.py +++ b/federatedscope/core/auxiliaries/data_builder.py @@ -21,14 +21,14 @@ TRANS_DATA_MAP = { 'BaseDataTranslator': [ '.*?@.*?', 'hiv', 'proteins', 'imdb-binary', 'bbbp', 'tox21', 'bace', - 'sider', 'clintox', 'esol', 'freesolv', 'lipo' + 'sider', 'clintox', 'esol', 'freesolv', 'lipo', 'cifar4cl', 'cifar4lp' ], 'DummyDataTranslator': [ 'toy', 'quadratic', 'femnist', 'celeba', 'shakespeare', 'twitter', 'subreddit', 'synthetic', 'ciao', 'epinions', '.*?vertical_fl_data.*?', '.*?movielens.*?', '.*?cikmcup.*?', 'graph_multi_domain.*?', 'cora', 'citeseer', 'pubmed', 'dblp_conf', 'dblp_org', 'csbm.*?', 'fb15k-237', - 'wn18', 'cifar4cl', 'cifar4lp' + 'wn18' ], # Dummy for FL dataset } DATA_TRANS_MAP = RegexInverseMap(TRANS_DATA_MAP, None) diff --git a/federatedscope/core/data/utils.py b/federatedscope/core/data/utils.py index f4f96b15d..5126118c9 100644 --- a/federatedscope/core/data/utils.py +++ b/federatedscope/core/data/utils.py @@ -40,8 +40,8 @@ def load_dataset(config): from federatedscope.cv.dataloader import load_cv_dataset dataset, modified_config = load_cv_dataset(config) elif config.data.type.lower() in ['cifar4cl', 'cifar4lp']: - from federatedscope.cl.dataloader import load_cifar_dataset_for_contrast_learning - dataset, modified_config = load_cv_dataset(config) + from federatedscope.cl.dataloader import load_cifar_dataset + dataset, modified_config = load_cifar_dataset(config) elif config.data.type.lower() in [ 'shakespeare', 'twitter', 'subreddit', 'synthetic' ]: diff --git a/tests/test_simclr_cifar10.py b/tests/test_simclr_cifar10.py index b086efbd2..f46696aaf 100644 --- a/tests/test_simclr_cifar10.py +++ b/tests/test_simclr_cifar10.py @@ -37,7 +37,8 @@ def set_config_simclr_cifar10(self, cfg): cfg.data.batch_size = 256 cfg.data.splitter = 'lda' cfg.data.splitter_args = [{'alpha': 0.1}] - cfg.data.num_workers = 4 + cfg.data.consistent_label_distribution = True + cfg.data.num_workers = 0 cfg.data.subsample = 1.0 cfg.model.type = 'SimCLR' From d306dc0d81cf7654abb8ebbc8ebb4b4f63ee295d Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Fri, 28 Oct 2022 17:03:43 +0800 Subject: [PATCH 41/46] Update test_simclr_cifar10.py --- tests/test_simclr_cifar10.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_simclr_cifar10.py b/tests/test_simclr_cifar10.py index f46696aaf..e503f8d08 100644 --- a/tests/test_simclr_cifar10.py +++ b/tests/test_simclr_cifar10.py @@ -38,7 +38,7 @@ def set_config_simclr_cifar10(self, cfg): cfg.data.splitter = 'lda' cfg.data.splitter_args = [{'alpha': 0.1}] cfg.data.consistent_label_distribution = True - cfg.data.num_workers = 0 + cfg.data.num_workers = 4 cfg.data.subsample = 1.0 cfg.model.type = 'SimCLR' From 99915fa55c304c4e292588c8b50381b16c2d2eeb Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Fri, 28 Oct 2022 17:48:19 +0800 Subject: [PATCH 42/46] Update Cifar10.py --- federatedscope/cl/dataloader/Cifar10.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/federatedscope/cl/dataloader/Cifar10.py b/federatedscope/cl/dataloader/Cifar10.py index 85dbbcc70..d519fc0d5 100644 --- a/federatedscope/cl/dataloader/Cifar10.py +++ b/federatedscope/cl/dataloader/Cifar10.py @@ -146,7 +146,3 @@ def load_cifar_dataset(config): elif config.data.type == "Cifar4LP": data, modified_config = Cifar4LP(config) return data, modified_config - - -register_data("Cifar4CL", load_cifar_dataset) -register_data("Cifar4LP", load_cifar_dataset) From 2c8112f5164678e10dc8d6ef49d11e2a65a734b0 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Thu, 24 Nov 2022 07:09:06 +0800 Subject: [PATCH 43/46] Create dataloader_molecule.py --- .../gfkd/dataloader/dataloader_molecule.py | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 federatedscope/gfkd/dataloader/dataloader_molecule.py diff --git a/federatedscope/gfkd/dataloader/dataloader_molecule.py b/federatedscope/gfkd/dataloader/dataloader_molecule.py new file mode 100644 index 000000000..b5e031d00 --- /dev/null +++ b/federatedscope/gfkd/dataloader/dataloader_molecule.py @@ -0,0 +1,86 @@ +from torch_geometric import transforms +from torch_geometric.datasets import TUDataset, MoleculeNet, QM7b +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Chem import Draw +m = Chem.MolFromSmiles('c1ccccc1') +m3d=Chem.AddHs(m) +AllChem.EmbedMolecule(m3d, randomSeed=1) + + + +from federatedscope.core.auxiliaries.transform_builder import get_transform + + +def load_heteromolecule_dataset(config=None): + r"""Convert dataset to Dataloader. + :returns: + data_local_dict + :rtype: Dict { + 'client_id': { + 'train': DataLoader(), + 'val': DataLoader(), + 'test': DataLoader() + } + } + """ + splits = config.data.splits + path = config.data.root + name = config.data.type.upper() + + # Transforms + transforms_funcs = get_transform(config, 'torch_geometric') + + if name.startswith('heterogeneous molecule dataset'.upper()): + dataset = [] + + TUDdataset_names = ['BZR', 'ENZYMES', 'MUTAG'] + MoleculeNet_names = ['ESOL', 'FreeSolv', 'BACE'] + for dname in TUDdataset_names: + tmp_dataset = TUDataset(path, dname, **transforms_funcs) + dataset.append(tmp_dataset) + for dname in MoleculeNet_names: + tmp_dataset = MoleculeNet(path, dname, **transforms_funcs) + if dname in ['FreeSolv', 'BACE']: + for i in len(tmp_dataset): + smiles = dataset[i].smiles + mol = Chem.MolFromSmiles(smiles) + mol = AllChem.AddHs(mol) + res = AllChem.EmbedMolecule(mol, randomSeed=1) + # will random generate conformer with seed equal to -1. else fixed random seed. + if res == 0: + try: + AllChem.MMFFOptimizeMolecule(mol)# some conformer can not use MMFF optimize + except: + pass + mol = AllChem.RemoveHs(mol) + coordinates = mol.GetConformer().GetPositions() + + elif res == -1: + mol_tmp = Chem.MolFromSmiles(smiles) + AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=1) + mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True) + try: + AllChem.MMFFOptimizeMolecule(mol_tmp)# some conformer can not use MMFF optimize + except: + pass + mol_tmp = AllChem.RemoveHs(mol_tmp) + coordinates = mol_tmp.GetConformer().GetPositions() + + assert dataset[i].x.shape[0] == len(coordinates), "coordinates shape is not align with {}".format(smiles) + tmp_dataset[i] = [tmp_dataset[i], coordinates] + dataset.append(tmp_dataset) + tmp_dataset = QM7b(path, dname, **transforms_funcs) + dataset.append(tmp_dataset) + else: + raise ValueError(f'No dataset named: {name}!') + + client_num = min(len(dataset), config.federate.client_num + ) if config.federate.client_num > 0 else len(dataset) + config.merge_from_list(['federate.client_num', client_num]) + + # get local dataset + data_dict = dict() + for client_idx in range(1, len(dataset) + 1): + data_dict[client_idx] = dataset[client_idx - 1] + return data_dict, config From 1e63ddbe4ab2d351c748757b1c0e2d6ba85d9938 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Mon, 28 Nov 2022 06:39:18 +0800 Subject: [PATCH 44/46] add models for different datasets --- .../gfkd/dataloader/dataloader_molecule.py | 1 + federatedscope/gfkd/model/2Dgraph_model.py | 61 ++++++++++++++++ federatedscope/gfkd/model/3Dgraph_model.py | 4 ++ federatedscope/gfkd/model/SMILES_model.py | 71 +++++++++++++++++++ 4 files changed, 137 insertions(+) create mode 100644 federatedscope/gfkd/model/2Dgraph_model.py create mode 100644 federatedscope/gfkd/model/3Dgraph_model.py create mode 100644 federatedscope/gfkd/model/SMILES_model.py diff --git a/federatedscope/gfkd/dataloader/dataloader_molecule.py b/federatedscope/gfkd/dataloader/dataloader_molecule.py index b5e031d00..7fb92d6bb 100644 --- a/federatedscope/gfkd/dataloader/dataloader_molecule.py +++ b/federatedscope/gfkd/dataloader/dataloader_molecule.py @@ -41,6 +41,7 @@ def load_heteromolecule_dataset(config=None): dataset.append(tmp_dataset) for dname in MoleculeNet_names: tmp_dataset = MoleculeNet(path, dname, **transforms_funcs) + if dname in ['FreeSolv', 'BACE']: for i in len(tmp_dataset): smiles = dataset[i].smiles diff --git a/federatedscope/gfkd/model/2Dgraph_model.py b/federatedscope/gfkd/model/2Dgraph_model.py new file mode 100644 index 000000000..c30b4b5fa --- /dev/null +++ b/federatedscope/gfkd/model/2Dgraph_model.py @@ -0,0 +1,61 @@ +import torch +import torch.nn.functional as F +from torch.nn import ModuleList +from torch_geometric.data import Data +from torch_geometric.data.batch import Batch + +from federatedscope.gfl.model.graph_level import GNN_Net_Graph + + +class GNN_Net(GNN_Net_Graph): + r"""GNN model with pre-linear layer, pooling layer + and output layer for graph classification tasks. + + Arguments: + in_channels (int): input channels. + out_channels (int): output channels. + hidden (int): hidden dim for all modules. + max_depth (int): number of layers for gnn. + dropout (float): dropout probability. + gnn (str): name of gnn type, use ("gcn" or "gin"). + pooling (str): pooling method, use ("add", "mean" or "max"). + """ + def __init__(self, + in_channels, + out_channels, + hidden=64, + max_depth=2, + dropout=.0, + gnn='gcn', + pooling='add', + conformer=False): + super(GNN_Net, self).__init__() + self.conformer = conformer + + def forward(self, data): + if self.conformer == False: + if isinstance(data, Batch): + x, edge_index, batch = data.x, data.edge_index, data.batch + elif isinstance(data, tuple): + x, edge_index, batch = data + else: + raise TypeError('Unsupported data type!') + else: + if isinstance(data, Batch): + x, edge_index, batch, pos = data.x, data.edge_index, data.batch + elif isinstance(data, tuple): + x, edge_index, batch = data + else: + raise TypeError('Unsupported data type!') + + if x.dtype == torch.int64: + x = self.encoder_atom(x) + else: + x = self.encoder(x) + + x = self.gnn((x, edge_index)) + x = self.pooling(x, batch) + x = self.linear(x) + x = F.dropout(x, self.dropout, training=self.training) + x = self.clf(x) + return x \ No newline at end of file diff --git a/federatedscope/gfkd/model/3Dgraph_model.py b/federatedscope/gfkd/model/3Dgraph_model.py new file mode 100644 index 000000000..a22513e4f --- /dev/null +++ b/federatedscope/gfkd/model/3Dgraph_model.py @@ -0,0 +1,4 @@ +import torch +import torch.nn.functional as F +from torch_geometric.nn import DimeNet + diff --git a/federatedscope/gfkd/model/SMILES_model.py b/federatedscope/gfkd/model/SMILES_model.py new file mode 100644 index 000000000..5fd6e58a4 --- /dev/null +++ b/federatedscope/gfkd/model/SMILES_model.py @@ -0,0 +1,71 @@ +import torch +import torch.nn +import random +import math + +class SMILESTransformer(torch.nn.Module): + def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.1): + super(SMILESTransformer, self).__init__() + from torch.nn import TransformerEncoder, TransformerEncoderLayer + self.model_type = 'Transformer' + self.ninp = ninp + + self.pos_encoder = torch.nn.Embedding(100, ninp) + self.encoder = torch.nn.Embedding(ntoken, ninp) + + self.layer_norm = torch.nn.LayerNorm([ninp]) + self.output_layer_norm = torch.nn.LayerNorm([ntoken]) + self.input_layer_norm = torch.nn.LayerNorm([ninp]) + + encoder_layers = TransformerEncoderLayer(d_model=ninp, + nhead=nhead, + dim_feedforward=nhid, + dropout=dropout, + activation='gelu') + + self.transformer_encoder = TransformerEncoder(encoder_layers, + nlayers, + norm=self.layer_norm) + + self.dropout = torch.nn.Dropout(dropout) + + self.decoder = torch.nn.Linear(ninp, ntoken, bias=False) + self.decoder_bias = torch.nn.Parameter(torch.zeros(ntoken)) + self.init_weights() + + + def init_weights(self): + initrange = 0.1 + self.encoder.weight.data.normal_(mean=0.0, std=1.0) + self.decoder.weight.data.normal_(mean=0.0, std=1.0) + self.decoder_bias.data.zero_() + + self.input_layer_norm.weight.data.fill_(1.0) + self.input_layer_norm.bias.data.zero_() + self.output_layer_norm.weight.data.fill_(1.0) + self.output_layer_norm.bias.data.zero_() + self.layer_norm.weight.data.fill_(1.0) + self.layer_norm.bias.data.zero_() + + + def forward(self, src, latent_out=False): + pos = torch.arange(0,100).long().to(src.device) + + mol_token_emb = self.encoder(src) + pos_emb = self.pos_encoder(pos) + input_emb = pos_emb + mol_token_emb + input_emb = self.input_layer_norm(input_emb) + input_emb = self.dropout(input_emb) + input_emb = input_emb.transpose(0, 1) + + attention_mask = torch.ones_like(src).to(src.device) + attention_mask = attention_mask.masked_fill(src!=1., 0.) + attention_mask = attention_mask.bool().to(src.device) + + output = self.transformer_encoder(input_emb) + + if latent_out: + return output + output = self.decoder(output) + self.decoder_bias + + return output \ No newline at end of file From 87f5a5438c52af69ca29bb4350597e6291423888 Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 29 Nov 2022 05:48:14 +0800 Subject: [PATCH 45/46] define DimeNet++ for QM7b --- federatedscope/gfkd/model/2Dgraph_model.py | 6 ++-- federatedscope/gfkd/model/3Dgraph_model.py | 40 +++++++++++++++++++++- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/federatedscope/gfkd/model/2Dgraph_model.py b/federatedscope/gfkd/model/2Dgraph_model.py index c30b4b5fa..b30c27626 100644 --- a/federatedscope/gfkd/model/2Dgraph_model.py +++ b/federatedscope/gfkd/model/2Dgraph_model.py @@ -41,10 +41,10 @@ def forward(self, data): else: raise TypeError('Unsupported data type!') else: - if isinstance(data, Batch): + if isinstance(data, Batch):# position is as attr in x x, edge_index, batch, pos = data.x, data.edge_index, data.batch - elif isinstance(data, tuple): - x, edge_index, batch = data + elif isinstance(data, tuple):# position + x, edge_index, batch, pos = data.x, data.edge_index, data.batch, data.pos else: raise TypeError('Unsupported data type!') diff --git a/federatedscope/gfkd/model/3Dgraph_model.py b/federatedscope/gfkd/model/3Dgraph_model.py index a22513e4f..279e2a9bb 100644 --- a/federatedscope/gfkd/model/3Dgraph_model.py +++ b/federatedscope/gfkd/model/3Dgraph_model.py @@ -1,4 +1,42 @@ import torch import torch.nn.functional as F -from torch_geometric.nn import DimeNet +from torch_geometric.nn import DimeNetPlusPlus +class DimeNetPlusPlus_for_QM7b(DimeNetPlusPlus): + def __init__( + self, + hidden_channels: int = 256, + out_channels: int = 14, + num_blocks: int = 3, + int_emb_size: int = 64, + basis_emb_size: int = 8, + out_emb_channels: int = 256, + num_spherical: int = 7, + num_radial: int = 6, + cutoff: float = 5.0, + max_num_neighbors: int = 32, + envelope_exponent: int = 5, + num_before_skip: int = 1, + num_after_skip: int = 2, + num_output_layers: int = 3, + act: Union[str, Callable] = 'swish', + ): + super().__init__( + hidden_channels=hidden_channels, + out_channels=out_channels, + num_blocks=num_blocks, + int_emb_size=int_emb_size, + basis_emb_size=basis_emb_size, + out_emb_channels=out_emb_channels, + num_bilinear=1, + num_spherical=num_spherical, + num_radial=num_radial, + cutoff=cutoff, + max_num_neighbors=max_num_neighbors, + envelope_exponent=envelope_exponent, + num_before_skip=num_before_skip, + num_after_skip=num_after_skip, + num_output_layers=num_output_layers, + act=act, + ) + \ No newline at end of file From d1404f7a55d5f05822e255a0417bd36193d2571f Mon Sep 17 00:00:00 2001 From: xkxxfyf <2014201870@ruc.edu.cn> Date: Tue, 6 Dec 2022 08:24:29 +0800 Subject: [PATCH 46/46] add trainer modelbuilder and generalmodel --- federatedscope/gfkd/model/general_model,py | 266 ++++++++++++++++++ .../{2Dgraph_model.py => graph2D_model.py} | 0 .../{3Dgraph_model.py => graph3D_model.py} | 0 federatedscope/gfkd/model/model_builder.py | 43 +++ federatedscope/gfkd/trainer/graphtrainer.py | 81 ++++++ federatedscope/gfl/trainer/graphtrainer.py | 20 -- 6 files changed, 390 insertions(+), 20 deletions(-) create mode 100644 federatedscope/gfkd/model/general_model,py rename federatedscope/gfkd/model/{2Dgraph_model.py => graph2D_model.py} (100%) rename federatedscope/gfkd/model/{3Dgraph_model.py => graph3D_model.py} (100%) create mode 100644 federatedscope/gfkd/model/model_builder.py create mode 100644 federatedscope/gfkd/trainer/graphtrainer.py diff --git a/federatedscope/gfkd/model/general_model,py b/federatedscope/gfkd/model/general_model,py new file mode 100644 index 000000000..be12bdee3 --- /dev/null +++ b/federatedscope/gfkd/model/general_model,py @@ -0,0 +1,266 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from data import get_dataset +import torch +import math +import torch.nn as nn +import pytorch_lightning as pl +from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool + +def init_params(module, n_layers): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02 / math.sqrt(n_layers)) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=0.02) + + +class Graphormer(pl.LightningModule): + def __init__( + self, + n_layers, + num_heads, + hidden_dim, + dropout_rate, + intput_dropout_rate, + num_class, + weight_decay, + ffn_dim, + warmup_updates, + tot_updates, + peak_lr, + end_lr, + edge_type, + multi_hop_max_dist, + attention_dropout_rate, + flag=False, + flag_m=3, + flag_step_size=1e-3, + flag_mag=1e-3, + ): + super().__init__() + self.save_hyperparameters() + + self.num_heads = num_heads + + self.atom_encoder = nn.Embedding( + 512 * 9 + 1, hidden_dim, padding_idx=0) + self.edge_encoder = nn.Embedding( + 512 * 3 + 1, hidden_dim, padding_idx=0) + self.edge_type = edge_type + if self.edge_type == 'multi_hop': + self.edge_dis_encoder = nn.Embedding( + 128 * num_heads * num_heads, 1) + self.no_masked_spatial_pos_encoder = nn.Linear(1, num_heads) + self.masked_spatial_pos_encoder = nn.Linear(1, num_heads) + + + self.input_dropout = nn.Dropout(intput_dropout_rate) + encoders = [EncoderLayer(hidden_dim, ffn_dim, dropout_rate, attention_dropout_rate, num_heads) + for _ in range(n_layers)] + self.decoders = EncoderLayer(hidden_dim, ffn_dim, dropout_rate, attention_dropout_rate, num_heads) + self.layers = nn.ModuleList(encoders) + self.final_ln = nn.LayerNorm(hidden_dim) + self.weight2ppr = nn.Linear(num_heads, 1) + + + self.downstream_out_proj = nn.Linear( + hidden_dim, num_class) + + self.graph_token = nn.Embedding(1, hidden_dim) + self.graph_token_virtual_distance = nn.Embedding(1, num_heads) + + self.warmup_updates = warmup_updates + self.tot_updates = tot_updates + self.peak_lr = peak_lr + self.end_lr = end_lr + self.weight_decay = weight_decay + self.multi_hop_max_dist = multi_hop_max_dist + + self.flag = flag + self.flag_m = flag_m + self.flag_step_size = flag_step_size + self.flag_mag = flag_mag + self.hidden_dim = hidden_dim + self.automatic_optimization = not self.flag + self.apply(lambda module: init_params(module, n_layers=n_layers)) + + def forward(self, batched_data, perturb=None): + attn_bias, spatial_pos_mask = batched_data.attn_bias, batched_data.spatial_pos_mask + no_mask_ppr = spatial_pos_mask.clone() + no_mask_ppr[spatial_pos_mask == 128] = 0 +# print(spatial_pos_mask.dtype, no_mask_ppr.dtype) +# no_mask_coordinate = torch.nonzero(mask) +# no_mask_seq = spatial_pos_mask.view(-1) + x_n, x_e = batched_data.x_n, batched_data.x_e + edge_index = batched_data.edge_index +# in_degree, out_degree = batched_data.in_degree, batched_data.out_degree + + with torch.no_grad(): + # graph_attn_bias + n_graph, n_node = x_n.size()[0], x_n.size()[1] + x_e.size()[1] + graph_attn_bias = attn_bias.clone() + graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat( + 1, self.num_heads, 1, 1) # [n_graph, n_head, n_node+1, n_node+1] + + # spatial pos + # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node] + no_mask_ppr = no_mask_ppr.unsqueeze(-1) + no_masked_spatial_pos_bias = self.no_masked_spatial_pos_encoder(no_mask_ppr).permute(0, 3, 1, 2) + graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, + :, 1:, 1:] + no_masked_spatial_pos_bias + # reset spatial pos here + t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1) + graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t + graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t + + graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, + :, 1:, 1:] + graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset + + # node feauture + graph token + node_feature = torch.cat([self.atom_encoder(x_n).sum(dim=-2),self.edge_encoder(x_e).sum(dim=-2)],dim=1) # [n_graph, n_node+e_node, n_hidden] + + # node_feature = node_feature + \ + # self.in_degree_encoder(in_degree) + \ + # self.out_degree_encoder(out_degree) + graph_token_feature = self.graph_token.weight.unsqueeze( + 0).repeat(n_graph, 1, 1) + graph_node_feature = torch.cat( + [graph_token_feature, node_feature], dim=1) + + # transformer encoder + output = self.input_dropout(graph_node_feature) + weight = graph_attn_bias + for enc_layer in self.layers: + output, weight = enc_layer(output, weight) + weight = torch.softmax(weight,dim=3) + + global_output = self.final_ln(output) + global_output = self.out_proj(global_output[:, 0, :]) + + return global_output + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("Graphormer") + parser.add_argument('--n_layers', type=int, default=6) + parser.add_argument('--num_heads', type=int, default=16) + parser.add_argument('--hidden_dim', type=int, default=512) + parser.add_argument('--ffn_dim', type=int, default=512) + parser.add_argument('--intput_dropout_rate', type=float, default=0.1) + parser.add_argument('--dropout_rate', type=float, default=0.1) + parser.add_argument('--weight_decay', type=float, default=0.01) + parser.add_argument('--attention_dropout_rate', + type=float, default=0.1) + parser.add_argument('--checkpoint_path', type=str, default='') + parser.add_argument('--warmup_updates', type=int, default=60000) + parser.add_argument('--tot_updates', type=int, default=1000000) + parser.add_argument('--peak_lr', type=float, default=2e-4) + parser.add_argument('--end_lr', type=float, default=1e-9) + parser.add_argument('--edge_type', type=str, default='multi_hop') + parser.add_argument('--validate', action='store_true', default=False) + parser.add_argument('--test', action='store_true', default=False) + parser.add_argument('--flag', action='store_true') + parser.add_argument('--flag_m', type=int, default=3) + parser.add_argument('--flag_step_size', type=float, default=1e-3) + parser.add_argument('--flag_mag', type=float, default=1e-3) + return parent_parser + + +class FeedForwardNetwork(nn.Module): + def __init__(self, hidden_size, ffn_size, dropout_rate): + super(FeedForwardNetwork, self).__init__() + + self.layer1 = nn.Linear(hidden_size, ffn_size) + self.gelu = nn.GELU() + self.layer2 = nn.Linear(ffn_size, hidden_size) + + def forward(self, x): + x = self.layer1(x) + x = self.gelu(x) + x = self.layer2(x) + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, hidden_size, attention_dropout_rate, num_heads): + super(MultiHeadAttention, self).__init__() + + self.num_heads = num_heads + + self.att_size = att_size = hidden_size // num_heads + self.scale = att_size ** -0.5 + + self.linear_q = nn.Linear(hidden_size, num_heads * att_size) + self.linear_k = nn.Linear(hidden_size, num_heads * att_size) + self.linear_v = nn.Linear(hidden_size, num_heads * att_size) + self.att_dropout = nn.Dropout(attention_dropout_rate) + + self.output_layer = nn.Linear(num_heads * att_size, hidden_size) + + def forward(self, q, k, v, attn_bias=None): + orig_q_size = q.size() + + d_k = self.att_size + d_v = self.att_size + batch_size = q.size(0) + + # head_i = Attention(Q(W^Q)_i, K(W^K)_i, V(W^V)_i) + q = self.linear_q(q).view(batch_size, -1, self.num_heads, d_k) + k = self.linear_k(k).view(batch_size, -1, self.num_heads, d_k) + v = self.linear_v(v).view(batch_size, -1, self.num_heads, d_v) + + q = q.transpose(1, 2) # [b, h, q_len, d_k] + v = v.transpose(1, 2) # [b, h, v_len, d_v] + k = k.transpose(1, 2).transpose(2, 3) # [b, h, d_k, k_len] + + # Scaled Dot-Product Attention. + # Attention(Q, K, V) = softmax((QK^T)/sqrt(d_k))V + q = q * self.scale + x = torch.matmul(q, k) # [b, h, q_len, k_len] + if attn_bias is not None: + x = x + attn_bias + + weight_matrix = x + + x = torch.softmax(x, dim=3) + x = self.att_dropout(x) + x = x.matmul(v) # [b, h, q_len, attn] + + x = x.transpose(1, 2).contiguous() # [b, q_len, h, attn] + x = x.view(batch_size, -1, self.num_heads * d_v) + + x = self.output_layer(x) + + assert x.size() == orig_q_size + return x, weight_matrix + + +class EncoderLayer(nn.Module): + def __init__(self, hidden_size, ffn_size, dropout_rate, attention_dropout_rate, num_heads): + super(EncoderLayer, self).__init__() + + self.self_attention_norm = nn.LayerNorm(hidden_size) + self.self_attention = MultiHeadAttention( + hidden_size, attention_dropout_rate, num_heads) + self.self_attention_dropout = nn.Dropout(dropout_rate) + + self.ffn_norm = nn.LayerNorm(hidden_size) + self.ffn = FeedForwardNetwork(hidden_size, ffn_size, dropout_rate) + self.ffn_dropout = nn.Dropout(dropout_rate) + + def forward(self, x, attn_bias=None): + y = self.self_attention_norm(x) + y, weight = self.self_attention(y, y, y, attn_bias) + y = self.self_attention_dropout(y) + x = x + y + + y = self.ffn_norm(x) + y = self.ffn(y) + y = self.ffn_dropout(y) + x = x + y + weight = weight + attn_bias + return x, weight diff --git a/federatedscope/gfkd/model/2Dgraph_model.py b/federatedscope/gfkd/model/graph2D_model.py similarity index 100% rename from federatedscope/gfkd/model/2Dgraph_model.py rename to federatedscope/gfkd/model/graph2D_model.py diff --git a/federatedscope/gfkd/model/3Dgraph_model.py b/federatedscope/gfkd/model/graph3D_model.py similarity index 100% rename from federatedscope/gfkd/model/3Dgraph_model.py rename to federatedscope/gfkd/model/graph3D_model.py diff --git a/federatedscope/gfkd/model/model_builder.py b/federatedscope/gfkd/model/model_builder.py new file mode 100644 index 000000000..335d5d8f7 --- /dev/null +++ b/federatedscope/gfkd/model/model_builder.py @@ -0,0 +1,43 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +from federatedscope.gfkd.model.graph2D_model import GNN_Net +from federatedscope.gfkd.model.graph3D_model import DimeNetPlusPlus_for_QM7b +from federatedscope.gfkd.model.SMILES_model import SMILESTransformer + + +def get_gnn(model_config, input_shape): + + x_shape, num_label, num_edge_features = input_shape + if not num_label: + num_label = 0 + if model_config.type == 'SMILESTransformer': + # assume `data` is a dict where key is the client index, + # and value is a PyG object + model = SMILESTransformer(ntoken=415, + ninp=128, + nhead=8, + nhid=model_config.hidden, + nlayers=model_config.layer, + dropout=model_config.dropout) + elif model_config.type == 'GNN_Net': + model = GNN_Net(x_shape[-1], + max(model_config.out_channels, num_label), + hidden=model_config.hidden, + max_depth=model_config.layer, + dropout=model_config.dropout, + gnn=model_config.type, + pooling=model_config.graph_pooling) + elif model_config.type == 'DimeNetPlusPlus': + model = DimeNetPlusPlus_for_QM7b(hidden_channels=model_config.hidden, + out_channels=model_config.out_channels, + num_blocks=model_config.num_blocks, + int_emb_size=model_config.int_emb_size, + basis_emb_size=model_config.basis_emb_size, + out_emb_channels=model_config.out_emb_channels) + else: + raise ValueError('not recognized model {}'.format( + model_config.type)) + + return model diff --git a/federatedscope/gfkd/trainer/graphtrainer.py b/federatedscope/gfkd/trainer/graphtrainer.py new file mode 100644 index 000000000..8f806c1a9 --- /dev/null +++ b/federatedscope/gfkd/trainer/graphtrainer.py @@ -0,0 +1,81 @@ +import logging + +from federatedscope.core.monitors import Monitor +from federatedscope.register import register_trainer +from federatedscope.core.trainers import GeneralTorchTrainer +from federatedscope.core.trainers.context import CtxVar +from federatedscope.core.trainers.enums import LIFECYCLE + +logger = logging.getLogger(__name__) + + +class GraphMiniBatchTrainer(GeneralTorchTrainer): + def _hook_on_batch_forward(self, ctx): + batch = ctx.data_batch.to(ctx.device) + pred = ctx.model(batch) + # TODO: deal with the type of data within the dataloader or dataset + if 'regression' in ctx.cfg.model.task.lower(): + label = batch.y + else: + label = batch.y.squeeze(-1).long() + if len(label.size()) == 0: + label = label.unsqueeze(0) + ctx.loss_batch = ctx.criterion(pred, label) + + ctx.batch_size = len(label) + ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) + ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH) + + def _hook_on_batch_forward_flop_count(self, ctx): + if not isinstance(self.ctx.monitor, Monitor): + logger.warning( + f"The trainer {type(self)} does contain a valid monitor, " + f"this may be caused by initializing trainer subclasses " + f"without passing a valid monitor instance." + f"Plz check whether this is you want.") + return + + if self.cfg.eval.count_flops and self.ctx.monitor.flops_per_sample \ + == 0: + # calculate the flops_per_sample + try: + batch = ctx.data_batch.to(ctx.device) + from torch_geometric.data import Data + if isinstance(batch, Data): + x, edge_index = batch.x, batch.edge_index + from fvcore.nn import FlopCountAnalysis + flops_one_batch = FlopCountAnalysis(ctx.model, + (x, edge_index)).total() + if self.model_nums > 1 and ctx.mirrored_models: + flops_one_batch *= self.model_nums + logger.warning( + "the flops_per_batch is multiplied by " + "internal model nums as self.mirrored_models=True." + "if this is not the case you want, " + "please customize the count hook") + self.ctx.monitor.track_avg_flops(flops_one_batch, + ctx.batch_size) + except: + logger.warning( + "current flop count implementation is for general " + "GraphMiniBatchTrainer case: " + "1) the ctx.model takes only batch = ctx.data_batch as " + "input." + "Please check the forward format or implement your own " + "flop_count function") + self.ctx.monitor.flops_per_sample = -1 # warning at the + # first failure + + # by default, we assume the data has the same input shape, + # thus simply multiply the flops to avoid redundant forward + self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \ + ctx.batch_size + + +def call_graph_level_trainer(trainer_type): + if trainer_type == 'graphminibatch_trainer': + trainer_builder = GraphMiniBatchTrainer + return trainer_builder + + +register_trainer('graphminibatch_trainer', call_graph_level_trainer) diff --git a/federatedscope/gfl/trainer/graphtrainer.py b/federatedscope/gfl/trainer/graphtrainer.py index 8f806c1a9..7d200d062 100644 --- a/federatedscope/gfl/trainer/graphtrainer.py +++ b/federatedscope/gfl/trainer/graphtrainer.py @@ -13,7 +13,6 @@ class GraphMiniBatchTrainer(GeneralTorchTrainer): def _hook_on_batch_forward(self, ctx): batch = ctx.data_batch.to(ctx.device) pred = ctx.model(batch) - # TODO: deal with the type of data within the dataloader or dataset if 'regression' in ctx.cfg.model.task.lower(): label = batch.y else: @@ -28,11 +27,6 @@ def _hook_on_batch_forward(self, ctx): def _hook_on_batch_forward_flop_count(self, ctx): if not isinstance(self.ctx.monitor, Monitor): - logger.warning( - f"The trainer {type(self)} does contain a valid monitor, " - f"this may be caused by initializing trainer subclasses " - f"without passing a valid monitor instance." - f"Plz check whether this is you want.") return if self.cfg.eval.count_flops and self.ctx.monitor.flops_per_sample \ @@ -48,26 +42,12 @@ def _hook_on_batch_forward_flop_count(self, ctx): (x, edge_index)).total() if self.model_nums > 1 and ctx.mirrored_models: flops_one_batch *= self.model_nums - logger.warning( - "the flops_per_batch is multiplied by " - "internal model nums as self.mirrored_models=True." - "if this is not the case you want, " - "please customize the count hook") self.ctx.monitor.track_avg_flops(flops_one_batch, ctx.batch_size) except: - logger.warning( - "current flop count implementation is for general " - "GraphMiniBatchTrainer case: " - "1) the ctx.model takes only batch = ctx.data_batch as " - "input." - "Please check the forward format or implement your own " - "flop_count function") self.ctx.monitor.flops_per_sample = -1 # warning at the # first failure - # by default, we assume the data has the same input shape, - # thus simply multiply the flops to avoid redundant forward self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \ ctx.batch_size