Deep RNN và Bidirectional RNN

Giới thiệu hai biến thể phổ biến của mô hình RNN trong các bài toán là Deep RNN và Bidirectional RNN

Giới thiệu

Qua các bài viết về RNN truyền thống, LSTM và GRU thì mình đều trình bày về các mô hình với duy nhất một cell trong kiến trúc (recurrent cell, LSTM cell hoặc là GRU cell). Ngoài ra, ta thấy các hidden state cũng được truyền theo một hướng cố định là từ trái sang phải (thời điểm $t$ đến thời điểm $t + 1$). Nếu bỏ qua chi tiết về các “thời điểm” thì nhìn chúng sẽ không khác gì một mô hình MLP cơ bản trong Machine Learning.

Các dạng mô hình RNN truyền thống
Nguồn: Javatpoint

Đây chỉ mới là sự khởi đầu của RNN. Để đạt được hiệu năng tốt nhất có thể trong các bài toán, ta cần phải có những cải tiến nhất định. Hai cải tiến, hay là hai biến thể, phổ biến của RNN mà mình giới thiệu trong bài viết này là Deep RNN (dùng nhiều cell trong kiến trúc) và Bidirectional RNN (truyền hidden state theo cả hai hướng).

Lưu ý. Để cho đơn giản, các cell được sử dụng trong kiến trúc mô hình mà mình trình bày bên dưới đều là recurrent cell. Ta hoàn toàn có thể thay thế nó bằng LSTM cell, GRU cell.

Deep RNN

Tất nhiên, Deep Learning mà, dùng nhiều cell (hay là hidden layer) ngay 😜 Trong các mô hình RNN mà mình đã trình bày cho đến trước bài viết này thì chúng chỉ có duy nhất một recurrent cell và cell này cứ nhận vào input, tính ra hidden state và output. Nếu chúng ta dùng nhiều cell thì sao?

Kết quả sẽ có dạng như hình bên dưới. Trong đó, input của recurrent cell thứ $l$ là hidden state của recurrent cell thứ $l - 1$.

Minh họa mô hình Deep RNN
Nguồn: Dive into DL

Ta kí hiệu:

  • $L$ là số recurrent cell của mô hình
  • Tại recurrent cell thứ $l$ thì
    • $\bold{H}_t^{(l)} \in \mathbb{R}^{h}$ là hidden state tại thời điểm $t$ (quy ước $\bold{H}_t^{(0)} = \bold{X}_t$)
    • Các ma trận trọng số lần lượt là $(\bold{W}_{xh}^{(l)}, \bold{W}_{hh}^{(l)})$ (với $l < L$).
    • Activation function dùng để tính hidden state là $\phi_l$.
  • $\bold{O}_t \in \mathbb{R}^{o}$ là output tại thời điểm $t$ của mô hình.
  • Tại recurrent cell thứ $L$, ma trận trọng số để tính ra output là $\bold{W}_{ho}$ và activation function là $\phi_o$.

Khi đó, quá trình feed-forward trong Deep RNN được mô tả như sau: Tại thời điểm $t$ thì

  • Qua từng recurrent cell thứ $l = 1, 2, …, L$, ta có

$$\bold{H}_t^{(l)} = \phi_l(\bold{W}_{xh}^{(l)} \bold{H}_t^{(l-1)} + \bold{W}_{hh}^{(l)}\bold{H}_{t-1}^{(l)})$$

  • Output tại thời điểm $t$ là

$$\bold{O}_t = \phi_o (\bold{W}_{ho} \bold{H}_t^{(L)})$$

Bidirectional RNN (BiRNN)

Ý tưởng của Bidirectional RNN (BiRNN) rất là tự nhiên và giống với cách con người đọc hiểu ngôn ngữ. Đầu tiên, ta xét bài toán Name Entity Recognition với câu sau:

  1. Can you see that? Teddy bears are on sales.
  2. He said that Teddy Rooosevelt was a great president.

Ở câu (2) thì ta có thể gán cho Teddy thuộc lớp Name, và nó đúng là tên của một người thật. Tuy nhiên, trong câu (1) mà gán như thế là sai. Để gán đúng với câu (1) thì ta cần biết được từ phía sau đó nữa (và ta phải gán nguyên cụm Teddy bears). Như vậy, RNN truyền thống sẽ thất bại trong ví dụ (1), vì khi xét tới Teddy thì ta chưa có bất kì thông tin gì về các từ phía sau nó.

  • Theo cách con người đọc hiểu ngôn ngữ, ở câu (1) thì ta cũng cần phải đọc thêm từ “bears” ở phía sau để biết được từ “Teddy” ở trước mang ý nghĩa gì.

Như vậy, trong BiRNN, các trạng thái ẩn sẽ được truyền theo cả hai chiều (xuôi và ngược). Kiến trúc của nó sẽ có dạng như hình bên dưới (để cho đơn giản thì ta chỉ xét với một recurrent cell 😜).

Minh họa mô hình BiRNN
Nguồn: Dive into DL

Ta kí hiệu:

  • Trong mỗi hướng truyền xuôi và ngược thì:
    • Hướng truyền xuôi: Hidden state là $\overrightarrow{\mathbf{H}}_t$ và hai ma trận trọng số là $(\bold{W}_{xh}^{(f)}, \bold{W}_{hh}^{(f)})$.
    • Hướng truyền ngược: Hidden state là $\overleftarrow{\mathbf{H}}_t$ và hai ma trận trọng số là $(\bold{W}_{xh}^{(b)}, \bold{W}_{hh}^{(b)})$.
  • Ma trận trọng số để tính ra output là $\bold{W}_{ho}$.

Quá trình feed-forward của BiRNN sẽ diễn ra như sau:

  • Lần lượt theo các hướng truyền xuôi, ta tính được hidden state theo các công thức

$$\overrightarrow{\bold{H}}_t = \phi_h(\bold{W}_{xh}^{(f)}\bold{X}_t + \bold{W}_{hh}^{(f)}\overrightarrow{\bold{H}}_{t-1})$$

$$\overleftarrow{\bold{H}}_t = \phi_h(\bold{W}_{xh}^{(b)}\bold{X}_t + \bold{W}_{hh}^{(b)}\overleftarrow{\bold{H}}_{t+1})$$

  • Sau khi đã tính xong hidden state tại toàn bộ các thời điểm, ta tính output:

$$\bold{O}_t = \phi_o \left ( \bold{W}_{ho} \left [ \overrightarrow{\bold{H}}_t , \overleftarrow{\bold{H}}_t \right ] \right )$$

, trong đó $\left [ \overrightarrow{\bold{H}}_t , \overleftarrow{\bold{H}}_t \right ]$ nghĩa là nối hai hidden state với nhau (concatenate).

Nhận xét.

  • Qua quá trình feed-forward của BiRNN, ta thấy rằng mô hình phải thực hiện tính toán hidden state tại toàn bộ các thời điểm rồi mới bắt đầu đưa ra output của mỗi thời điểm. Do đó, đôi khi BiRNN sẽ không thực sự phù hợp cho các bài toán real-time như speech recognition.

Tài liệu tham khảo

Lưu ý. Nếu phần Comment không load ra được thì các bạn vào DNS setting của Wifi/LAN và đổi thành "8.8.8.8" nhé (server của Google)!

Built with Hugo
Theme Stack designed by Jimmy