SwAV 코드 살펴보기
2021. 3. 23. 15:14ㆍ딥러닝
def train(train_loader, model, optimizer, epoch, lr_schedule, queue):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
softmax = nn.Softmax(dim=1).cuda()
model.train()
use_the_queue = False
end = time.time()
for it, inputs in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
# update learning rate
iteration = epoch * len(train_loader) + it
for param_group in optimizer.param_groups:
param_group["lr"] = lr_schedule[iteration]
# normalize the prototypes
with torch.no_grad():
w = model.module.prototypes.weight.data.clone()
w = nn.functional.normalize(w, dim=1, p=2)
model.module.prototypes.weight.copy_(w)
# ============ multi-res forward passes ... ============
"""
embedding.size() = [32, 128] = [sum(args.nmb_crops)*args.batch_size , args.feat_dim]
output.size() = [32, 3000]= [sum(args.nmb_crops)*args.batch_size , args.nmb_prototypes]
bs = args.batch_size()
"""
embedding, output = model(inputs)
embedding = embedding.detach()
bs = inputs[0].size(0)
# ============ swav loss ... ============
loss = ...
# ============ backward and optim step ... ============
optimizer.zero_grad()
if args.use_fp16:
with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# cancel some gradients
if iteration < args.freeze_prototypes_niters:
for name, p in model.named_parameters():
if "prototypes" in name:
p.grad = None
optimizer.step()
# ============ misc ... ============
losses.update(loss.item(), inputs[0].size(0))
batch_time.update(time.time() - end)
end = time.time()
if args.rank ==0 and it % 50 == 0:
logger.info(
"Epoch: [{0}][{1}]\t"
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
"Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
"Loss {loss.val:.4f} ({loss.avg:.4f})\t"
"Lr: {lr:.4f}".format(
epoch,
it,
batch_time=batch_time,
data_time=data_time,
loss=losses,
lr=optimizer.optim.param_groups[0]["lr"],
)
)
return (epoch, losses.avg), queue
- 전반적인 train 1 epoch.
embedding, output = model(inputs)
에서 embedding : $\boldsymbol{z}{t}$
output : $\boldsymbol{z}{t}\boldsymbol{C}$
즉 output = prototype(embedding)
- model 안에 prototype module이 들어가있음.
# ============ swav loss ... ============
loss = 0
for i, crop_id in enumerate(args.crops_for_assign):
with torch.no_grad():
out = output[bs * crop_id: bs * (crop_id + 1)] # out.size() = [batch_size, nmb_prototypes]
# time to use the queue
if queue is not None:
if use_the_queue or not torch.all(queue[i, -1, :] == 0):
use_the_queue = True
out = torch.cat((torch.mm(
queue[i], # queue[i].size = [3840, 128] = [args.queue_length. args.feat_dim]
model.module.prototypes.weight.t() # model.module.prototypes.weight.size() = [nmb_prototypes. args.feat_dim]
), out)) # out.size() = [3872, 3000]
# fill the queue
queue[i, bs:] = queue[i, :-bs].clone()
queue[i, :bs] = embedding[crop_id * bs: (crop_id + 1) * bs] #size [batch, args.feat_dim]
# get assignments
q = out / args.epsilon
if args.improve_numerical_stability:
M = torch.max(q)
dist.all_reduce(M, op=dist.ReduceOp.MAX)
q -= M
q = torch.exp(q).t()
q = distributed_sinkhorn(q, args.sinkhorn_iterations)[-bs:]
# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(args.nmb_crops)), crop_id):
p = softmax(output[bs * v: bs * (v + 1)] / args.temperature)
subloss -= torch.mean(torch.sum(q * torch.log(p), dim=1))
loss += subloss / (np.sum(args.nmb_crops) - 1)
loss /= len(args.crops_for_assign)
Loss 계산
def distributed_sinkhorn(Q, nmb_iters):
with torch.no_grad():
Q = shoot_infs(Q)
sum_Q = torch.sum(Q)
dist.all_reduce(sum_Q)
Q /= sum_Q
r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0]
c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (args.world_size * Q.shape[1])
for it in range(nmb_iters):
u = torch.sum(Q, dim=1)
dist.all_reduce(u)
u = r / u
u = shoot_infs(u)
Q *= u.unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()
iterative한 Q계산
'딥러닝' 카테고리의 다른 글
AI Song Contest: Human-AI Co-Creation in Songwriting (0) | 2021.04.26 |
---|---|
Aleatoric and Epistemic Uncertainty - Alex Kendall (0) | 2021.03.31 |
SwAV, SEER-Unsupervised Learning by Contrasting Cluster Assignments (0) | 2021.03.23 |
Fine-grained 서베이 논문 3편 (0) | 2021.03.16 |
얀 르쿤 페이스북 요약. Self-supervised learning: NLP vs VISION (0) | 2021.03.08 |