LoRA: Low-Rank Adaptation
LoRA?
LoRA (Low-Rank Adaptation)은 PEFT (Parameter Efficient Fine-Tuning) 방법의 하나로써,
모델을 학습시킬 때 전체 가중치를 업데이트하는 것이 아닌, 일부 파라미터만 효율적으로 학습시킴으로써 GPU 메모리 사용량을 획기적으로 줄일 수 있는 방법이다.
특히나 요즘같이 LLM과 같은 기반 모델의 크기가 매우 커지면서 FFT (Full Fine-Tuning)을 한정된 GPU 메모리 속에서 진행하기 힘든 시기에 PEFT 기법이 많이 주목받았는데, LoRA는 그중에서도 FFT에 비해 상대적으로 준하는 성능을 보이기에 더욱 사랑받는 기법이다.
기존 방법의 문제
앞서 설명했듯이, 기존에는 전체 가중치를 업데이트하면서 학습시켰는데 (FFT), 이를 위한 메모리 사용량이 매우 커진다.
LLM은 더욱 확연하게 체감할 수 있는데, 논문에서는 GPT-3 175B를 예시로 든다.
175B면 약 1750억개의 파라미터로 구성되어 있다는 건데, 학습하자고 행렬곱을 몇 번을 때려야 하는지 벌써부터 상상이 안 간다.
물론 사람들이 주구장창 학습이 완료될 때까지 FFT를 기다린 것은 아니다.
많은 사람들이 일부분의 파라미터만으로 학습시키려는 PEFT 방식을 연구했지만, 초기 PEFT 방식들은 inference latency가 많이 발생했기에 한계가 있었다. 무엇보다도, 성능마저 현저히 저하되기에 더더욱 효율적인 기법이 필요했다.
방법
LoRA를 만든 Edward Hu를 비롯한 연구진들은 다음과 같은 생각으로부터 LoRA 기법을 떠올렸다고 한다.
학습된 over-parametrized model이 실제로는 낮은 고유 차원 (Low Intrinsic Rank)에 있음을 많은 논문들이 보여주네..?
그럼, fine-tuning을 위해 사용되는 가중치들의 변화 또한 낮은 고유 차원이지 않을까?
아래는 논문에 적힌 LoRA의 원리다.
LoRA allows us to train some dense layers in a neural network indirectly by optimizing rank decomposition matrices of the dense layers' change during adaptation instead, while keeping the pre-trained weights frozen.
즉, pre-trained 모델 가중치는 그대로 두고, update 해야 할 가중치들(W 기울기) 행렬을 재구성하여 최적화하는 방법이다.
특징
우선 장점은 명확하다. GPU 메모리 사용량이 현저하게 줄어든다.
비록 모델 파라미터는 FFT와 비교 시 동일할지언정, GPU 메모리를 사용하는 4대 요소 (모델 파라미터, Gradient State, Optimizer State, 순전파 상태) 중 Gradient와 Optimizer 상태값을 획기적으로 줄일 수 있다.
당연하다. FFT에서는 W (d차원 * d차원) 만큼의 파라미터를 학습시켰다면, LoRA는 A, B 2개의 행렬을 학습할지언정 (각 d차원*r차원), 파라미터 수가 수십억 대인 LLM 세계에서는 엄청난 절약인 거다.
GPT-3 175B 기준, LoRA의 학습 파라미터 수가 FFT의 0.01%라니까... 말 다했다.
또한, 저렴한 비용으로 task 간 Context Switching을 할 수 있다는 것도 큰 특징이다.
이 역시 모든 파라미터가 아닌 특정 가중치들만 교환하기 때문에 가능한 일이다.
게다가 대부분의 파라미터에 대한 기울기 계산이 필요 없기에, 학습 속도도 FFT보다 빠른 성능을 보인다.
다만, 단점 또한 존재한다.
추가적인 inference latency를 없애기 위해 A, B 행렬을 W로 흡수하자고 선택하면, 이 A와 B가 다른 여러 task에 대한 입력을 배치로 처리하는 것이 어렵다.
Reference
LoRA arxiv : https://arxiv.org/abs/2106.09685
Cloudflare LoRA : https://www.cloudflare.com/ko-kr/learning/ai/what-is-lora/
kimjy99's blog : https://kimjy99.github.io/%EB%85%BC%EB%AC%B8%EB%A6%AC%EB%B7%B0/lora/