import os
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset import *
from model import *
test function의 소스입니다.
# Define a test function.
def test(model, criterion, dataloader, device):
# Perform an evaluation using the defined network.
model.eval()
# Wrap the iterable dataloader with tqdm.
bar = tqdm(dataloader)
samples = 0
total_loss, total_acc =0, 0
for batch_idx, (data,label) in enumerate(bar):
# Move both data and label to device (e.g. GPU).
data = data.type(torch.FloatTensor).to(device)
label = label.to(device)
# Pass the input data through the defined network architecture.
pred = model(data, extract=True)
# Compute a loss function.
loss = criterion(pred, label)
total_loss += loss.item()*len(label)
# Compute speaker recognition accuracy.
samples += len(label)
acc = torch.sum(torch.eq(torch.argmax(pred,-1),label)).item()
total_acc += acc
return total_loss/samples, (total_acc/samples)*100.
test function 의 세부 로직을 살펴보기에 앞서,
이 호출되는 위치를 살펴보고, 전체적인 기능이 무엇인지 설명드리겠습니다.
test function은 main function 소스 내에 epoch loop 내에서 호출됩니다. (###.. 으로 표기해두었습니다.)
for epoch in range(start, epochs+1):
# Train the network.
train(epoch, model, softmax_criterion, optimizer, train_loader, device)
# Test the network.
#######################################################################
#######################################################################
opt_loss, opt_acc = test(model, softmax_criterion, test_loader, device)
#######################################################################
#######################################################################
# Save the optimal model.
if opt_loss < prev_loss:
prev_loss = opt_loss
torch.save({'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()},
'./model/model_opt.pth')
ct_edec = 0
else:
ct_dec += 1
# Decrease the learning rate by 2 when the test loss decreases 3 times in a row.
if ct_dec == 3:
optim_state = optimizer.state_dict()
optim_state['param_groups'][0]['lr'] /= 2
optimizer.load_state_dict(optim_state)
print('lr is divided by 2.')
ct_dec = 0
test function 의 기능은 로직 흐름을 보면 알 수 있습니다.
train 한 후, test function을 호출하게 되는데,
test function 에서 return 되는 값이 opt_loss, opt_acc 입니다.
그 다음으로 이어지는 로직을 보면
opt_loss 값이 기존 건(prev_loss)보다 작은 경우, 해당 모델을 저장합니다.
요약하자면, test function의 기능은 epoch loop 돌며, train한 모델 중 test 결과로 loss가 가장 적은 모델을 model_opt.pth 파일로 저장하는 것임을 알 수 있습니다.
model_opt.pth 파일은 아래 로직을 보면 사용방법을 알 수 있습니다.
# Train a model.
main()
# Load the pre-trained model and train more.
###########################################
###########################################
#main(model_path='./model/model_opt.pth')
###########################################
###########################################
main function을 호출할 때,
parameter로 model_path로 해당 파일경로를 지정해주면 됩니다.
model_path가 main function 내에서 어떻게 사용되는지 확인해보면
#################################### Load pre-trained model ########################################
start = 0
# Load the pre-trained model.
print('Directory of the pre-trained model: {}'.format(model_path))
if model_path:
check = torch.load(model_path)
start = check['epoch']
model.load_state_dict(check['model'])
optimizer.load_state_dict(check['optimizer'])
print('## Successfully load the model at {} epochs!'.format(start))
torch.load(model_path) 로 저장해둔 파일을 load 하고 -> check 변수에 저장합니다.
check 변수 중 'epoch' 값을 -> start 로 넣습니다.
위 start 값은 이 다음에 이어지는 epoch loop 돌 때, start 값으로 기입됩니다.
model과 optimizer는 load_state_dict 으로 불러옵니다.
* state_dict 에 대한 참고글 : tutorials.pytorch.kr/beginner/saving_loading_models.html#state-dict
모델 저장하기 & 불러오기 — PyTorch Tutorials 1.6.0 documentation
Note Click here to download the full example code 모델 저장하기 & 불러오기 Author: Matthew Inkawhich번역: 박정환 이 문서에서는 PyTorch 모델을 저장하고 불러오는 다양한 방법을 제공합니다. 이 문서 전체를 다
tutorials.pytorch.kr
다음으로 test function 내부 로직에 대한 세부 분석결과 전달드리겠습니다.
# Define a test function.
def test(model, criterion, dataloader, device):
# Perform an evaluation using the defined network.
model.eval()
# Wrap the iterable dataloader with tqdm.
bar = tqdm(dataloader)
samples = 0
total_loss, total_acc = 0, 0
for batch_idx, (data,label) in enumerate(bar):
# Move both data and label to device (e.g. GPU).
data = data.type(torch.FloatTensor).to(device)
label = label.to(device)
# Pass the input data through the defined network architecture.
pred = model(data, extract=True)
# Compute a loss function.
loss = criterion(pred, label)
total_loss += loss.item()*len(label)
# Compute speaker recognition accuracy.
samples += len(label)
acc = torch.sum(torch.eq(torch.argmax(pred,-1),label)).item()
total_acc += acc
return total_loss/samples, (total_acc/samples)*100.
model.eval()
(출처 : go-hard.tistory.com/64)
- Perform an evaluation using the defined network
- model.eval()은 eval mode에서 사용할 것이라고 모든 레이어에 선언하는 것이며, 배치 정규화나 dropout layer들은 학습모드 대신에 eval mode로 작동한다. (eval 모드에서는 dropout은 비활성화, 배치 정규화는 학습에서 저장된 파라미터를 사용)
bar = tqdm(dataloader)
(출처 : https://skillmemory.tistory.com/17)
- 파이썬으로 어떤 작업을 수행중인데, 프로그램이 내가 의도한 데로 돌아가고 있는 중인가, 진행상황이 궁금할 때가 있다. 시간이 걸리는 작업의 경우에 상태확인 해주는 역할
- 이미지 예시
for batch_idx, (data,label) in enumerate(bar):
(출처 : wikidocs.net/20792)
- 리스트가 있는 경우 순서와 리스트의 값을 전달하는 기능
- 이 함수는 순서가 있는 자료형(list, set, tuple, dictionary, string)을 입력으로 받아 인덱스 값을 포함하는 enumerate 객체를 리턴
- 보통 enumerate 함수는 for문과 함께 자주 사용
to(device)
# Move both data and label to device (e.g. GPU).
data = data.type(torch.FloatTensor).to(device)
label = label.to(device)
(출처 : tutorials.pytorch.kr/recipes/recipes/save_load_across_devices.html)
- GPU에서 학습하고 저장된 모델을 GPU에서 불러올 때는, 초기화된 모델에 model.to(torch.device(‘cuda’)) 을 호출하여 CUDA에 최적화된 모델로 변환함. 그리고 모든 입력에 .to(torch.device('cuda')) 함수를 호출해야 모델에 데이터를 제공할 수 있음
- device 변수는 main function 에서 cuda 가능여부에 따라 cuda 또는 cpu 로 세팅됨
pred = model(data, extract=True)
# Pass the input data through the defined network architecture.
pred = model(data, extract=True)
- model 은 model.py 로 교수님이 주신 class 소스
- model.py 내의 forward function이 수행하게 됨
- 파라미터 extract 가 False 인 경우, return 되는 값은 d_vector가 되도록 설계되어있음
즉, extract 는 d_vector 얻어내기 위한 flag 파라미터
d_vector 는 아래 수업자료 참고
나머지는
loss 값과 acc 값 각각의 전체 sum을 구해서 return 해주면서 test function 은 완료