본문 바로가기

파이썬(PYTHON)

PYTORCH 숫자 이미지 인식 및 분류

728x90
# mage dataset preparation in PyTorch (Dataloaders and Transforms)

import torchvision
import torch
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import numpy as np
import time

train_dataset_path = './data/num/training/training'
test_dataset_path = './data/num/validation/validation'

mean = [0.4363, 0.4328, 0.3291]
std = [0.2120, 0.2075, 0.2038]

train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
])

test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
])

train_dataset = torchvision.datasets.ImageFolder(root=train_dataset_path, transform=train_transforms)
test_dataset = torchvision.datasets.ImageFolder(root=test_dataset_path, transform=test_transforms)

def show_transforms_images(dataset):
    loader = torch.utils.data.DataLoader(dataset, batch_size=9, shuffle=True)
    batch = next(iter(loader))
    images, labels = batch

    grid = torchvision.utils.make_grid(images, nrow=3)
   
    plt.imshow(np.transpose(grid, (1,2,0)))
    #print('labels:', labels, int(labels[0]))
    plt.text(100, 40, int(labels[0]+1), fontsize=20)
    plt.text(350, 40, int(labels[1]+1), fontsize=20)
    plt.text(550, 40, int(labels[2]+1), fontsize=20)

    plt.text(100, 250, int(labels[3]+1), fontsize=20)
    plt.text(350, 250, int(labels[4]+1), fontsize=20)
    plt.text(550, 250, int(labels[5]+1), fontsize=20)

    plt.text(100, 480, int(labels[6]+1), fontsize=20)
    plt.text(350, 480, int(labels[7]+1), fontsize=20)
    plt.text(550, 480, int(labels[8]+1), fontsize=20)

    plt.pause(1)
    plt.clf()

plt.figure(figsize=(11, 11))
while True:
    show_transforms_images(train_dataset)
    #time.sleep(1)

#train_loadert = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
#test_loadert = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)
 
 

 

728x90