-
TTA (Test Time Augmentation) 란딥러닝(deep Learning)/딥러닝 이슈(Issue) 2021. 6. 19. 23:04반응형
0. 서론
- TTA ( Test Time Augmentation) 이란, 말 그대로 model 을 테스트 할때에도, Data Augmentation 을 한다고 이해하면 될 것이다. 그림으로 살펴 보았을때, 밑의 input 즉 원본 이미지를 flip 및 rotation, zoom 등을 하여, 원본으로부터 변형된 여러가지 Image Augmentation 에 평가를 실시하여, 최종 분류값이 무엇인지 예측하는 기법이다.
- 좀 더 직관적으로 본다면, 모델에 한가지의 이미지를 주는 것보다는 여러가지 변형된 이미지를 주어, 평가를 하게 되면, 발생하는 오차는 작아지는 것이 당연해 보인다.
- TTA 를 쓰게 되면, 모델이 편향된 학습결과를 가지고 있을때, 그러한 편향에서 벗어나 좀 더 좋은 예측을 할 수 있게 된다. 하지만 언제까지나 좋은 TTA를 썻을때, 좋은 데이터 모델링이 나오게 된다는 것 명심하자.
1. Pytorch implementation
0. Module loading¶
In [1]:import torch import ttach as tta import timm import numpy as np from PIL import Image import matplotlib.pyplot as plt import cv2
In [66]:image_path = "./golden.jpg" image = np.array(Image.open(image_path)) / 255 # 이미지를 읽고 min max scaling image = cv2.resize(image, (384, 384)) # Vision Transformer base model input size plt.imshow(image) plt.title("original image") image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).to(torch.float32) # unsqeeze 를 해준 이유는 batchsize 를 추가시켜주었음. print("image_shape =",image.shape) # batch, channel , height, width
image_shape = torch.Size([1, 3, 384, 384])
- Image.open 을 통해골든 리트리버 이미지를 불러 왔고, image 에 unsqeeze(0) 을 해주어 1차원의 데이터를 추가해주었는데, 배치사이즈 를 추가해준 것이다. (B,C,H,W) 순이다.
- model 로써 vit_base_patch 를 사용하였는데, 이는 (384,384) 이미지를 사용하였기에, 이미지를 resize 시켜준다.
2. Models¶
In [14]:transforms = tta.Compose( [ tta.HorizontalFlip(), tta.Rotate90(angles=[0, 90]), # tta.Scale(scales=[1, 2]), # tta.FiveCrops(384, 384), tta.Multiply(factors=[0.7, 1]), ] ) model = timm.create_model("vit_base_patch16_384", pretrained=True) # model = timm.create_model("seresnet50", pretrained=True)
- Transforms 에 tta.compose 를 통하여 원하는 function 을 집어넣어주었다.
- 아까 말했듯이 model 은 timm 모듈로 "vit_base_patch" 을 사용
3. Transformer¶
In [23]:# transforms 를 살펴보게 되면 다음과 같이 여러가지의 tools 들이 조합하여 생겨났다. transforms.aug_transform_parameters # HorizontalFlip ,Rotate90, Multiply 순
Out[23]:[(False, 0, 0.7), (False, 0, 1), (False, 90, 0.7), (False, 90, 1), (True, 0, 0.7), (True, 0, 1), (True, 90, 0.7), (True, 90, 1)]
- transforms 에서는 각각을 augmentation 시켜 주었을때, itertools.product(*[]) 메소드를 통해 마치 순열의 형태로 각각의 함수에 대해 data augmentation 을 진행해준다.
In [76]:fig = plt.figure(figsize = (10,10)) original_image = np.array((image*255).squeeze()).transpose(1, 2, 0).astype(np.uint8) fig.add_subplot(1,3,1) plt.title("original image") plt.imshow(original_image) # augmented 된 사진 aug_image = next(iter(transforms)).image_pipeline(image) augmented_image = np.array((aug_image*255).squeeze()).transpose(1, 2, 0).astype(np.uint8) fig.add_subplot(1,3,2) plt.title("augment image") plt.imshow(augmented_image)
Out[76]:<matplotlib.image.AxesImage at 0x15c010b9e48>
In [15]:imagenet_labels = dict(enumerate(open('./imagenet1000_labels.txt'))) # ImageNet class name 1000개의 라벨링 데이터 fig = plt.figure(figsize=(20, 20)) columns = 3 rows = 3 for i, transformer in enumerate(transforms): # custom transforms augmented_image = transformer.augment_image(image) output = model(augmented_image) predicted = imagenet_labels[output.argmax(1).item()].strip() augmented_image = np.array((augmented_image*255).squeeze()).transpose(1, 2, 0).astype(np.uint8) fig.add_subplot(rows, columns, i+1) plt.imshow(augmented_image) plt.title(predicted) plt.show()
- 잘못 예측된 값 [3번] = "Great Pyrenees" 라는 강아지
- 결론 : 위의 결과에서 보았듯이 여러가지 augmentation 을 통해서 살펴본 결과, 모델이 데이터를 잘못 예측할 확률이 많이 적어진다는 것을 확인 할 수 있었다. 따라서, 모델을 평가하거나 성능을 조금이라도 높이고 싶을때, TTA 를 자주 사용해보자.
REF.
1. https://stepup.ai/test_time_data_augmentation/
2.https://visionhong.tistory.com/26
3. https://github.com/qubvel/ttach
반응형