본문 바로가기
경제학코딩 2023

[torch.nn - (1)]

by 개발도사(진) 2023. 12. 14.

torch.nn이란?

torch.nn module은 PyTorch의 모든 Neural network의 Base Class이다.

 

1. dataset 불러오기

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

분석 대상으로 삼기 위해, 위 code를 사용해서 tensorflow가 제공하는 mnist dataset을 불러온다.

tensorflow가 제공하는 mnist dataset은 다음과 같은 숫자 손글씨 dataset이다.

우리의 목표는 x_train data를 사용해서 y_train data를 맞추는 모델을 만들고 해당 모델을 이용해 x_test를 사용해서  y_test를 맞추는 것이다.

 

2. dataset 정제

처음 data를 받아오고 나면, x_train은 28*28 numpy array로 만들어진 그림이 60000개 있는 형태이다. 0~59999까지의 index를 이용해 직접 x_train 값을 출력해 봄으로써 확인할 수 있다.

앞으로의 분석을 위해, 다음 코드를 사용해 데이터를 정제한다.

x_train = torch.tensor(x_train.reshape(60000, 784)/255,dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)

데이터를 정제하면서 x_train 값을 255로 나눠준 이유는, 해당 값이 색을 나타내는 0~255 사이의 값이기 때문에 그 범위를 0~1로 바꿔주기 위함이다.

 

3. Regression

x_train@weight+bias

= [60000*784]@[784*10]+[10] = [60000*10]

위 matrix muliplication을 사용하여 x_train을 60000*10 형태로 바꿔 줄 것이다. 여기에 사용되는 code는 아래와 같다.

weights = torch.randn((784,10))
weights.requires_grad_()
bias = torch.zeros((10, ), requires_grad = True)

def model_0(inputs):
  outputs = inputs@weights+bias
  return outputs

prob = model_0(x_train)
prob

weights, bias에 사용된 requires_grad는 neural network training을 위해 필요하기 때문에 위와 같은 방식으로 각각 추가해 주어야 한다. 위 코드를 수행하고 우리가 얻은 결과값을 확인해 보면

 

위와 같은데, 각 value 들은 regression 결과 해당하는 data값이 0이 될 '확률'을 나타낸다. 그러나 현재 출력된 값들을 살펴보면 우리가 흔히 다루는 확률이라고 볼 수 없다. 이를 보정하기 위해 우리는 여기에 softmax를 도입해야 한다.

 

4. softmax, log_softmax

softmax는 다음과 같이 구한다. [a, b, c] 라는 3개의 수가 주어졌다고 가정했을 때, 

우리는 이 값의 log 버전을 이용할 것이다. 이를 위한 code는 아래와 같다.

def log_softmax(x):
    return x - x.exp().sum(-1, keepdim=True).log()

 

이제 log_softmax를 적용한 model_1을 만들어 x_train 값을 적용시키면, 아래와 같은 값을 구할 수 있다.

def model_1(inputs):
    outputs = inputs @ weights + bias
    return log_softmax(outputs)
    
model_1(x_train)

 

'경제학코딩 2023' 카테고리의 다른 글

[Encoder Only Model with imdb]  (0) 2023.12.17
[Encoder Only Model with Random Data]  (1) 2023.12.17
[Keras Layer - (2)]  (0) 2023.12.16
[Keras Layer-(1)]  (0) 2023.12.16
[torch.nn-(2)]  (1) 2023.12.15