LSTM เท่าที่เข้าใจ
LSTM หรือชื่อเต็มคือ Long Short-Term Memory เป็นโครงข่ายประสาทเทียมแบบหนึ่งที่ถูกออกแบบมาสำหรับการประมวลผลลำดับ (sequence)
LSTM ถูกนำเสนอมาตั้งนานแล้วแต่พึ่งมาได้รับความนิยมไม่นานมานี้ หนึ่งในเหตุผล (อย่างน้อยสำหรับผม) คือมันดูวุ่นวาย เข้าใจยาก บทความแนะนำ LSTM มักจะเน้นไปที่ว่ามันทำงานอย่างไร แต่ในส่วนว่าทำไมมันถึงต้องเป็นเช่นนั้นนั้นมักถูกซ่อนไว้ในคำอธิบายที่ไม่ชัดเจนหรือรูปวาด ที่โดยส่วนตัวแล้วผมดูไม่ค่อยรู้เรื่อง
วันก่อนหยิบ LSTM มาอ่านใหม่แล้วรู้เรื่องมากขึ้น เลยมาจดไว้หน่อย
RNN พื้นฐานและ gradient vanishing
ก่อนอื่น LSTM นั้นจัดว่าเป็นโครงข่ายประเภท Recurrent Neural Network (RNN) นั่นคือ NN ที่มีการนำเอา output ของมันเองก่อนหน้านี้กลับมาใช้ใหม่เช่นรูปซ้ายข้างล่างนี้
รูปซ้ายนั้นคือตัวอย่าง RNN ที่มี 1 layer และ 1 node ใน layer นั้น เมื่อเราเอา RNN นี้ไปใช้ประมวลผลลำดับ (x(1),y(1)),…,(x(10),y(10)) x(t) คือข้อมูลนำเข้า ณ เวลา t และ y(t) คือค่าส่งออกที่เราต้องการ เช่น x(1),…,x(10) อาจจะเป็นลำดับคำในภาษาไทย ส่วน y(1),…,y(10) เป็น part of speech ของคำเหล่านี้เป็นต้น หรืออาจจะเป็นงาน sequence-to-sequence อื่นก็ได้
ในการปรับ RNN ข้างซ้ายจากข้อมูลลำดับนี้เราจะต้องทำการกางมันออกตามรูปขวา
จากนั้นขั้นตอนการ train ก็ใช้ gradient descent ตามปกติ โดย gradient ก็ได้จาก backprop มาตรฐาน
สิ่งที่ควรรู้คือการทำ backprop นั้นขึ้นกับ input และ output เช่นเราสามารถพิจารณา input x(1 )และ backprop gradient กลับมาจาก y(1) หรือ y(2) หรือกระทั้ง y(10)
สิ่งที่ต่างกันคือหากเรา backprop จาก y(1) ไป x(1) ข้อมูล gradient นั้นจะผ่านเพียง layer เดียว หากเรา backprop จาก y(2) ไป x(1) ก็ต้องผ่าน 2 layers หรือจาก y(10) ไป x(1) ก็ผ่าน 10 layers
ปัญหาที่เกิดเมื่อเราส่ง gradient ผ่านหลายๆ layer คือ “ขนาด” หรือ amplitude ของมันจะลดลง ทำให้สุดท้ายแล้วค่าที่ได้มานั้นเล็กมาก เมื่อนำมาใช้กับ learning rate ที่ปกติก็เล็กอยู่แล้วจะทำให้ weights ของ layer ที่เราพิจารณานั้นแทบไม่ถูกปรับเลย
เราเรียกปัญหานี้ว่า gradient vanishing
ควรทราบว่า gradient vanishing นั้นเกิดได้เมื่อเราสร้างโครงข่ายปกติอย่าง multi-layer Perceptron (MLP) แต่ใช้หลายชั้น เช่นกัน ในกรณีนั้น weights ของ layer ล่างๆ ก็จะไม่ถูกปรับ ค่า weights ที่ได้จากการ train ก็คือค่าที่ random ไว้ตั้งแต่ต้น ดังนั้นผลที่ได้จึงให้ผลที่ไม่ดี
กระบวนการทำ pre-training ที่ค่อยๆ สร้างทีละ layer โดยใช้ unsupervised criteria นั้นถึงจะไม่ได้แก้ gradient vanishing โดยตรงแต่ช่วยให้ weights ที่ได้นั้นอย่างน้อยมีคุณภาพดีกว่าการ random ปกติ
สำหรับ RNN นั้นสถาณการณ์ต่างออกไปบ้างนั่นเพราะหากเรา backprop จาก y(t) ไป x(t) มันก็แค่ชั้นเดียว ดังนั้นเราสามารถ train weights ได้อยู่แล้ว แต่การที่ gradient ที่ backprop จากตำแหน่งไกลๆ มันหายไปนั้นทำให้ RNN ที่สร้างไม่สามารถ capture long-term dependency ได้
แล้ว gradient vanishing เกิดได้อย่างไร? คำตอบคือมาจากกฎ chain rule ที่ใช้ใน backprop นี่แหละ
chain rule ที่เรียน ม.ปลาย คือถ้า h(x) = f(g(x)) แล้ว h’(x) = f’(g(x))g’(x)
สำหรับ NN แล้วการทำงานภายในแต่ละ node จบที่ activation function ที่มักเป็น non-linear function ดังนั้นค่า derivative ของมันนั้นมักจะน้อยกว่า 1 ดังนั้นเมื่อนำมาคูณต่อไปเรื่อยๆ ค่า gradient จึงเล็กลงเรื่อยๆ
Key หลักเพื่อให้ capture long-term dependency ได้ก็คือหาทางส่ง gradient กลับมาให้ได้มากที่สุด
Formulation พื้นฐานและการส่ง gradient
เพื่อประมวลผลลำดับ node ของ RNN ต้องมี หน่วยความจำภายใน ที่จำสิ่งที่เกิดขึ้นแล้วและใช้ในการตัดสินใจในเวลาถัดไป โดยหน่วยความจำภายในนี้ก็ต้องถูกปรับไปเรื่อยๆ ตามค่าของลำดับที่เราได้ประมวลผลมาเช่นกัน
ให้ h(t) เป็นหน่วยความจำภายใน ณ เวลา t
h(t) เองก็ต้องถูกปรับโดยใช้ 1) h(t-1), 2) ค่าส่งออกก่อนนี้ y(t-1) และ 3) ข้อมูลนำเข้า x(t)
สังเกตว่ามันก็เหมือนกับที่เราใช้ในการคำนวณค่าส่งออก y(t) ด้วย
นั่นคือ โดย concept แล้วแต่ละ node ต้องคำนวณ
h(t) = f1(x(t), y(t-1), h(t-1))
y(t) = f2(x(t), y(t-1), h(t-1))
โดย f1 และ f2 เป็นฟังก์ชันใดๆ หนึ่งในตัวอย่างที่เป็นไปได้คือ
h(t) = h(t-1) + tanh(U x(t) + W y(t-1))
y(t) = tanh( h(t) )
(tanh นี้ apply กับทุก element ของ vector แยกกัน)
ตัวแปรที่เราต้องปรับของ node นี้คือเมตริกซ์ U และ W
Formulation นี้น่าสนใจเพราะการปรับหน่วยความจำภายใน h(t) นั้นทำแบบ linear หากดูตาม concept แล้วก็เหมือนกับเรามี 2 เส้นทางในการ backprop gradient ทางแรกคือผ่าน h(t-1) ที่ผ่านง่าย อีกทางหนึ่งคือผ่าน tanh ที่ต้องโดน derivative ตบลง
เส้นทางแรกนี้เองที่ทำให้เราสามารถส่ง gradient กลับมาได้ง่ายขึ้น ทำให้เราสามารถจับ long-term dependency ได้ดีขึ้น
Gates
นอกจากการเพิ่มเส้นทางนำ gradient แล้ว LSTM ยังเสนอแนวความคิดเพิ่มอีกคือ
- ข้อมูลบางอย่างก็ควรจะลืมไปบ้าง
- ข้อมูลบางอย่างก็เป็น noise ที่ไม่ควรนำมาพิจารณา
- ข้อมูลบางอย่างอาจต้อง scale หรือ filter ก่อนส่งออก
แนวความคิดทั้ง 3 นี้นำไปสู่การเพิ่ม gates ต่างๆ คือ forget gate f, input gate i, และ output gate oให้กับ formulation ก่อนหน้านี้ นั่นคือ
h(t) = f⊗h(t-1) + i⊗tanh(U x(t) + W y(t-1))
y(t) = o⊗tanh( h(t) )
โดยสัญลักษณ์ ⊗ แทนการคูณกันของแต่ละ coordinate แยกกัน
และ gates ทั้ง 3 นี้ให้ค่าอยู่ในช่วง [0,1] ค่า 0 แปลว่าจะ ทำการลืม/ลบ input นั้นทิ้ง/ไม่ส่งค่านั้นออก หากค่าของ gates ทั้ง 3 เป็น 1 เราก็จะกลับไปที่ formulation ตั้งต้นข้างบน
Gates ทั้งสามนี้ยังสามารถตั้งให้เป็น function ที่สามารถถูก train ไปพร้อมๆ กับ node นี้ได้เช่น
f = logistic( Uf x(t) + Wf y(t-1) )
i = logistic( Ui x(t) + Wi y(t-1) )
o = logistic( Uo x(t) + Wo y(t-1) )
เรายังสามารถอนุญาตให้ gates ต่างๆ เข้าถึงหน่วยความจำภายใน h(t) ได้อีก เช่นให้
f = logistic( Uf x(t) + Wf y(t-1) + Vf h(t-1) )
i = logistic( Ui x(t) + Wi y(t-1) + Vi h(t-1) )
o = logistic( Uo x(t) + Wo y(t-1) + Vo h(t-1))
รู้สึกว่า formulation นี้จะเรียกว่ามีการเพิ่ม peephole gate หรือช่องแอบดู ให้กับ LSTM มาตรฐาน
หมายเหตุ
- ผมมาจากสาย Markov model (HMM, n-gram, variable length n-gram, …) ดังนั้นเลยไม่ถนัด RNN เลยจริงๆ ยุคที่ผมเรียนนั้น Markov model เป็นมาตรฐานด้านการประมวลผลลำดับ สงสัยต้องลองใช้ LSTM ดูบ้างแล้ว
- ใน model อย่าง HMM นั้นหน่วยความจำภายในหรือสถานะ (state) นั้นมีได้เยอะมากและลำดับของการเปลี่ยนสถานะเหล่านี้ก็มีที่กำหนดไว้แน่ชัด ถ้าเอาแนวความคิดนี้มาใส่ LSTM อาจได้อะไรสนุกๆ
ของที่ต่างกันคือการใช้งาน HMM นั้นอิงลำดับของ state แบบแน่ชัด ซึ่งถ้าเอามาใช้อาจต้องมีขั้นตอนเพิ่มสำหรับเก็บ state และ backtrack บางทีถ้าใช้ framework พวก Kalman filter ที่เป็น continuous state-space model (คล้ายๆ HMM ที่เราทำงานกับความน่าจะเป็นที่จะอยู่ใน state ต่างๆ แทนที่จำทำงานกับ state โดยตรง) อาจเหมาะกว่า - การคั้งให้ y(t) = tanh(h(t)) นั้นผมว่าเป็นเพราะต้องการจำกัด range ของค่า output ที่เป็นไปได้ แต่การใส่ tanh อาจไม่ใช่วิธีดีที่สุด ผมเดาว่าถ้าเปลี่ยนมาทำ batch normalization แทนอาจให้ผลที่ดีเหมือนกัน และอาจจะทำให้เก็บ dependency ได้ยาวขึ้นอีกนิด
ไว้ลองแล้วได้ผลไงจะมาเล่าให้ฟังอีกที