Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions SimCLR/simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from SimCLR import loss

from .utils import save_checkpoint, save_config_file
from torch.profiler import profile, record_function, ProfilerActivity


class SimCLR(object):
Expand Down Expand Up @@ -36,21 +37,26 @@ def train(self, train_loader):
if dist_utils.is_dist_avail_and_initialized():
train_loader.sampler.set_epoch(epoch_counter)
for images, _ in tqdm(train_loader):
view1_images = images["view1"].cuda(self.device_id) # noqa: PLW2901
view2_images = images["view2"].cuda(self.device_id) # noqa: PLW2901
# Concatenate the two views so we run inference once.
images = torch.cat([view1_images, view2_images], dim=0) # noqa: PLW2901
images = images.cuda(self.device_id) # noqa: PLW2901

with autocast(enabled=self.args.fp16_precision):
features = self.model(images)
hidden1, hidden2 = torch.split(features, features.shape[0] // 2)
loss = self.criterion(hidden1, hidden2, self.device_id)

self.optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(self.optimizer)
scaler.update()
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True,
profile_memory=True) as prof:

with record_function("forward_pass"):
view1_images = images["view1"].cuda(self.device_id) # noqa: PLW2901
view2_images = images["view2"].cuda(self.device_id) # noqa: PLW2901
# Concatenate the two views so we run inference once.
images = torch.cat([view1_images, view2_images], dim=0) # noqa: PLW2901
images = images.cuda(self.device_id) # noqa: PLW2901
with record_function("compute_loss"), autocast(enabled=self.args.fp16_precision):
features = self.model(images)
hidden1, hidden2 = torch.split(features, features.shape[0] // 2)
loss = self.criterion(hidden1, hidden2, self.device_id)
with record_function("backward_pass"):
self.optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(self.optimizer)
scaler.update()

if n_iter % self.args.log_every_n_steps == 0:
print(
Expand All @@ -62,6 +68,7 @@ def train(self, train_loader):
self.scheduler.get_last_lr()[0],
global_step=n_iter,
)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))

n_iter += 1

Expand All @@ -86,4 +93,4 @@ def train(self, train_loader):
f"Model checkpoint and metadata has been saved at {self.writer.log_dir}."
)

print("Training has finished.")
print("Training has finished.")