diff --git a/federatedscope/gfkd/dataloader/dataloader_molecule.py b/federatedscope/gfkd/dataloader/dataloader_molecule.py new file mode 100644 index 000000000..7fb92d6bb --- /dev/null +++ b/federatedscope/gfkd/dataloader/dataloader_molecule.py @@ -0,0 +1,87 @@ +from torch_geometric import transforms +from torch_geometric.datasets import TUDataset, MoleculeNet, QM7b +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Chem import Draw +m = Chem.MolFromSmiles('c1ccccc1') +m3d=Chem.AddHs(m) +AllChem.EmbedMolecule(m3d, randomSeed=1) + + + +from federatedscope.core.auxiliaries.transform_builder import get_transform + + +def load_heteromolecule_dataset(config=None): + r"""Convert dataset to Dataloader. + :returns: + data_local_dict + :rtype: Dict { + 'client_id': { + 'train': DataLoader(), + 'val': DataLoader(), + 'test': DataLoader() + } + } + """ + splits = config.data.splits + path = config.data.root + name = config.data.type.upper() + + # Transforms + transforms_funcs = get_transform(config, 'torch_geometric') + + if name.startswith('heterogeneous molecule dataset'.upper()): + dataset = [] + + TUDdataset_names = ['BZR', 'ENZYMES', 'MUTAG'] + MoleculeNet_names = ['ESOL', 'FreeSolv', 'BACE'] + for dname in TUDdataset_names: + tmp_dataset = TUDataset(path, dname, **transforms_funcs) + dataset.append(tmp_dataset) + for dname in MoleculeNet_names: + tmp_dataset = MoleculeNet(path, dname, **transforms_funcs) + + if dname in ['FreeSolv', 'BACE']: + for i in len(tmp_dataset): + smiles = dataset[i].smiles + mol = Chem.MolFromSmiles(smiles) + mol = AllChem.AddHs(mol) + res = AllChem.EmbedMolecule(mol, randomSeed=1) + # will random generate conformer with seed equal to -1. else fixed random seed. + if res == 0: + try: + AllChem.MMFFOptimizeMolecule(mol)# some conformer can not use MMFF optimize + except: + pass + mol = AllChem.RemoveHs(mol) + coordinates = mol.GetConformer().GetPositions() + + elif res == -1: + mol_tmp = Chem.MolFromSmiles(smiles) + AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=1) + mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True) + try: + AllChem.MMFFOptimizeMolecule(mol_tmp)# some conformer can not use MMFF optimize + except: + pass + mol_tmp = AllChem.RemoveHs(mol_tmp) + coordinates = mol_tmp.GetConformer().GetPositions() + + assert dataset[i].x.shape[0] == len(coordinates), "coordinates shape is not align with {}".format(smiles) + tmp_dataset[i] = [tmp_dataset[i], coordinates] + dataset.append(tmp_dataset) + tmp_dataset = QM7b(path, dname, **transforms_funcs) + dataset.append(tmp_dataset) + else: + raise ValueError(f'No dataset named: {name}!') + + client_num = min(len(dataset), config.federate.client_num + ) if config.federate.client_num > 0 else len(dataset) + config.merge_from_list(['federate.client_num', client_num]) + + # get local dataset + data_dict = dict() + for client_idx in range(1, len(dataset) + 1): + data_dict[client_idx] = dataset[client_idx - 1] + return data_dict, config diff --git a/federatedscope/gfkd/model/SMILES_model.py b/federatedscope/gfkd/model/SMILES_model.py new file mode 100644 index 000000000..5fd6e58a4 --- /dev/null +++ b/federatedscope/gfkd/model/SMILES_model.py @@ -0,0 +1,71 @@ +import torch +import torch.nn +import random +import math + +class SMILESTransformer(torch.nn.Module): + def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.1): + super(SMILESTransformer, self).__init__() + from torch.nn import TransformerEncoder, TransformerEncoderLayer + self.model_type = 'Transformer' + self.ninp = ninp + + self.pos_encoder = torch.nn.Embedding(100, ninp) + self.encoder = torch.nn.Embedding(ntoken, ninp) + + self.layer_norm = torch.nn.LayerNorm([ninp]) + self.output_layer_norm = torch.nn.LayerNorm([ntoken]) + self.input_layer_norm = torch.nn.LayerNorm([ninp]) + + encoder_layers = TransformerEncoderLayer(d_model=ninp, + nhead=nhead, + dim_feedforward=nhid, + dropout=dropout, + activation='gelu') + + self.transformer_encoder = TransformerEncoder(encoder_layers, + nlayers, + norm=self.layer_norm) + + self.dropout = torch.nn.Dropout(dropout) + + self.decoder = torch.nn.Linear(ninp, ntoken, bias=False) + self.decoder_bias = torch.nn.Parameter(torch.zeros(ntoken)) + self.init_weights() + + + def init_weights(self): + initrange = 0.1 + self.encoder.weight.data.normal_(mean=0.0, std=1.0) + self.decoder.weight.data.normal_(mean=0.0, std=1.0) + self.decoder_bias.data.zero_() + + self.input_layer_norm.weight.data.fill_(1.0) + self.input_layer_norm.bias.data.zero_() + self.output_layer_norm.weight.data.fill_(1.0) + self.output_layer_norm.bias.data.zero_() + self.layer_norm.weight.data.fill_(1.0) + self.layer_norm.bias.data.zero_() + + + def forward(self, src, latent_out=False): + pos = torch.arange(0,100).long().to(src.device) + + mol_token_emb = self.encoder(src) + pos_emb = self.pos_encoder(pos) + input_emb = pos_emb + mol_token_emb + input_emb = self.input_layer_norm(input_emb) + input_emb = self.dropout(input_emb) + input_emb = input_emb.transpose(0, 1) + + attention_mask = torch.ones_like(src).to(src.device) + attention_mask = attention_mask.masked_fill(src!=1., 0.) + attention_mask = attention_mask.bool().to(src.device) + + output = self.transformer_encoder(input_emb) + + if latent_out: + return output + output = self.decoder(output) + self.decoder_bias + + return output \ No newline at end of file diff --git a/federatedscope/gfkd/model/general_model,py b/federatedscope/gfkd/model/general_model,py new file mode 100644 index 000000000..be12bdee3 --- /dev/null +++ b/federatedscope/gfkd/model/general_model,py @@ -0,0 +1,266 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from data import get_dataset +import torch +import math +import torch.nn as nn +import pytorch_lightning as pl +from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool + +def init_params(module, n_layers): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02 / math.sqrt(n_layers)) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=0.02) + + +class Graphormer(pl.LightningModule): + def __init__( + self, + n_layers, + num_heads, + hidden_dim, + dropout_rate, + intput_dropout_rate, + num_class, + weight_decay, + ffn_dim, + warmup_updates, + tot_updates, + peak_lr, + end_lr, + edge_type, + multi_hop_max_dist, + attention_dropout_rate, + flag=False, + flag_m=3, + flag_step_size=1e-3, + flag_mag=1e-3, + ): + super().__init__() + self.save_hyperparameters() + + self.num_heads = num_heads + + self.atom_encoder = nn.Embedding( + 512 * 9 + 1, hidden_dim, padding_idx=0) + self.edge_encoder = nn.Embedding( + 512 * 3 + 1, hidden_dim, padding_idx=0) + self.edge_type = edge_type + if self.edge_type == 'multi_hop': + self.edge_dis_encoder = nn.Embedding( + 128 * num_heads * num_heads, 1) + self.no_masked_spatial_pos_encoder = nn.Linear(1, num_heads) + self.masked_spatial_pos_encoder = nn.Linear(1, num_heads) + + + self.input_dropout = nn.Dropout(intput_dropout_rate) + encoders = [EncoderLayer(hidden_dim, ffn_dim, dropout_rate, attention_dropout_rate, num_heads) + for _ in range(n_layers)] + self.decoders = EncoderLayer(hidden_dim, ffn_dim, dropout_rate, attention_dropout_rate, num_heads) + self.layers = nn.ModuleList(encoders) + self.final_ln = nn.LayerNorm(hidden_dim) + self.weight2ppr = nn.Linear(num_heads, 1) + + + self.downstream_out_proj = nn.Linear( + hidden_dim, num_class) + + self.graph_token = nn.Embedding(1, hidden_dim) + self.graph_token_virtual_distance = nn.Embedding(1, num_heads) + + self.warmup_updates = warmup_updates + self.tot_updates = tot_updates + self.peak_lr = peak_lr + self.end_lr = end_lr + self.weight_decay = weight_decay + self.multi_hop_max_dist = multi_hop_max_dist + + self.flag = flag + self.flag_m = flag_m + self.flag_step_size = flag_step_size + self.flag_mag = flag_mag + self.hidden_dim = hidden_dim + self.automatic_optimization = not self.flag + self.apply(lambda module: init_params(module, n_layers=n_layers)) + + def forward(self, batched_data, perturb=None): + attn_bias, spatial_pos_mask = batched_data.attn_bias, batched_data.spatial_pos_mask + no_mask_ppr = spatial_pos_mask.clone() + no_mask_ppr[spatial_pos_mask == 128] = 0 +# print(spatial_pos_mask.dtype, no_mask_ppr.dtype) +# no_mask_coordinate = torch.nonzero(mask) +# no_mask_seq = spatial_pos_mask.view(-1) + x_n, x_e = batched_data.x_n, batched_data.x_e + edge_index = batched_data.edge_index +# in_degree, out_degree = batched_data.in_degree, batched_data.out_degree + + with torch.no_grad(): + # graph_attn_bias + n_graph, n_node = x_n.size()[0], x_n.size()[1] + x_e.size()[1] + graph_attn_bias = attn_bias.clone() + graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat( + 1, self.num_heads, 1, 1) # [n_graph, n_head, n_node+1, n_node+1] + + # spatial pos + # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node] + no_mask_ppr = no_mask_ppr.unsqueeze(-1) + no_masked_spatial_pos_bias = self.no_masked_spatial_pos_encoder(no_mask_ppr).permute(0, 3, 1, 2) + graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, + :, 1:, 1:] + no_masked_spatial_pos_bias + # reset spatial pos here + t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1) + graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t + graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t + + graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, + :, 1:, 1:] + graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset + + # node feauture + graph token + node_feature = torch.cat([self.atom_encoder(x_n).sum(dim=-2),self.edge_encoder(x_e).sum(dim=-2)],dim=1) # [n_graph, n_node+e_node, n_hidden] + + # node_feature = node_feature + \ + # self.in_degree_encoder(in_degree) + \ + # self.out_degree_encoder(out_degree) + graph_token_feature = self.graph_token.weight.unsqueeze( + 0).repeat(n_graph, 1, 1) + graph_node_feature = torch.cat( + [graph_token_feature, node_feature], dim=1) + + # transformer encoder + output = self.input_dropout(graph_node_feature) + weight = graph_attn_bias + for enc_layer in self.layers: + output, weight = enc_layer(output, weight) + weight = torch.softmax(weight,dim=3) + + global_output = self.final_ln(output) + global_output = self.out_proj(global_output[:, 0, :]) + + return global_output + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("Graphormer") + parser.add_argument('--n_layers', type=int, default=6) + parser.add_argument('--num_heads', type=int, default=16) + parser.add_argument('--hidden_dim', type=int, default=512) + parser.add_argument('--ffn_dim', type=int, default=512) + parser.add_argument('--intput_dropout_rate', type=float, default=0.1) + parser.add_argument('--dropout_rate', type=float, default=0.1) + parser.add_argument('--weight_decay', type=float, default=0.01) + parser.add_argument('--attention_dropout_rate', + type=float, default=0.1) + parser.add_argument('--checkpoint_path', type=str, default='') + parser.add_argument('--warmup_updates', type=int, default=60000) + parser.add_argument('--tot_updates', type=int, default=1000000) + parser.add_argument('--peak_lr', type=float, default=2e-4) + parser.add_argument('--end_lr', type=float, default=1e-9) + parser.add_argument('--edge_type', type=str, default='multi_hop') + parser.add_argument('--validate', action='store_true', default=False) + parser.add_argument('--test', action='store_true', default=False) + parser.add_argument('--flag', action='store_true') + parser.add_argument('--flag_m', type=int, default=3) + parser.add_argument('--flag_step_size', type=float, default=1e-3) + parser.add_argument('--flag_mag', type=float, default=1e-3) + return parent_parser + + +class FeedForwardNetwork(nn.Module): + def __init__(self, hidden_size, ffn_size, dropout_rate): + super(FeedForwardNetwork, self).__init__() + + self.layer1 = nn.Linear(hidden_size, ffn_size) + self.gelu = nn.GELU() + self.layer2 = nn.Linear(ffn_size, hidden_size) + + def forward(self, x): + x = self.layer1(x) + x = self.gelu(x) + x = self.layer2(x) + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, hidden_size, attention_dropout_rate, num_heads): + super(MultiHeadAttention, self).__init__() + + self.num_heads = num_heads + + self.att_size = att_size = hidden_size // num_heads + self.scale = att_size ** -0.5 + + self.linear_q = nn.Linear(hidden_size, num_heads * att_size) + self.linear_k = nn.Linear(hidden_size, num_heads * att_size) + self.linear_v = nn.Linear(hidden_size, num_heads * att_size) + self.att_dropout = nn.Dropout(attention_dropout_rate) + + self.output_layer = nn.Linear(num_heads * att_size, hidden_size) + + def forward(self, q, k, v, attn_bias=None): + orig_q_size = q.size() + + d_k = self.att_size + d_v = self.att_size + batch_size = q.size(0) + + # head_i = Attention(Q(W^Q)_i, K(W^K)_i, V(W^V)_i) + q = self.linear_q(q).view(batch_size, -1, self.num_heads, d_k) + k = self.linear_k(k).view(batch_size, -1, self.num_heads, d_k) + v = self.linear_v(v).view(batch_size, -1, self.num_heads, d_v) + + q = q.transpose(1, 2) # [b, h, q_len, d_k] + v = v.transpose(1, 2) # [b, h, v_len, d_v] + k = k.transpose(1, 2).transpose(2, 3) # [b, h, d_k, k_len] + + # Scaled Dot-Product Attention. + # Attention(Q, K, V) = softmax((QK^T)/sqrt(d_k))V + q = q * self.scale + x = torch.matmul(q, k) # [b, h, q_len, k_len] + if attn_bias is not None: + x = x + attn_bias + + weight_matrix = x + + x = torch.softmax(x, dim=3) + x = self.att_dropout(x) + x = x.matmul(v) # [b, h, q_len, attn] + + x = x.transpose(1, 2).contiguous() # [b, q_len, h, attn] + x = x.view(batch_size, -1, self.num_heads * d_v) + + x = self.output_layer(x) + + assert x.size() == orig_q_size + return x, weight_matrix + + +class EncoderLayer(nn.Module): + def __init__(self, hidden_size, ffn_size, dropout_rate, attention_dropout_rate, num_heads): + super(EncoderLayer, self).__init__() + + self.self_attention_norm = nn.LayerNorm(hidden_size) + self.self_attention = MultiHeadAttention( + hidden_size, attention_dropout_rate, num_heads) + self.self_attention_dropout = nn.Dropout(dropout_rate) + + self.ffn_norm = nn.LayerNorm(hidden_size) + self.ffn = FeedForwardNetwork(hidden_size, ffn_size, dropout_rate) + self.ffn_dropout = nn.Dropout(dropout_rate) + + def forward(self, x, attn_bias=None): + y = self.self_attention_norm(x) + y, weight = self.self_attention(y, y, y, attn_bias) + y = self.self_attention_dropout(y) + x = x + y + + y = self.ffn_norm(x) + y = self.ffn(y) + y = self.ffn_dropout(y) + x = x + y + weight = weight + attn_bias + return x, weight diff --git a/federatedscope/gfkd/model/graph2D_model.py b/federatedscope/gfkd/model/graph2D_model.py new file mode 100644 index 000000000..b30c27626 --- /dev/null +++ b/federatedscope/gfkd/model/graph2D_model.py @@ -0,0 +1,61 @@ +import torch +import torch.nn.functional as F +from torch.nn import ModuleList +from torch_geometric.data import Data +from torch_geometric.data.batch import Batch + +from federatedscope.gfl.model.graph_level import GNN_Net_Graph + + +class GNN_Net(GNN_Net_Graph): + r"""GNN model with pre-linear layer, pooling layer + and output layer for graph classification tasks. + + Arguments: + in_channels (int): input channels. + out_channels (int): output channels. + hidden (int): hidden dim for all modules. + max_depth (int): number of layers for gnn. + dropout (float): dropout probability. + gnn (str): name of gnn type, use ("gcn" or "gin"). + pooling (str): pooling method, use ("add", "mean" or "max"). + """ + def __init__(self, + in_channels, + out_channels, + hidden=64, + max_depth=2, + dropout=.0, + gnn='gcn', + pooling='add', + conformer=False): + super(GNN_Net, self).__init__() + self.conformer = conformer + + def forward(self, data): + if self.conformer == False: + if isinstance(data, Batch): + x, edge_index, batch = data.x, data.edge_index, data.batch + elif isinstance(data, tuple): + x, edge_index, batch = data + else: + raise TypeError('Unsupported data type!') + else: + if isinstance(data, Batch):# position is as attr in x + x, edge_index, batch, pos = data.x, data.edge_index, data.batch + elif isinstance(data, tuple):# position + x, edge_index, batch, pos = data.x, data.edge_index, data.batch, data.pos + else: + raise TypeError('Unsupported data type!') + + if x.dtype == torch.int64: + x = self.encoder_atom(x) + else: + x = self.encoder(x) + + x = self.gnn((x, edge_index)) + x = self.pooling(x, batch) + x = self.linear(x) + x = F.dropout(x, self.dropout, training=self.training) + x = self.clf(x) + return x \ No newline at end of file diff --git a/federatedscope/gfkd/model/graph3D_model.py b/federatedscope/gfkd/model/graph3D_model.py new file mode 100644 index 000000000..279e2a9bb --- /dev/null +++ b/federatedscope/gfkd/model/graph3D_model.py @@ -0,0 +1,42 @@ +import torch +import torch.nn.functional as F +from torch_geometric.nn import DimeNetPlusPlus + +class DimeNetPlusPlus_for_QM7b(DimeNetPlusPlus): + def __init__( + self, + hidden_channels: int = 256, + out_channels: int = 14, + num_blocks: int = 3, + int_emb_size: int = 64, + basis_emb_size: int = 8, + out_emb_channels: int = 256, + num_spherical: int = 7, + num_radial: int = 6, + cutoff: float = 5.0, + max_num_neighbors: int = 32, + envelope_exponent: int = 5, + num_before_skip: int = 1, + num_after_skip: int = 2, + num_output_layers: int = 3, + act: Union[str, Callable] = 'swish', + ): + super().__init__( + hidden_channels=hidden_channels, + out_channels=out_channels, + num_blocks=num_blocks, + int_emb_size=int_emb_size, + basis_emb_size=basis_emb_size, + out_emb_channels=out_emb_channels, + num_bilinear=1, + num_spherical=num_spherical, + num_radial=num_radial, + cutoff=cutoff, + max_num_neighbors=max_num_neighbors, + envelope_exponent=envelope_exponent, + num_before_skip=num_before_skip, + num_after_skip=num_after_skip, + num_output_layers=num_output_layers, + act=act, + ) + \ No newline at end of file diff --git a/federatedscope/gfkd/model/model_builder.py b/federatedscope/gfkd/model/model_builder.py new file mode 100644 index 000000000..335d5d8f7 --- /dev/null +++ b/federatedscope/gfkd/model/model_builder.py @@ -0,0 +1,43 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +from federatedscope.gfkd.model.graph2D_model import GNN_Net +from federatedscope.gfkd.model.graph3D_model import DimeNetPlusPlus_for_QM7b +from federatedscope.gfkd.model.SMILES_model import SMILESTransformer + + +def get_gnn(model_config, input_shape): + + x_shape, num_label, num_edge_features = input_shape + if not num_label: + num_label = 0 + if model_config.type == 'SMILESTransformer': + # assume `data` is a dict where key is the client index, + # and value is a PyG object + model = SMILESTransformer(ntoken=415, + ninp=128, + nhead=8, + nhid=model_config.hidden, + nlayers=model_config.layer, + dropout=model_config.dropout) + elif model_config.type == 'GNN_Net': + model = GNN_Net(x_shape[-1], + max(model_config.out_channels, num_label), + hidden=model_config.hidden, + max_depth=model_config.layer, + dropout=model_config.dropout, + gnn=model_config.type, + pooling=model_config.graph_pooling) + elif model_config.type == 'DimeNetPlusPlus': + model = DimeNetPlusPlus_for_QM7b(hidden_channels=model_config.hidden, + out_channels=model_config.out_channels, + num_blocks=model_config.num_blocks, + int_emb_size=model_config.int_emb_size, + basis_emb_size=model_config.basis_emb_size, + out_emb_channels=model_config.out_emb_channels) + else: + raise ValueError('not recognized model {}'.format( + model_config.type)) + + return model diff --git a/federatedscope/gfkd/trainer/graphtrainer.py b/federatedscope/gfkd/trainer/graphtrainer.py new file mode 100644 index 000000000..8f806c1a9 --- /dev/null +++ b/federatedscope/gfkd/trainer/graphtrainer.py @@ -0,0 +1,81 @@ +import logging + +from federatedscope.core.monitors import Monitor +from federatedscope.register import register_trainer +from federatedscope.core.trainers import GeneralTorchTrainer +from federatedscope.core.trainers.context import CtxVar +from federatedscope.core.trainers.enums import LIFECYCLE + +logger = logging.getLogger(__name__) + + +class GraphMiniBatchTrainer(GeneralTorchTrainer): + def _hook_on_batch_forward(self, ctx): + batch = ctx.data_batch.to(ctx.device) + pred = ctx.model(batch) + # TODO: deal with the type of data within the dataloader or dataset + if 'regression' in ctx.cfg.model.task.lower(): + label = batch.y + else: + label = batch.y.squeeze(-1).long() + if len(label.size()) == 0: + label = label.unsqueeze(0) + ctx.loss_batch = ctx.criterion(pred, label) + + ctx.batch_size = len(label) + ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) + ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH) + + def _hook_on_batch_forward_flop_count(self, ctx): + if not isinstance(self.ctx.monitor, Monitor): + logger.warning( + f"The trainer {type(self)} does contain a valid monitor, " + f"this may be caused by initializing trainer subclasses " + f"without passing a valid monitor instance." + f"Plz check whether this is you want.") + return + + if self.cfg.eval.count_flops and self.ctx.monitor.flops_per_sample \ + == 0: + # calculate the flops_per_sample + try: + batch = ctx.data_batch.to(ctx.device) + from torch_geometric.data import Data + if isinstance(batch, Data): + x, edge_index = batch.x, batch.edge_index + from fvcore.nn import FlopCountAnalysis + flops_one_batch = FlopCountAnalysis(ctx.model, + (x, edge_index)).total() + if self.model_nums > 1 and ctx.mirrored_models: + flops_one_batch *= self.model_nums + logger.warning( + "the flops_per_batch is multiplied by " + "internal model nums as self.mirrored_models=True." + "if this is not the case you want, " + "please customize the count hook") + self.ctx.monitor.track_avg_flops(flops_one_batch, + ctx.batch_size) + except: + logger.warning( + "current flop count implementation is for general " + "GraphMiniBatchTrainer case: " + "1) the ctx.model takes only batch = ctx.data_batch as " + "input." + "Please check the forward format or implement your own " + "flop_count function") + self.ctx.monitor.flops_per_sample = -1 # warning at the + # first failure + + # by default, we assume the data has the same input shape, + # thus simply multiply the flops to avoid redundant forward + self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \ + ctx.batch_size + + +def call_graph_level_trainer(trainer_type): + if trainer_type == 'graphminibatch_trainer': + trainer_builder = GraphMiniBatchTrainer + return trainer_builder + + +register_trainer('graphminibatch_trainer', call_graph_level_trainer) diff --git a/federatedscope/gfl/trainer/graphtrainer.py b/federatedscope/gfl/trainer/graphtrainer.py index 8f806c1a9..7d200d062 100644 --- a/federatedscope/gfl/trainer/graphtrainer.py +++ b/federatedscope/gfl/trainer/graphtrainer.py @@ -13,7 +13,6 @@ class GraphMiniBatchTrainer(GeneralTorchTrainer): def _hook_on_batch_forward(self, ctx): batch = ctx.data_batch.to(ctx.device) pred = ctx.model(batch) - # TODO: deal with the type of data within the dataloader or dataset if 'regression' in ctx.cfg.model.task.lower(): label = batch.y else: @@ -28,11 +27,6 @@ def _hook_on_batch_forward(self, ctx): def _hook_on_batch_forward_flop_count(self, ctx): if not isinstance(self.ctx.monitor, Monitor): - logger.warning( - f"The trainer {type(self)} does contain a valid monitor, " - f"this may be caused by initializing trainer subclasses " - f"without passing a valid monitor instance." - f"Plz check whether this is you want.") return if self.cfg.eval.count_flops and self.ctx.monitor.flops_per_sample \ @@ -48,26 +42,12 @@ def _hook_on_batch_forward_flop_count(self, ctx): (x, edge_index)).total() if self.model_nums > 1 and ctx.mirrored_models: flops_one_batch *= self.model_nums - logger.warning( - "the flops_per_batch is multiplied by " - "internal model nums as self.mirrored_models=True." - "if this is not the case you want, " - "please customize the count hook") self.ctx.monitor.track_avg_flops(flops_one_batch, ctx.batch_size) except: - logger.warning( - "current flop count implementation is for general " - "GraphMiniBatchTrainer case: " - "1) the ctx.model takes only batch = ctx.data_batch as " - "input." - "Please check the forward format or implement your own " - "flop_count function") self.ctx.monitor.flops_per_sample = -1 # warning at the # first failure - # by default, we assume the data has the same input shape, - # thus simply multiply the flops to avoid redundant forward self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \ ctx.batch_size