Meta Learning — A path to Artificial General Intelligence Series Part III — Optimization-Based Meta Learning

CellStrat
4 min readApr 26, 2022

--

This is the finale of our Meta Learning — A path to Artificial General Intelligence Series. Currently, all deep learning models learn by backpropagating the gradients and then performing gradient descent. But this method is not designed for learning with less data and requires a lot of iterations to arrive at the minima. Optimization-based Meta-Learning intends to design algorithms which modify the training algorithm such that they can learn with less data in just a few training steps. Usually, this refers to learning an initialization of parameters which can be fine-tuned with a few gradient updates. Some examples of such algorithms are —

  1. LSTM Meta-Learner
  2. 2. Model-Agnostic Meta-Learner (MAML)
  3. 3. Reptile.

Transfer Learning and Optimization-based Meta Learning, which one is better though? For transfer learning to work, we first pretrain a model on a very large dataset. The resulting pretrained model is then used as the initial parameters for fine-tuning on a new and medium-sized dataset. However, transfer learning is not designed specifically for learning with less data. It just happens to be good at medium-sized datasets because of the pre-learned features. Optimization-based meta learning is aimed at finding those initial sets of parameters which are generalizable to a wide range of problems, so that when we have a new problem, we only need a few gradient updates for fine-tuning to a small dataset. These algorithms are explicitly designed and trained for finding that set of initial parameters which can be fine-tuned later in a couple of training steps.

Model-agnostic Meta Learner (MAML) is a model-agnostic meta-learning algorithm which explicitly learns parameters that can generalise to any new task by fine-tuning with a single training step. It is model-agnostic i.e. it can work with any deep learning model which is trained with gradient descent. It has been tested on few-shot classification, regression and even reinforcement learning to demonstrate its flexibility.

During the meta training, we optimise to find the optimal initial parameters (blue) such that it is close to all related tasks. Performing a few steps of fine-tuning on the specific task should generalise it well without overfitting.

Let’s look at this with an example

Consider a few-shot image classification problem. The dataset is split into multiple tasks. Each task is sampled as K-shot N-way classification. K images for each of the N classes which comprises the Support Set (train). The Query Set (test) also contains some K images for each of the N classes. The model is a vanilla CNN based image classifier.

Training steps include:

  1. Initial model parameters θ
  2. Begin Epoch —

a. Sample a batch of tasks, Ti

b. For all Ti do —

i. Sample the Support Set Si

ii. Make predictions on Siusing θ

iii. Calculate Loss LSi (fθ )

iv. Fast Weight θi ’ = θ -ɑΔθ LSi

v. Sample the Query Set Qi for Meta update

vi. Calculate Meta Loss

3. Gather Meta Losses, LMeta

4. Update θ = θ — 𝛽Δθ LMeta

In the last step, the meta loss gradient is calculated with respect to the initial parameters θ making the backpropagation go through the entire computation graph including the task specific fine-tuning part where we had calculated the gradient once for support set loss LSi . This means it involves taking the gradient of the gradient making the derivative of the second-order (Hessian vector). Intuitively, this means moving the initial parameters θ to a place which is easy to reach for all tasks for fine tuning in a single gradient update. The Meta Parameter Update of θ is the backpropagation through the entire inner fast weights leading to second order derivatives.

MAML can be used for few-shot classification problems or any similar problem where data is scarce. MAML is model agnostic so it is flexible enough to fit any domain mostly. But it requires more compute power. Applications of Optimization-based Meta Learning can be like Few-shot image classification, Regression, Reinforcement Learning or Any deep learning model which learns uses gradient-based optimizers (Adam, SGD, etc.).

USEFUL LINKS :-

--

--

CellStrat

A Simple and Unified AI Platform for Developers and Researchers.