Vanishing gradient và long-term, short-term dependency
Trong bài viết về mô hình RNN truyền thống, mình đã có đề cập đến vấn đề vanishing gradient của nó dựa vào công thức của quá trình BPPT (Back-propagation Through Time). Hệ quả của vấn đề này là RNN gặp khó khăn trong việc ghi nhớ thông tin trong những câu có nhiều từ.
Để minh họa rõ hơn về hệ quả của vanishing gradient, ta xét ví dụ với mô hình RNN dùng để sinh ra văn bản. Giả sử đoạn văn bản đang được sinh ra như sau:
- Trưa hôm nay, trời đã mưa rất to và tôi thì lại để quên áo mưa ở nhà. Vì sao quên thì là do sáng nay ngủ dậy muộn nên tôi chỉ tập trung nhanh chóng vệ sinh cá nhân, soạn sách vở rồi ăn sáng để đến lớp thôi. …(vài câu gì đó nữa)…. Kết quả là lúc về đến nhà, cả người tôi đã bị _
Từ tiếp theo được sinh ra ở vị trí của kí tự _ lúc này nên là “ướt”, nhưng những thông tin liên quan đến vấn đề bị ướt này thì lại cách vị trí hiện tại rất xa, ở tận phía đầu của đoạn văn bản (trời mưa, quên áo mưa). Khi đó, RNN sẽ khó mà nhớ được những chi tiết này, dẫn đến từ sinh ra sẽ không phù hợp.
Ta có thể gọi sự phụ thuộc giữa từ “ướt” nên được sinh ra và các chi tiết ở đầu đoạn văn là long-term dependency. Mô hình cần phải nhớ được những chi tiết đó thì ở sau nó mới có thể sinh ra được từ hợp lý. Như vậy, vì gặp vấn đề vanishing gradient mà mô hình RNN truyền thống gặp khó khăn trong việc ghi nhớ các long-term dependency. Đây là một điểm yếu rõ rệt nhất của RNN.
Ngược với long-term thì ta có short-term dependency. Sự phụ thuộc này chỉ những mối tương quan giữa những từ ở gần nhau trong đoạn văn bản. Với RNN truyền thống thì nó hoàn toàn có thể nhớ được các sự phụ thuộc này.
Long Short-Term Memory (LSTM, 1997) và Gated Recurrent Unit (GRU, 2014) là các cải tiến của mô hình RNN truyền thống, nhằm tập trung khắc phục điểm yếu của nó trong vấn đề ghi nhớ các long-term dependency.
Lưu ý.
- Trước khi đi đến các phần sau, ta quy ước rằng output của các cell (LSTM cell, GRU cell) tại thời điểm $t$ sẽ được ký hiệu là $\bold{Y}_t$ (tránh nhầm với $\bold{O}_t$ trong RNN truyền thống). Nói chung là thay O thành Y 😜
- Trong bài viết này, mình cũng sẽ bỏ qua các giá trị bias, tương tự như trong bài viết về RNN.
Long Short-Term Memory (LSTM)
Ý tưởng về internal state
Trước tiên, ta nhắc lại quá trình feed-forward của RNN truyền thống một chút. Tại thời điểm $t$, từ input $\bold{X}_t$ và hidden state $\bold{H}_{t-1}$ thì ta có
$$ \bold{H}_t = \phi_h (\bold{W}_{xh} \bold{X}_t + \bold{W}_{hh} \bold{H}_{t-1})$$ $$\bold{Y}_t = \phi_y (\bold{W}_{hy} \bold{H}_t) $$
Hidden state $\bold{H}_t$ chính là thành phần “nhớ” các short-term dependency trong RNN. Hơn nữa, cũng vì lý do ta tính $\bold{H}_t$ dựa vào $\bold{H}_{t-1}$ theo công thức như trên nên RNN mới gặp vấn đề vanishing gradient descent.
Ý tưởng của Long Short-Term Memory (LSTM) xuất phát từ việc xây dựng thêm một thành phần tại mỗi thời điểm để ghi nhớ long-term dependency, nó được gọi là internal state. Đối với short-term thì ta vẫn sẽ ghi nhớ chúng bằng hidden state như trong RNN truyền thống.
Kí hiệu internal state tại thời điểm $t$ là $\bold{C}_t$. Ta có các thao tác liên quan đến $\bold{C}_t$ như sau:
-
Cập nhật internal state $\bold{C}_t$:
-
Loại bỏ một số thông tin không cần thiết trong long-term dependency nhớ được từ $t-1$ thời điểm trước (hay là quên bớt thông tin trong $\bold{C}_{t-1}$)
-
Thêm vào các thông tin cần thiết từ các input của thời điểm hiện tại vào internal state (hay là cập nhật thêm thông tin). Ta thường kí hiệu lượng thông tin này là $\tilde{\bold{C}}_t$.
- Input của thời điểm hiện tại bao gồm $\bold{H}_{t-1}$ (short-term dependency) và $\bold{X}_t$
- $\tilde{\bold{C}}_t$ còn được gọi là candiate internal state.
-
Trong đó, ta có hai giá trị điều chỉnh tỉ lệ loại bỏ và thêm vào tại thời điểm $t$, nó sẽ kiểu như
$$ \bold{C}_t = \alpha_t \bold{C}_{t-1} + \beta_t \tilde{\bold{C}}_t $$
-
-
Từ internal state $\bold{C}_t$, ta chắt lọc các thông tin có vai trò như là những short-term dependency mà mô hình nên nhớ ở thời điểm hiện tại, tức là tính ra $\bold{H}_t$.
-
Ngoài ra, nếu cần tính ra output $\bold{Y}_t$ thì ta cũng sẽ dựa vào $\bold{C}_t$.
Như vậy, input của LSTM cell tại thời điểm $t$ sẽ có tổng cộng 3 phần là $\bold{X}_t$, $\bold{H}_{t-1}$, $\bold{C}_{t-1}$ và output sẽ bao gồm $\bold{H}_t$, $\bold{C}_t$ (có thể có thêm $\bold{Y}_t$).
Nhận xét.
- Trong LSTM, short-term dependency và long-term dependency được tách ra và nó sử dụng hai cổng để ghi nhớ.
- Từ công thức tính $\bold{C}_t$ ở trên, nếu $\alpha_t = 0$ thì có nghĩa là ta sẽ quên hết các thông tin phía trước luôn, chỉ tập trung hiện tại thôi, còn $\beta_t = 0$ thì xem như ta không quan tâm hiện tại, chỉ dùng đúng những gì đã biết trong quá khứ. Thật ra $\alpha_t$ và $\beta_t$ là các ma trận và phép nhân được thực hiện là element-wise.
- Nhờ tính internal state $\bold{C}_t$ theo ý tưởng của LSTM, ta đã có thể hạn chế vấn đề vanishing gradient (hạn chế thôi, vẫn có thể gặp phải nhưng hiếm hơn 😀)
Forget gate, input gate, output gate và candidate internal state
Trong các thao tác liên quan đến $\bold{C}_t$ ở trên, quan trọng nhất là các chi tiết về loại bỏ, thêm vào và chắt lọc. Chúng sẽ lần lượt ứng với ba “cổng” là forget gate $\bold{F}_t$, input gate $\bold{I}_t$ và output gate $\bold{O}_t$trong LSTM cell.
- Lưu ý rằng các giá trị này là tỉ lệ, liên quan đến thao tác loại bỏ, thêm vào và chắt lọc các thông tin truyền từ bên ngoài vào nên ta sẽ dùng activation function $\sigma$ (sigmoid)
Ngoài ra, thao tác tính toán thông tin candidate internal state $\tilde{\bold{C}}_t$ dựa vào $\bold{X}_t$ và $\bold{H}_{t-1}$ sẽ được biểu diễn thông qua thành phần input node. Activation function được dùng để tính giá tị này là $\tanh$.
Khi đó, những thành phần này được biểu diễn trong LSTM cell như sau:
Minh họa các thành phần trong LSTM cell
Nguồn: Dive into DL
Như vậy thì ta đã có kha khá ký hiệu được sử dụng để biểu diễn cho các giá trị. Điều này cũng có nghĩa là sẽ có rất nhiều ma trận trọng số 😀 Cụ thể hơn, với mỗi thành phần $\bold{F}_t$, $\bold{I}_t$, $\bold{O}_t$ và $\tilde{\bold{C}}_t$ thì ta sẽ có hai ma trận trọng số, ví dụ như $\bold{W}_{xf}, \bold{W}_{hf}$ đối với thành phần $\bold{F}_t$.
Quá trình feed-forward
Tổng quan quá trình feed-forward trong LSTM cell được thể hiện trong hình ảnh bên dưới
Quá trình feed-forward trong LSTM cell
Nguồn: Dive into DL
Ta sẽ có khá nhiều phép tính, trước hết là tính các “cổng” $\bold{F}_t$, $\bold{I}_t$, $\bold{O}_t$ dựa vào $\bold{X}_t$ và $\bold{H}_{t-1}$:
$$ \mathbf{I}_t = \sigma( \mathbf{W}_{xi}\mathbf{X}_t + \mathbf{W}_{hi}\mathbf{H}_{t-1} )$$ $$\mathbf{F}_t = \sigma( \mathbf{W}_{xf} \mathbf{X}_t + \mathbf{W}_{hf} \mathbf{H}_{t-1})$$ $$\mathbf{O}_t = \sigma( \mathbf{W}_{xo}\mathbf{X}_t + \mathbf{W}_{ho} \mathbf{H}_{t-1}) $$
Bên cạnh đó, ta tính $\tilde{\bold{C}}_t$ bằng công thức
$$ \tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{W}_{xc} \mathbf{X}_t + \mathbf{W}_{hc} \mathbf{H}_{t-1}) $$
Từ đó, internal state $\bold{C}_t$ và hidden state $\bold{H}_t$ sẽ được tính như sau:
$$ \mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t$$ $$\mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t) $$
Nếu ở thời điểm này cần tính ra output $\bold{Y}_t$ thì sẽ tính theo $\bold{C}_t$ cùng với ma trận trọng số $\bold{W}_{cy}$.
Gated Recurrent Units
Ý tưởng đơn giản hóa LSTM
Gated Recurrent Units (GRU - 2014) là có thể nói là một phiên bản tối ưu hơn của LSTM về mặt độ phức tạp, trong khi về hiệu năng thì có thể nói là hai mô hình này ngang ngửa nhau. Do đó, ta thường thấy GRU được sử dụng nhiều hơn.
Ý tưởng mấu chốt của GRU là chỉ sử dụng hidden state $\bold{H}_t$ để vừa nhớ cả short-term và long-term dependency. Trong khi đó, ở LSTM thì ta có sự phân tách giữa hai thông tin này.
Bên cạnh đó, nhìn vào công thức tính internal state $\bold{C}_t$ dựa vào hai giá trị tỉ lệ $\bold{F}_t$ và $\bold{I}_t$ trong LSTM cell, ta có thể có “cảm giác” là thông thường thì tổng của chúng hay bằng 1, nên là thôi bỏ một cái đi 😀 GRU làm theo đúng như thế.
Ngoài ra, trong LSTM có candidate internal state $\tilde{\bold{C}}_t$ dùng để tính ra các thông tin cần thiết từ các input $\bold{X}_t$ và $\bold{H}_{t-1}$. Với GRU thì ta cũng có thành phần tương tự là candidate hidden state $\tilde{\bold{H}}_t$ với cùng mục đích như thế.
Reset gate, update gate và candidate hidden state
Thay vì sử dụng ba “cổng” như LSTM thì GRU sẽ dùng hai là reset gate $\bold{R}_t$ và update gate $\bold{Z}_t$. Một thành phần nữa cũng rất quan trọng trong GRU là candidate hidden state $\tilde{\bold{H}}_t$. Trong đó:
- $\bold{R}_t$ sẽ đóng vai trò loại bỏ một số thông tin không cần thiết về các short-term dependency trong hidden state của thời điểm trước là $\bold{H}_{t-1}$. Ta sẽ dùng nó để tính $\tilde{\bold{H}}_t$. Do đó, candidate hidden state $\tilde{\bold{H}}_t$ sẽ chứa các thông tin có ích về short-term dependency.
- $\bold{Z}_t$ sẽ thay thế cho cả $\bold{F}_t$ và $\bold{I}_t$ trong việc điều chỉnh tỉ lệ loại bỏ thông tin không cần thiết về long-term dependency trong $\bold{H}_{t-1}$ và thêm vào các thông tin cần thiết về short-term dependency (chính là $\tilde{\bold{H}}_t$).
Minh họa các thành phần trong GRU cell
Nguồn: Dive into DL
Ta thấy rằng, số ma trận trọng số cần sử dụng trong GRU cell sẽ ít hơn LSTM cell hai ma trận. Cụ thể hơn thì với mỗi thành phần $\bold{R}_t$, $\bold{Z}_t$ và $\tilde{\bold{H}}_t$ thì ta đều cần hai ma trận trọng số.
Quá trình feed-forward
Tổng quan quá trình feed-forward trong GRU cell được thể hiện trong hình ảnh bên dưới.
Minh họa các thành phần trong GRU cell
Nguồn: Dive into DL
Trước hết, ta tính các “cổng” $\bold{R}_t$, $\bold{Z}_t$ dựa vào $\bold{X}_t$ và $\bold{H}_{t-1}$:
$$ \mathbf{R}_t = \sigma( \mathbf{W}_{xr} \mathbf{X}_t + \mathbf{W}_{hr} \mathbf{H}_{t-1})$$ $$\mathbf{Z}_t = \sigma(\mathbf{W}_{xz} \mathbf{X}_t + \mathbf{W}_{hz} \mathbf{H}_{t-1} ) $$
Từ $\bold{R}_t$, ta tính được $\tilde{\bold{H}}_t$ như sau:
$$ \tilde{\mathbf{H}}_t = \tanh( \mathbf{W}_{xh} \mathbf{X}_t + \mathbf{W}_{hh}\left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) ) $$
Sử dụng $\bold{Z}_t$ và $\tilde{\bold{H}}_t$, hidden state tại thời điểm $t$ được xác định bằng
$$ \bold{H}_t = \bold{Z}_t \odot \bold{H}_{t-1} + (1 - \bold{Z}_t) \odot \tilde{\bold{H}}_t $$