-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
118 lines (94 loc) · 4.14 KB
/
main.py
File metadata and controls
118 lines (94 loc) · 4.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import pandas as pd
import torch
import torchvision.transforms as T
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
from torch.optim import AdamW
from configs.configs import get_config
from utils.utils import create_results_dir, print_gpu_usage
from utils.dataset import Crack500Dataset
from utils.evaluate import evaluate_model
from utils.losses import FocalLoss
from networks.vit_seg_modeling import VisionTransformer as ViT_seg
# Configuration
config = get_config()
root_dir = config.root_dir
results_dir = create_results_dir(config.base_results_dir)
best_model_path = os.path.join(results_dir, 'best_model.pth')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Data Preparation
train_transform = T.Compose([
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(),
T.RandomRotation(10),
T.ToTensor(),
])
train_dataset = Crack500Dataset(root_dir=root_dir, split='traincrop', transform=train_transform)
val_dataset = Crack500Dataset(root_dir=root_dir, split='valcrop', transform=T.ToTensor())
test_dataset = Crack500Dataset(root_dir=root_dir, split='testcrop', transform=T.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=0)
# Model, Loss, Optimizer
model = ViT_seg(config, img_size=config.img_size, num_classes=config.n_classes,
use_multi_filter=True, use_locally_position_aware=False,
use_channel_attention=False, use_multi_fusion=False).to(device)
criterion = FocalLoss(alpha=0.5, gamma=2.0).to(device)
optimizer = AdamW(model.parameters(), lr=config.learning_rate)
scaler = GradScaler()
# Training Loop
best_val_loss = float('inf')
patience_counter = 0
metrics = {"epoch": [], "train_loss": [], "val_loss": []}
for epoch in range(config.num_epochs):
print(f"\n Epoch {epoch+1}/{config.num_epochs}")
model.train()
running_loss = 0.0
for images, labels, _ in tqdm(train_loader, desc="Training"):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
with autocast():
outputs = model(images)
logits = torch.nn.functional.interpolate(outputs, size=labels.shape[-2:], mode="bilinear", align_corners=False)
loss = criterion(logits, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item()
torch.cuda.empty_cache()
train_loss = running_loss / len(train_loader)
print(f"Train Loss: {train_loss:.4f}")
print_gpu_usage()
# ---------------- Validation ----------------
model.eval()
val_loss = 0.0
with torch.no_grad():
for images, labels, _ in tqdm(val_loader, desc="Validating"):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
logits = torch.nn.functional.interpolate(outputs, size=labels.shape[-2:], mode="bilinear", align_corners=False)
val_loss += criterion(logits, labels).item()
val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")
# ---------------- Save Model ----------------
metrics["epoch"].append(epoch + 1)
metrics["train_loss"].append(train_loss)
metrics["val_loss"].append(val_loss)
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
torch.save(model.state_dict(), best_model_path)
print("Best model saved.")
else:
patience_counter += 1
if patience_counter >= config.early_stopping_patience:
print("Early stopping triggered.")
break
# Save training metrics
pd.DataFrame(metrics).to_excel(os.path.join(results_dir, 'training_metrics.xlsx'), index=False)
print("\n Evaluating on test set...")
model.load_state_dict(torch.load(best_model_path))
evaluate_model(model, test_loader, results_dir)
print("Evaluation complete.")