2강 Data Representation Learning

 

Step 1 Load MNIST dataset

In [28]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
import keras
from keras.datasets import mnist
from sklearn import preprocessing
import numpy as np

(train_xs, train_ys), (test_xs, test_ys) = mnist.load_data()
dim_x = train_xs.shape[1] * train_xs.shape[2]
dim_y = 10

train_xs = train_xs.reshape(train_xs.shape[0], dim_x).astype(np.float32)

scaler = preprocessing.MinMaxScaler().fit(train_xs)
train_xs = scaler.transform(train_xs)
print(train_xs.shape)
print(train_ys.shape)
 
 
 
(60000, 784)
(60000,)
 

Step 2 Data Sampling

In [27]:
ridx = np.random.randint(train_xs.shape[0], size=10000)
np_train_xs = train_xs[ridx, :]
np_train_ys = train_ys[ridx]
print(np_train_xs.shape)
print(np_train_ys.shape)
 
(10000, 784)
(10000,)
 

Step 3 Import t-SNE & seaborn

In [12]:
%matplotlib inline
import sklearn
from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_style('darkgrid')
sns.set_palette('muted')
sns.set_context("notebook", font_scale=1.5, rc={"lines.linewidth": 2.5})
 

Step 4 Define Scatterplot method. Run t-SNE.

MNIST 데이터를 t-SNE를 이용하여 차원 축소(784차원 -> 2차원) 및 시각화 하기

  • t-SNE의 이론적 배경 소개
    • Perplexity 의미 : 가까운 점은 그 Label에 상관 없이 같은 군집이라고 가정하는 척도
    • Perplexity가 작으면 : 군집간 거리가 멀어짐
    • Perplexity가 크면 : 군집간 거리가 가까워 짐 #### Step1 Display the result
In [26]:
def draw_scatter(x, n_class, colors):
    sns.palplot(sns.color_palette())
    palette = np.array(sns.color_palette())
    
    f = plt.figure(figsize=(14,14))
    ax = plt.subplot(aspect='equal')
    sc = ax.scatter(x[:,0], x[:,1], lw=0, s=540, c=palette[colors.astype(np.int)], alpha=0.2)
    plt.xlim(-25, 25)
    plt.ylim(-25, 25)
    ax.axis('off')
    ax.axis('tight')
    plt.show()
    
    
tsne_train_xs = TSNE(random_state=42).fit_transform(np_train_xs)
draw_scatter(tsne_train_xs, dim_y, np_train_ys)
 
 
 

t-SNE의 Perplexity 실습하기

In [30]:
ridx = np.random.randint(train_xs.shape[0], size = 1000) #data 크기를 줄임
np_train_xs = train_xs[ridx, :]
np_train_ys = train_ys[ridx]

sns.palplot(sns.color_palette()) # 숫자 0~9를 Color로 표시하여 보여줌
palette = np.array(sns.color_palette())
# 화면 구성은 3x3으로 보여줌
fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(15,15))
for ax, perplexity in zip(axs.flat, [2,5,10,20,30,50,75,100,150]):
    tsne_out = TSNE(n_components = 2, perplexity = perplexity).fit_transform(np_train_xs)
    title = 'Perpelexity = {}'.format(perplexity)
    ax.set_title(title)
    ax.scatter(tsne_out[:,0], tsne_out[:,1], lw=0, s=25, c=palette[np_train_ys.astype(np.int)], alpha=0.3)
    ax.axis('tight')

plt.show()
 
 
 
  • Perplexity가 작을 수록 분포의 면적이 넓어짐
  • Perplexity가 낮은 값에서부터 조금씩 증가할 수 록 군집 내 거리를 가깝고 군집간 거리가 멀어지는 것을 알 수 있음.
  • Perplexity가 30이 넘어가면서 전체 데이터가 차지하는 공간이 작아지면서 군집간 거리도 좁아지는 것을 알 수 있음.
  • Perplexity가 100이 넘으면 오히려 군집간 식별력이 떨어질 수 있음(본 문제의 경우)

+ Recent posts