Week 11. RNN๊ณผ LSTM
์๊ฐ์ด ํ๋ฅด๋ ๋ฐ์ดํฐ โ ๋ง, ์์ , ์ฃผ๊ฐ, ์ผ์ ์ ํธ โ ๋ฅผ ๋ค๋ฃจ๋ ์ ๊ฒฝ๋ง. ๊ทธ๋ฆฌ๊ณ ๊ทธ๊ฒ์ ์ค์ฉํํ LSTM์ ๊ฒ์ดํธ ๊ตฌ์กฐ.
์ด๋ฒ ์ฃผ์ ๋ฐฐ์ฐ๋ ๊ฒ
- ์ํ์ค ๋ฐ์ดํฐ์ ํน์ง
- vanilla RNN๊ณผ BPTT
- ์ฅ๊ธฐ ์์กด์ฑ ๋ฌธ์
- LSTM์ ์ธ ๊ฒ์ดํธ
- GRU โ ๋ ๊ฐ๋จํ ๋์
1. ์ํ์ค ๋ฐ์ดํฐ โ ์๊ฐ์ด ํ๋ฅด๋ ์ธ์
์ง๊ธ๊น์ง ๋ค๋ฃฌ ์ ๊ฒฝ๋ง์ ์ ๋ ฅ์ด ๊ณ ์ ๋ ํฌ๊ธฐ์ ๋ฒกํฐ์์ต๋๋ค. W9 CNN์ ์ด๋ฏธ์ง๋ $224 \times 224$๋ก ํฌ๊ธฐ๊ฐ ๊ณ ์ ์ ๋๋ค. ๊ทธ๋ฐ๋ฐ ํ์ค์๋ ๊ธธ์ด๊ฐ ๊ฐ๋ณ์ ์ด๊ณ ์์๊ฐ ์ค์ํ ๋ฐ์ดํฐ๊ฐ ์์์ด ๋ง์ต๋๋ค:
- ์์ฐ์ด โ ๋ฌธ์ฅ์ ๊ธธ์ด๋ ์ ๊ฐ๊ฐ. ๋จ์ด ์์๊ฐ ๋ฐ๋๋ฉด ์๋ฏธ๊ฐ ์์ ํ ๋ฌ๋ผ์ง.
- ์์ฑ โ ์ํ ์๋ ๋ฐํ ๊ธธ์ด์ ๋ฐ๋ผ ๋ค๋ฆ. ์๊ฐ ์์๋๋ก ๋ค์ด์ผ ์๋ฏธ๊ฐ ์๊น.
- ์์ โ ๋ฉ๋ก๋๋ ์ํ์ ์๊ฐ ์์.
- ์ฃผ๊ฐยท์ผ์ ์ ํธ โ ์๊ณ์ด ๊ทธ ์์ฒด.
- ๋น๋์ค โ ํ๋ ์์ ์ํ์ค.
๋ ๊ฐ์ง ํน์ฑ์ด ์ด๋ค์ ์์ ์ฐ๊ฒฐ๋ง์ด๋ CNN์ผ๋ก ๋ค๋ฃจ๊ธฐ ์ด๋ ต๊ฒ ๋ง๋ญ๋๋ค.
ํน์ฑ 1 โ ๊ฐ๋ณ ๊ธธ์ด. "์๋ "์ ๋จ์ด 1๊ฐ, "์ค๋ ๋ ์จ๊ฐ ์ฐธ ์ข๋ค์"๋ 5๊ฐ. ์์ ์ฐ๊ฒฐ๋ง์ ์ ๋ ฅ ์ฐจ์์ ๊ณ ์ ์ด๋ผ ๋ ๋ฌธ์ฅ์ ๊ฐ์ ๋ฐฉ์์ผ๋ก ์ฒ๋ฆฌํ ์ ์์ต๋๋ค. ํจ๋ฉ(0์ผ๋ก ์ฑ์ฐ๊ธฐ)์ด๋ ์๋ฅด๊ธฐ๋ ์ ๋ณด ์์ค์ ๋ง๋ญ๋๋ค.
ํน์ฑ 2 โ ์์๊ฐ ๊ณง ์๋ฏธ. "๋๋ ํ๊ต์ ๊ฐ๋ค"์ "๊ฐ๋ค ํ๊ต์ ๋๋"์ ๋จ์ด ์งํฉ์ ๋๊ฐ์ง๋ง ์๋ฏธ๊ฐ ๋ค๋ฆ ๋๋ค. ๊ทธ๋ฐ๋ฐ ๋จ์ด๋ฅผ bag-of-words๋ก ๋จ์ํ ํฉ์น๋ฉด ์ด ๋ ๋ฌธ์ฅ์ด ๊ตฌ๋ณ๋์ง ์์ต๋๋ค. ๊ทธ๋ฆฌ๊ณ "๋๋ ์ด์ ๋์๊ด์์ ์น๊ตฌ๋ค๊ณผ ์๋ก ๋์จ ์ฑ ์ ์ฝ์๋ค"์์, ๋ง์ง๋ง ๋์ฌ๋ฅผ ๊ฒฐ์ ํ๋ ๋จ์(์ฃผ์ด "๋๋")๋ ๋จผ ๊ณผ๊ฑฐ์ ์์ต๋๋ค. ๋ชจ๋ธ์ด ์ฅ๊ฑฐ๋ฆฌ ์์กด์ฑ์ ํฌ์ฐฉํด์ผ ํฉ๋๋ค.
์ด ๋ ์๊ตฌ๋ฅผ ํด๊ฒฐํ๋ ์ ๊ฒฝ๋ง์ด ์ํ ์ ๊ฒฝ๋ง(Recurrent Neural Network, RNN)์ ๋๋ค.
2. Vanilla RNN โ ์๊ฐ์ด๋ผ๋ ๊ณต์ ์ถ์ ๋ฐ๋ผ
RNN์ ํต์ฌ ์์ด๋์ด๋ ๋ฑ ํ๋์ ๋๋ค: "์ง๊ธ๊น์ง ๋ณธ ๊ฒ์ ์์ฝํ ์๋ ์ํ(hidden state)๋ฅผ ๋ค์ ์์ ์ผ๋ก ๋๊ธฐ์". ์ด ์๋ ์ํ $h_t$๊ฐ ์ผ์ข ์ "๊ธฐ์ต"์ ๋๋ค.
๊ฐ์ฅ ๋จ์ํ RNN์ ์ ๋ฐ์ดํธ ๊ท์น (Elman, 1990):
$$ h_t = \tanh(W_h h_{t-1} + W_x x_t + b_h) $$ $$ y_t = W_y h_t + b_y $$์ฌ๊ธฐ์ $x_t$๋ ์์ $t$์ ์ ๋ ฅ(์: ํ์ฌ ๋จ์ด์ ์๋ฒ ๋ฉ), $h_t$๋ ์์ $t$์ ์๋ ์ํ(๊ณผ๊ฑฐ ์ ๋ถ๋ฅผ ์์ฝ), $y_t$๋ ์์ $t$์ ์ถ๋ ฅ. ํต์ฌ์ ๊ฐ์ ๊ฐ์ค์น ํ๋ ฌ $W_h, W_x, W_y$๋ฅผ ๋ชจ๋ ์์ ์ด ๊ณต์ ํ๋ค๋ ๊ฒ์ ๋๋ค. ๊ธธ์ด๊ฐ 10์ด๋ 1000์ด๋ ํ๋ผ๋ฏธํฐ ์๋ ๊ทธ๋๋ก โ ์ด๋ก์จ ๊ฐ๋ณ ๊ธธ์ด ๋ฌธ์ ๊ฐ ํด๊ฒฐ๋ฉ๋๋ค.
์๋ ๋ฐฉ์์ ์์ํด๋ด ์๋ค. $h_0 = 0$์ผ๋ก ์์. ์ฒซ ๋จ์ด $x_1$์ด ๋ค์ด์ค๋ฉด $h_1 = \tanh(W_x x_1 + b_h)$. ๋์งธ ๋จ์ด $x_2$๊ฐ ๋ค์ด์ค๋ฉด $h_2 = \tanh(W_h h_1 + W_x x_2 + b_h)$ โ ์ฌ๊ธฐ์ $h_1$์ ํตํด ์ฒซ ๋จ์ด ์ ๋ณด๊ฐ ์์ ๋๋ค. ์ด๋ฐ ์์ผ๋ก $h_t$๋ $x_1, \dots, x_t$ ์ ๋ถ์ ์์ฝ์ด ๋ฉ๋๋ค.
์ด ๊ตฌ์กฐ์ ๋ ๋ค๋ฅธ ๋งค๋ ฅ์ ํ๋ผ๋ฏธํฐ ๊ณต์ ๊ฐ ์ผ๋ฐํ๋ฅผ ๋๋๋ค๋ ๊ฒ์ ๋๋ค. ๊ฐ์ $W$๋ฅผ ์์ ๋ง๋ค ์ฐ๋ฏ๋ก, "๋ฌธ์ฅ ์ฒซ ์๋ฆฌ์์ ๋ณธ ํจํด"์ด "๋ฌธ์ฅ ์ค๊ฐ์ ๋ค์ ๋ํ๋๋ฉด" ๊ฐ์ ์ฒ๋ฆฌ๊ฐ ์ ์ฉ๋ฉ๋๋ค. ์ด๋ W9 CNN์ "๊ณต๊ฐ ์์น์ ์๊ด์์ด ๊ฐ์ ํํฐ"์ ๊ฐ์ ์๋ฆฌ โ ๋ค๋ง ์ฐจ์์ด ๊ณต๊ฐ์ด ์๋๋ผ ์๊ฐ.
2.1 BPTT โ ์๊ฐ์ ํผ์ณ ์ญ์ ํํ๊ธฐ
ํ์ต์ ์ด๋ป๊ฒ ํ ๊น์? RNN์ "์๊ฐ์ ๋ฐ๋ผ ํผ์น๋ฉด(unroll)" ์ฌ์ค์ ๊ฐ ์์ ๋ง๋ค ์ธต์ด ํ๋์ฉ ์๋ ์์ฃผ ๊น์ ํผ๋ํฌ์๋ ๋คํธ์ํฌ๊ฐ ๋ฉ๋๋ค. ์ฌ๊ธฐ์ W8 ์ญ์ ํ๋ฅผ ๊ทธ๋๋ก ์ ์ฉํ๋ ๊ฒ์ด BPTT(Backpropagation Through Time)์ ๋๋ค.
์์ค $L = \sum_t L_t$๊ฐ ์์ ๋, $W_h$์ ๋ํ ๊ทธ๋๋์ธํธ๋:
$$ \frac{\partial L}{\partial W_h} = \sum_t \sum_{k \le t} \frac{\partial L_t}{\partial h_t} \cdot \left(\prod_{j=k+1}^{t} \frac{\partial h_j}{\partial h_{j-1}}\right) \cdot \frac{\partial h_k}{\partial W_h} $$ํต์ฌ์ ์์ชฝ์ Jacobian ๊ณฑ $\prod_{j=k+1}^{t} \frac{\partial h_j}{\partial h_{j-1}}$์ ๋๋ค. ์ด๊ฒ์ด ๋ค์ ์น์ ์ ์ฃผ์ธ๊ณต์ ๋๋ค.
BPTT๋ ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋ง์ด ๋ญ๋๋ค. ๋ชจ๋ ์์ ์ ์๋ ์ํ๋ฅผ ์ ์ฅํด์ผ ์ญ์ ํํ ์ ์๊ธฐ ๋๋ฌธ์, ๊ธด ์ํ์ค์์๋ TBPTT(Truncated BPTT)๋ก ๋ช ์์ ๋ง ๊ฑฐ๊พธ๋ก ์ ํํ๋ ๊ทผ์ฌ๊ฐ ์์ฃผ ์ฐ์ ๋๋ค.
๐ฎ ์ธํฐ๋ํฐ๋ธ: RNN ํผ์นจ ์๊ฐํ
๊ฐ์ ์ ์ด ์๊ฐ์ ๋ฐ๋ผ 5๋ฒ ๋ฐ๋ณต๋๋ ๋ชจ์ต์ ๋ด ๋๋ค. ์ฌ๋ผ์ด๋๋ก ์์ ์ ์ฎ๊ธฐ๋ฉด ๊ทธ ์๊ฐ์ ์๋ ์ํ๊ฐ ๊ฐ์กฐ๋ฉ๋๋ค.
3. ์ฅ๊ธฐ ์์กด์ฑ ๋ฌธ์ โ ๊ธฐ์ธ๊ธฐ๊ฐ ์ฌ๋ผ์ง๋ ์ด์
RNN์ ์์์ ์ฐ์ํ์ง๋ง ํ์ค์์ ์ฌ๊ฐํ ๋ฌธ์ ๊ฐ ์์ต๋๋ค. ์๋ฌธ: "๋๋ ์ด์ ๋์๊ด์์ ์น๊ตฌ๋ค๊ณผ ๊ณต๋ถํ๊ณ ์๋ก ๋์จ ์์ค์ ๋น๋ ค ์ง์ ์์ ์ ๋ ์ ๋จน๊ณ ์ ๋ค๊ธฐ ์ ์ ํ ์๊ฐ ๋์ ๊ทธ ์ฑ ์ ์ฝ์๋ค." ๋ง์ง๋ง ๋์ฌ "์ฝ์๋ค"๊ฐ ๊ณผ๊ฑฐํ์ธ ๊ฒ์ ๋ฌธ์ฅ ๋งจ ์์ "์ด์ "์ ๊ด๋ จ์ด ์์ต๋๋ค. ์ด ๊ฑฐ๋ฆฌ๊ฐ 20๋จ์ด ์ด์.
์ด๋ก ์ ์ผ๋ก RNN์ $h_t$์ ๋ชจ๋ ๊ณผ๊ฑฐ ์ ๋ณด๋ฅผ ๋ฃ์ ์ ์์ง๋ง, ์ค์ ์์๋ 5~10๋จ์ด ์ด์ ๋จ์ด์ง ์ ๋ณด๋ฅผ ๊ธฐ์ตํ์ง ๋ชปํฉ๋๋ค. ์ ๊ทธ๋ด๊น์?
BPTT์ Jacobian ๊ณฑ์ ๋ค์ ๋ด ์๋ค:
$$ \prod_{j=k+1}^{t} \frac{\partial h_j}{\partial h_{j-1}} = \prod_{j=k+1}^{t} W_h^\top \text{diag}(\tanh'(z_j)) $$์ด ๊ณฑ์ด $t - k$์ ์ง์ํจ์์ ๋๋ค. ๋ ๊ฒฝ์ฐ๋ก ๋๋ฉ๋๋ค:
- ๊ธฐ์ธ๊ธฐ ์์ค(vanishing) โ ๊ฐ์ค์น์ ์ต๋ ํน์ด๊ฐ์ด 1๋ณด๋ค ์์ผ๋ฉด ๊ณฑ์ด $0$์ผ๋ก ์ง์ ๊ฐ์ . ๋จผ ๊ณผ๊ฑฐ์ ๋จ์๊ฐ ํ์ฌ ์ ๋ฐ์ดํธ์ ์ ํ ์ํฅ์ ๋ชป ์ค. $\tanh'$์ ์ต๋๊ฐ์ด 1์ด๊ณ ๋๋ถ๋ถ 0๋ณด๋ค ์์ผ๋ฏ๋ก ์ด ์ชฝ์ด ๋ ํํจ.
- ๊ธฐ์ธ๊ธฐ ํญ๋ฐ(exploding) โ ๋ฐ๋๋ก ๊ณฑ์ด $\infty$๋ก ๋ฐ์ฐ. ์ ๋ฐ์ดํธ๊ฐ ๋๋ฌด ์ปค์ NaN์ด ๋จ๊ฑฐ๋ ๋ฐ์ฐ. ๊ทธ๋๋์ธํธ ํด๋ฆฌํ(gradient clipping)์ผ๋ก ์ด๋ ์ ๋ ์ํ ๊ฐ๋ฅ.
์ด ๋ฌธ์ ๋ W8 ยง2 ๊ธฐ์ธ๊ธฐ ์์ค์์ ๋ณธ ํ์๊ณผ ๋ณธ์ง์ ์ผ๋ก ๊ฐ์ต๋๋ค. ๋ค๋ง ์๊ฐ์ถ์์ ๋ฐ์ํ ๋ฟ. 1994๋ Bengio ๋ฑ์ด ์ด ๋ฌธ์ ๋ฅผ ์ํ์ ์ผ๋ก ๋ถ์ํด "vanilla RNN์ ์ฅ๊ธฐ ์์กด์ฑ์ ๋ณธ์ง์ ์ผ๋ก ํ์ตํ๊ธฐ ์ด๋ ต๋ค"๋ ๋ถ์ ์ ๊ฒฐ๋ก ์ ๋ด๋ฆฌ๋ฉด์ RNN ์ฐ๊ตฌ๋ ํ๋์ ์นจ์ฒด๋์์ต๋๋ค.
4. LSTM โ ์ ๋ณด์ ๊ณ ์๋๋ก
ํด๋ต์ ๋๋๊ฒ๋ Bengio์ ๋ถ์๊ณผ ๊ฑฐ์ ๊ฐ์ ์๊ธฐ์ ์ด๋ฏธ ์ ์๋์ด ์์์ต๋๋ค. 1997๋ Sepp Hochreiter์ Jรผrgen Schmidhuber๋ ๋ ผ๋ฌธ "Long Short-Term Memory"์์ RNN ์ ์ ์์ ํ ์ฌ์ค๊ณํ LSTM์ ๋ฐํํ์ต๋๋ค. ํต์ฌ ์์ด๋์ด๋ ํ๋์ ๋๋ค: ์๋ ์ํ ์ธ์ "์ ์ํ(cell state) $C_t$"๋ผ๋ ๋ณ๋์ ์ ๋ณด ํต๋ก๋ฅผ ๋๊ณ , ์ด ํต๋ก๋ ๊ณฑ์ ์ด ์๋ ๋ง์ ์ผ๋ก ์ ๋ฐ์ดํธ๋๊ฒ ํ์.
๋ง์ ์ ๊ณฑ์ ๊ณผ ๋ฌ๋ฆฌ ๊ธฐ์ธ๊ธฐ๋ฅผ ์์ค์ํค์ง ์์ต๋๋ค. $\frac{\partial C_t}{\partial C_{t-1}} = 1$ (๊ฒ์ดํธ๊ฐ ์ ์ ํ ์ค์ ๋๋ฉด)์ด๋ฉด ๊ธฐ์ธ๊ธฐ๊ฐ ์๋ฐฑ ์์ ์ ๊ฑฐ์ฌ๋ฌ ์ฌ๋ผ๊ฐ๋ ๊ฑฐ์ ๊ทธ๋๋ก ์ ์ง๋ฉ๋๋ค. ์ด๊ฒ์ด LSTM์ "์ ๋ณด์ ๊ณ ์๋๋ก"๋ผ ๋ถ๋ฅด๋ ์ด์ . ์ ์ํ๊ฐ ๊ณ ์๋๋ก๊ณ , ๊ฒ์ดํธ๊ฐ ๊ทธ ๊ณ ์๋๋ก์ ๋ฌด์์ ์ฌ๋ฆฌ๊ณ ๋ด๋ฆด์ง ์ ์ดํฉ๋๋ค.
LSTM์ ์ธ ๊ฐ์ ๊ฒ์ดํธ๋ก ์ ์ํ๋ฅผ ์ ์ดํฉ๋๋ค:
- ๋ง๊ฐ ๊ฒ์ดํธ(forget gate) $f_t$: ๊ณผ๊ฑฐ ์ ์ํ $C_{t-1}$์์ ๋ฌด์์ ์ง์ธ์ง. ๊ฐ์ด 0์ด๋ฉด ์์ ํ ์๊ณ , 1์ด๋ฉด ์์ ํ ์ ์ง. ์๊ทธ๋ชจ์ด๋๋ฅผ ์จ์ $[0, 1]$ ๋ฒ์๋ก ๋ง๋ญ๋๋ค.
- ์ ๋ ฅ ๊ฒ์ดํธ(input gate) $i_t$: ํ์ฌ ์์ ์ ์์ฑ๋ ์ ํ๋ณด $\tilde C_t$ ์ค ๋ฌด์์ ์ ์ ์๋ก ์ธ์ง. ์ญ์ ์๊ทธ๋ชจ์ด๋.
- ์ถ๋ ฅ ๊ฒ์ดํธ(output gate) $o_t$: ์ ๋ฐ์ดํธ๋ ์ ์ํ ์ค ๋ฌด์์ ์ธ๋ถ(์๋ ์ํ $h_t$)์ ๋ณด์ผ์ง. ์๊ทธ๋ชจ์ด๋.
๐ฎ ์ธํฐ๋ํฐ๋ธ: LSTM ๊ฒ์ดํธ ๊ฐ ์กฐ์
์ธ ๊ฒ์ดํธ ๊ฐ์ ์ง์ ์์ง์ด๋ฉฐ ์ ์ํ๊ฐ ์ด๋ป๊ฒ ๋ณํ๋์ง ๋ณด์ธ์. ๋ง๊ฐ์ด 0์ด๋ฉด ๊ณผ๊ฑฐ๋ฅผ ์์ ํ ์ง์ฐ๊ณ , ์ ๋ ฅ์ด 0์ด๋ฉด ์ ์ ๋ณด๋ฅผ ๋ฐ์ง ์์ต๋๋ค.
4.1 LSTM ์์ ํ ์ค์ฉ ์ดํดํ๊ธฐ
LSTM์ ์ ์ฒด ์ ๋ฐ์ดํธ ์์ ๋ค์ ์ ์ด๋ด ์๋ค:
$$ f_t = \sigma(W_f [h_{t-1}, x_t] + b_f) $$ $$ i_t = \sigma(W_i [h_{t-1}, x_t] + b_i) $$ $$ o_t = \sigma(W_o [h_{t-1}, x_t] + b_o) $$ $$ \tilde C_t = \tanh(W_C [h_{t-1}, x_t] + b_C) $$ $$ C_t = f_t \odot C_{t-1} + i_t \odot \tilde C_t $$ $$ h_t = o_t \odot \tanh(C_t) $$$[h_{t-1}, x_t]$๋ ๋ ๋ฒกํฐ๋ฅผ ์ด์ด๋ถ์ธ ๊ฒ์ด๊ณ , $\odot$๋ ์์๋ณ ๊ณฑ. ๊ฐ ์์ ์๋ฏธ๋ฅผ ํ ๋ฌธ์ฅ์ฉ:
- $f_t$: "๊ณผ๊ฑฐ์์ ๋ฌด์์ ์์๊น?" โ 0์ด๋ฉด ์์ ํ ์๊ณ 1์ด๋ฉด ์์ ํ ๊ธฐ์ต.
- $i_t$: "์ง๊ธ ๋ฌด์์ ์๋ก ๋ฐฐ์ธ๊น?" โ ์ง๊ธ ์ด ์๊ฐ์ ์ ๋ณด ์ค ์ ์ ์ธ ๊ฒ์ ๋น์จ.
- $\tilde C_t$: "์๋ก ์ธ ํ๋ณด ๊ฐ" โ $\tanh$๋ก $[-1, 1]$ ๋ฒ์.
- $C_t = f_t \odot C_{t-1} + i_t \odot \tilde C_t$: "์ ์ ์ํ์์ ์์ ๊ฑด ์ง์ฐ๊ณ ์๋ก ๋ฐฐ์ธ ๊ฑด ๋ํ๊ธฐ". ์ด ๋ง์ ๊ตฌ์กฐ๊ฐ ๊ธฐ์ธ๊ธฐ ์์ค์ ๋ง๋ ํต์ฌ.
- $o_t$: "์ ์ ๋ด์ฉ ์ค ์ด๋๋ฅผ ๋ฐ์ผ๋ก ๋ณด๋ผ๊น?"
- $h_t = o_t \odot \tanh(C_t)$: "์ ์ํ๋ฅผ ์ค์ผ์ผ ์กฐ์ ํด ์ธ๋ถ๋ก ์ถ๋ ฅ".
4.2 ์ ๊ธฐ์ธ๊ธฐ๊ฐ ๋ ์ด์ ์ฌ๋ผ์ง์ง ์๋๊ฐ
ํต์ฌ์ $C_t = f_t \odot C_{t-1} + i_t \odot \tilde C_t$์ $\partial C_t / \partial C_{t-1} = f_t$. ๋ง๊ฐ ๊ฒ์ดํธ์ ๊ฐ์ ๋๋ค. ์ค์ํ ๊ด์ฐฐ:
- $f_t$๊ฐ 1์ ๊ฐ๊น์ฐ๋ฉด $\partial C_t / \partial C_{t-1} \approx 1$. ๊ธฐ์ธ๊ธฐ๊ฐ ๊ทธ๋๋ก ์ ๋ฌ๋จ.
- $\tanh$ ์ฒ๋ผ 0์ ๊ฐ๊น์ด ๊ฐ์ด ์๋๋ผ ๋คํธ์ํฌ๊ฐ ํ์ต์ผ๋ก 1 ๊ทผ์ฒ ๊ฐ์ ์ ํํ ์ ์์.
- ๋ฐ๋ผ์ "๊ธฐ์ตํ ๊ฐ์น๊ฐ ์๋ ์ ๋ณด"๋ $f_t = 1$๋ก ์ ์ง, "์์ ์ ๋ณด"๋ $f_t = 0$์ผ๋ก ์ง์.
์ฆ LSTM์ ์ค์ค๋ก "์ผ๋ง๋ ์ค๋ ๊ธฐ์ตํ ์ง"๋ฅผ ๋ฐ์ดํฐ๋ก๋ถํฐ ํ์ตํฉ๋๋ค. ์ด ๋ฅ๋ ฅ์ด ์๋ vanilla RNN๊ณผ์ ๊ฒฐ์ ์ ์ฐจ์ด์ ๋๋ค.
5. GRU โ ๋จ์ํจ์ ๋ฏธ๋
2014๋ Kyunghyun Cho ๋ฑ์ด ์ ์ํ GRU(Gated Recurrent Unit)๋ LSTM์ ์ธ ๊ฒ์ดํธ๋ฅผ ๋ ๊ฐ๋ก ์ค์ธ ๋จ์ํ ๋ฒ์ ์ ๋๋ค. ํต์ฌ ์์ด๋์ด:
- ์ ์ํ $C$์ ์๋ ์ํ $h$๋ฅผ ํ๋๋ก ํฉ์นจ (์ ์ํ ์์).
- ๋ง๊ฐ ๊ฒ์ดํธ์ ์ ๋ ฅ ๊ฒ์ดํธ๋ฅผ ํ๋์ ์ ๋ฐ์ดํธ ๊ฒ์ดํธ $z_t$๋ก ํฉ์นจ. "์ง์ฐ๋ ๋งํผ๋ง ์๋ก ์ฐ์" โ $f + i = 1$์ ์ ์ฝ.
- ์ ๋ฆฌ์ ๊ฒ์ดํธ $r_t$๋ "๊ณผ๊ฑฐ ์ํ๋ฅผ ์ผ๋ง๋ ๋ค์ ๋ณผ์ง"๋ฅผ ๊ฒฐ์ .
GRU์ ํ๋ผ๋ฏธํฐ ์๋ LSTM์ ์ฝ 3/4์ ๋๋ค. ์ฑ๋ฅ์ ์์ ์ ๋ฐ๋ผ ๋น์ทํ๊ฑฐ๋ ์ด์ง ๋ค๋ฅธ๋ฐ, ์ผ๋ฐ์ ์ผ๋ก ์์ ๋ฐ์ดํฐ์ ์์๋ GRU๊ฐ ์ ๋ฆฌํ๊ณ (๊ณผ์ ํฉ์ด ๋ํจ), ํฐ ๋ฐ์ดํฐ์ ์์๋ LSTM์ด ์ฝ๊ฐ ๋ ์ ์๋ํ๋ ๊ฒฝํฅ์ด ์์ต๋๋ค. ์ค๋ฌด์์๋ ๋ ๋ค ์๋ํด๋ณด๊ณ ๊ฒ์ฆ ์ฑ๋ฅ์ด ์ข์ ์ชฝ์ ์ ํํฉ๋๋ค.
์ดํ Transformer(W13)๊ฐ ๋ฑ์ฅํ๋ฉด์ RNN ๊ณ์ด์ ์์ฐ์ด ์ฒ๋ฆฌ ์ฃผ๋ฅ์์ ๋ฌผ๋ฌ๋ฌ์ง๋ง, ์ฌ์ ํ ์ผ์ ์ ํธ ์ฒ๋ฆฌ, ์๊ณ์ด ์์ธก, ์์ฑ ์ธ์ ๊ฐ์ ํน์ ๋ถ์ผ์์๋ LSTM/GRU๊ฐ ๊ธฐ๋ณธ ์ ํ์ ๋๋ค. ํนํ ์ฃ์ง ๋๋ฐ์ด์ค๋ ์ค์๊ฐ ์ฒ๋ฆฌ๊ฐ ํ์ํ ํ๊ฒฝ์์๋ Transformer์ $O(n^2)$ ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋ถ๋ด์ด๋ผ RNN์ด ์ฌ์ ํ ์ ํธ๋ฉ๋๋ค.
6. ์๊ณ์ด ์์ธก ๋ฐ๋ชจ
๐ฎ ์ธํฐ๋ํฐ๋ธ: ์ฌ์ธํ ์์ธก
๊ฐ๋จํ 1์ RNN์ด ์ฌ์ธํ๋ฅผ ํ์ตํฉ๋๋ค. ์งํญ๊ณผ ์ฃผํ์๋ฅผ ์กฐ์ ํด๋ณด์ธ์(์๋ฎฌ๋ ์ด์ ).
7. ์ฝ๋ ์์ (PyTorch)
import torch.nn as nn
class LSTMNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
out, (h, c) = self.lstm(x)
return self.fc(out[:, -1, :])
๐ ๋ ๊น์ด ๊ณต๋ถํ๊ธฐ
- Sung Kim Lec 12 โ RNN/LSTM ํ๊ตญ์ด ๊ฐ์.
- Understanding LSTM Networks โ Chris Olah, colah.github.io. ์ด ๋ถ์ผ์ ๊ณ ์ ๊ธ.
- LSTM ์๋ ผ๋ฌธ โ HochreiterยทSchmidhuber(1997). ์ธ๋ด์ฌ์ ๊ฐ์ง๊ณ ๋ณด์ธ์.
- The Unreasonable Effectiveness of RNNs โ Andrej Karpathy.