00:00:00
Mamba is a new neural net architecture that
is better than transformers at language modelling.
00:00:05
Yes that’s right, after reigning supreme
for 7 years, the transformer has finally been
00:00:12
dethroned.
00:00:13
Well, maybe, so far Mamba has only been tested
at small model sizes up to a few billion parameters,
00:00:22
but the results so far are promising!
00:00:24
In addition, Mamba uses less compute than
transformers.
00:00:28
For an input sequence of n words, Mamba only
uses O(nlog(n)) compute, whereas transformers
00:00:34
use O(n^2).
00:00:36
So Mamba based language models should allow
for much greater context sizes to be used.
00:00:42
In this video we’re going to do a deep dive
of the Mamba architecture, what is it, why
00:00:46
does it work so well, and how could you have
gone about designing such an architecture
00:00:50
yourself?
00:00:54
Usually Mamba is presented as an extension
of something called a state-space model.
00:00:59
State-space models are another type of sequence
model that have been steadily gaining popularity
00:01:03
over the past few years, but, to be honest,
the theory behind state-space models is massively
00:01:08
over-complicated and uses some pretty advanced
mathematics.
00:01:12
Fortunately, Mamba can also be understood
as an extension of recurrent neural networks,
00:01:17
or RNNs for short, which are much easier to
understand.
00:01:22
So in this video we will be taking the RNN
path to understanding Mamba.
00:01:27
Now let’s get started: what is a recurrent
neural network?
00:01:35
Given a sequence of input vectors, a convolutional
layer applies a neural net to consecutive
00:01:40
groups of vectors.
00:01:42
The key thing is that the neural net only
sees a small number of vectors at a time,
00:01:47
which makes the model easy to train.
00:01:50
The downside is that information from vectors
which are far way can’t be combined until
00:01:55
many convolutional layers have been applied.
00:01:58
This makes it difficult for convolutional
neural nets to understand long range dependencies
00:02:02
in their input, and such long-range dependencies
occur all the time in natural language text.
00:02:10
To remedy this flaw, the transformer architecture
was invented, which successfully allows a
00:02:15
single layer to combine information from vectors
no matter how far away they are.
00:02:20
I previously made a video explaining how and
why transformers work in detail, which you
00:02:25
can find here.
00:02:27
And while transformers work great, they have
a significant limitation, which is that the
00:02:32
amount of compute they use is quadratic in
the input length.
00:02:36
This isn’t a huge deal for small inputs,
but if you want to have a million vectors
00:02:40
in the input, that means you need to do a
million times a million operations, which
00:02:46
is a lot.
00:02:48
Recurrent neural nets take a completely different
approach to improving convolutional layers.
00:02:53
The idea is very simple: instead of applying
the neural net to two consecutive input vectors,
00:02:59
you apply it to one input vector and the previous
output of the neural net.
00:03:05
This seems like a small change, but it has
profound consequences: each output vector
00:03:10
now contains information from all of the input
vectors prior to it, instead of only one previous
00:03:16
vector.
00:03:17
This final output vector contains information
from every vector in the input, no matter
00:03:22
how many there are.
00:03:24
And we have not used any more compute than
a convolutional layer.
00:03:28
We’ve managed to incorporate long-range
information, for free.
00:03:32
This is exactly what we want.
00:03:34
Or at least, it would be, if it weren’t
for 2 small problems with RNNs which make
00:03:40
them almost impossible to use in practice.
00:03:43
The first problem is that, while a recurrent
layer uses the same amount of compute as a
00:03:48
convolutional layer, that compute cannot be
paralellized across multiple processors.
00:03:55
Even if you have lots of processors available,
you can’t begin evaluating the neural net
00:04:00
on an input until all of the previous steps
have finished, because you need to feed the
00:04:05
output from the previous step into the neural
net.
00:04:10
Compare this to a convolutional layer, where
the neural net only needs to see the original
00:04:15
input.
00:04:16
You can run the neural net on all inputs at
the same time, so long as you have enough
00:04:20
processors available.
00:04:23
And since modern hardware, such as GPUs, are
highly specialized for parallel computation
00:04:28
with thousands of processors, RNNs are actually
a lot slower than CNNs in practice.
00:04:37
In fact RNNs are even slower than transformers,
despite doing less computation.
00:04:45
And the second problem, is that RNNs are incredibly
difficult to train.
00:04:50
While in theory, a single recurrent layer
can incorporate information from arbitrarily
00:04:55
many inputs, in practice, they don’t.
00:04:59
Instead, they only learn to incorporate information
from the previous few dozen inputs at most.
00:05:06
The idea for RNNs has been around since the
1980s, but because of these 2 problems, RNNs
00:05:11
have fallen out of favour, with convolutional
neural nets and transformers being much more
00:05:16
successful in practice.
00:05:18
In fact, RNNs have hardly been used at all
in the past decade.
00:05:24
Until now.
00:05:25
Last year, a new paper was published showing
that linear RNNs can avoid both of these problems,
00:05:32
and therefore linear RNNs are highly effective
long sequence models.
00:05:37
So what is a linear recurrent neural network?
00:05:42
Well you simply replace the neural net with
a linear function.
00:05:53
This might seem like a bad idea, since linear
functions can only perform relatively simple
00:05:58
transformations of their inputs, but we can
make up for it by applying a full neural net
00:06:03
to each output vector afterwards.
00:06:06
This is similar to how in transformers you
can replace the value neural nets with simple
00:06:11
linear functions, and then add neural nets
in between self-attention layers to make up
00:06:15
for the lack of non-linear processing power.
00:06:19
So just like in a transformer, we will alternate
linear recurrent layers with element wise
00:06:25
neural networks.
00:06:28
But importantly, by making the recurrent operation
purely linear it becomes possible to solve
00:06:34
both of the RNN problems.
00:06:37
To start with I’ll explain how a linear
recurrence applied to n vectors can be computed
00:06:42
in parallel in just O(log(n)) time.
00:06:46
And then I’ll explain how the training issues
that plague regular RNNs can be fixed in linear
00:06:52
recurrences.
00:06:59
The linear recurrence operator is given by
this formula: to get the i’th output vector
00:07:04
you multiply the previous, (i-1)’th, output
vector with a matrix W_y, and add the i’th
00:07:11
input vector multiplied by a different matrix
W_x.
00:07:15
The entries in the W matrices are the parameters
which will be learned by the model, so they
00:07:20
start off as random samples from a normal
distribution centred at 0, and are then updated
00:07:26
with gradient descent.
00:07:28
And since the W_x matrix is just applied to
each input independently, we can actually
00:07:33
just think of it as being part of the previous
layer, so we can simplify our recurrence operator
00:07:39
to just add the input x, assuming that a linear
function has already been applied to the input
00:07:45
in the previous layer.
00:07:48
A linear recurrence is actually a special
case of a more general operation called a
00:07:53
scan, so let’s start with the simplest example
of a scan: a cumulative sum.
00:07:59
Given a list of n numbers as input, the goal
is to compute the list of partial sums, up
00:08:04
to each term.
00:08:06
So the i’th item in the output list should
be the sum of of the first i items of the
00:08:11
input list.
00:08:14
While it is trivial to compute this by simply
adding the numbers together, one at a time,
00:08:19
we want to do it in parallel.
00:08:24
And it turns out we can do so as follows:
first add together each consecutive pair of
00:08:30
numbers.
00:08:32
Then, from the resulting list, add together
pairs of numbers which are 2 steps apart.
00:08:43
Then 4 steps apart.
00:08:45
And 8… and so on, each iteration doubling
the step size, until the step size is as large
00:08:53
as the entire input list, which will be after
log(n) steps.
00:08:59
This algorithm works because at each iteration,
the i’th output element contains the sum
00:09:04
of the previous step size numbers.
00:09:08
For example, in the first iteration, each
output number is the sum of the previous 2
00:09:13
terms.
00:09:14
In the next iteration, each item contains
the sum of the previous 2 terms plus the sum
00:09:19
of the previous 2 terms starting 2 away, that
is the sum of the previous 4 terms.
00:09:26
And so on.
00:09:27
When the step size is the size of the input,
each output contains the sum of all previous
00:09:33
terms, as desired.
00:09:35
It’s trivial to see that each iteration
can be computed in parallel, however the different
00:09:41
iterations do still need to be computed sequentially,
and there are log(n) iterations.
00:09:48
So, if you have n processors, the total run
time of this algorithm is O(log(n)), down
00:09:53
from O(n) of the naive sequential version.
00:09:57
And this same algorithm works for computing
lists of cumulative applications of any binary
00:10:03
operator, not just addition, so long as the
binary operator is associative.
00:10:11
Associative means that you can change the
order of application and you’ll still end
00:10:15
up with the same result.
00:10:17
This is true of addition, which is why our
parallel cumulative sum algorithm works.
00:10:25
And it’s also true of of a bunch of other
operations.
00:10:29
In particular, this binary operator is associative:
f((W1, x1), (W2, x2)) = (W1*W2, W1*x1+x2).
00:10:35
Note that this operator uses a pair of a matrix
and a vector as input and output, instead
00:10:41
of just a single number like with addition.
00:10:46
And remarkably, performing a scan with this
operator is equivalent to a linear recurrence.
00:10:52
We first need to replace our input list of
vectors with a list of pairs, where the first
00:10:58
element is the recurrent weight matrix and
the second element is the input vector, but
00:11:03
then we just perform the scan as usual.
00:11:09
You can check for yourself that this operator
is in fact associative by expanding a few
00:11:14
terms in the other order.
00:11:18
To summarize, we just need to do our parallel
cumulative sum algorithm with this operator
00:11:24
in place of addition, and we get the result
of a linear recurrent layer in just O(log(n))
00:11:31
time.
00:11:33
Except for one small problem.
00:11:35
If you look closely at this operation, the
way it works is by using the first element
00:11:41
of the tuples as a cumulative matrix, which
contains the product of all of the matrices
00:11:46
seen so far.
00:11:47
That’s why the first element of the output
tuple is the product of the two input matrices.
00:11:54
But this means we’re performing a [d, d]
times [d, d] matrix multiplication in every
00:12:01
step, where d is the dimension of the vectors.
00:12:05
This is really slow.
00:12:08
Note that in the original sequential RNN we
didn’t need to keep track of this cumulative
00:12:13
matrix, and so we only ever multiply the weight
matrix with a length [d] input vector at each
00:12:19
step, which is a O(d^2) operation.
00:12:23
But now we have to do a O(d^3) operation in
every step.
00:12:28
For standard model sizes, this is easily a
thousand fold increase in computation.
00:12:33
And that’s bad.
00:12:36
Fortunately, there is a way around this: matrix
diagonalization.
00:12:41
You see (almost) every square matrix can be
factored into the product of an invertible
00:12:47
matrix P, a diagonal matrix D, and P^-1, so
long as the matrix elements are allowed to
00:12:55
be complex numbers.
00:12:58
Here’s an example.
00:13:02
Note that this middle matrix is diagonal,
that is all elements except for the main diagonal
00:13:08
are 0.
00:13:09
What’s neat about this is when you multiply
the matrix by itself in this form, the inner
00:13:14
P inverse and P terms cancel, and the product
of 2 diagonal matrices is just the diagonal
00:13:21
matrix with the product of elements.
00:13:23
That is, in order to compute D^2, all you
need to do is square the elements on the main
00:13:29
diagonal of D, which can be done in just O(m)
operations, instead of O(m^3), much better.
00:13:37
So then, what we can do is represent the recurrent
weight matrix in diagonalized form, which
00:13:43
means we only need to use a complex vector
which contains the elements of the main diagonal
00:13:48
of D.
00:13:51
That is to say, we first apply a complex matrix
P to the input vectors, then perform the linear
00:13:57
recurrence with a complex weight vector w,
using element-wise multiplication, and finally
00:14:03
apply P^-1 to the output.
00:14:07
The result of this will then be equivalent
to a linear recurrence for some real valued
00:14:12
weight matrix W. But when computed this way,
the recurrence operator only needs to compute
00:14:18
element-wise multiplication between two vectors
to update the cumulative weights, instead
00:14:23
of matrix multiplication.
00:14:26
When we plug this operator into our parallel
scan algorithm, the total compute is now just
00:14:31
O(d*n*log(n)), and the parallel run time is
O(log(n)).
00:14:35
Much better.
00:14:38
Note that the parameters of this layer are
the complex entries in the recurrent weight
00:14:42
vector w and matrix P. In practice you would
just use two separate real numbers to represent
00:14:48
the real and imaginary components of each
parameter, which are initialized by sampling
00:14:53
from a normal distribution centred at 0, and
updated with gradient descent as usual.
00:15:00
Lastly, computing matrix inverses is really
slow, so in practice we don’t bother, and
00:15:07
instead just use 2 independent complex matrices
before and after the linear recurrence.
00:15:13
This actually makes the model more expressive
than a real valued linear RNN, and it saves
00:15:18
computation.
00:15:19
But it does mean that the model is no longer
equivalent to a real valued recurrence, and
00:15:23
the output can now be a complex number, so
we will need to take the real valued part
00:15:28
of the output before passing it to the next
layer.
00:15:32
Ok, so we’ve seen how to make linear RNNs
fast for modern hardware, but what about the
00:15:38
other problem, that RNNs are very difficult
to train?
00:15:44
Before we solve this problem, here’s a quick
recap of why training RNNs is so problematic
00:15:48
in the first place: neural nets are trained
by subtracting the gradient of the loss function
00:15:53
from each weight in the model.
00:15:56
What is the gradient?
00:15:57
Well imagine evaluating the neural net, then
increasing the value of a weight by a very
00:16:04
small amount, and then evaluating it again.
00:16:09
The difference in these scores is (proportional
to) the gradient for that weight, and it tells
00:16:14
you how to change the weight to make the neural
net better.
00:16:17
So let’s evaluate the gradient of a linear
recurrent layer.
00:16:23
Actually to make this a bit easier, let’s
simplify the model and suppose that every
00:16:27
input after the first is 0, so we can just
ignore them.
00:16:33
When we evaluate the recurrent layer, at each
step the previous output is multiplied by
00:16:38
the weight vector, so after n steps the output
vector is equal to the recurrent weight vector
00:16:43
to the power of n times the first vector x_1.
00:16:48
When we increase the weight by a small amount
and evaluate it again we get this.
00:16:54
Taking the difference, we get, up to a constant
scaling factor, w^(n-1) x_1.
00:17:04
The problem here is that as n becomes large,
this term, w^(n-1), either gets very small
00:17:11
or very large, depending on whether the values
in w are less than or greater than 1.
00:17:19
In either case it’s a problem: If the gradient
is very large then the neural net weights
00:17:24
change too much, and the existing functionality
already learned by the neural net gets destroyed.
00:17:31
If the gradient is very small then the weights
don’t change enough and the neural net doesn’t
00:17:35
learn anything at all.
00:17:38
This is what makes training RNNs difficult,
while in principle RNNs can use infinitely
00:17:44
long context, in practice, with gradient based
training techniques, the RNN will only learn
00:17:50
to use context for as many steps as the gradient
remains the right size for learning.
00:17:56
This is known as the problem of vanishing
and exploding gradients.
00:18:01
And when we add back in non-zero inputs, this
problem only gets worse, as the additional
00:18:06
inputs make the gradients even more unstable.
00:18:11
And to be clear, the reason why this isn’t
a problem for regular neural nets is because
00:18:16
they use different weights in each layer.
00:18:20
Some layers can have weights smaller than
1, and some layers can have weights larger
00:18:23
than 1, so long as the gradient remains about
the same size, the neural net will be able
00:18:28
to learn.
00:18:32
There are lots and lots of different configurations
of weights that result in stable gradients,
00:18:37
and its easy to stay in stable configurations
all throughout training.
00:18:44
But for RNNs, you’re using the same weight
in each step, so there is exactly one stable
00:18:50
configuration which is when the weight is
1.
00:18:54
Any deviation from 1 and you have exponentially
growing or decaying gradients.
00:19:00
Note that when the weights are complex numbers
the same argument applies, just using the
00:19:05
absolute value of the weights.
00:19:09
So how can we fix vanishing and exploding
gradients?
00:19:12
Well, we saw that the RNN gradients are stable
so long as the recurrent weights are 1 and
00:19:18
the inputs are 0, so in the linear RNN paper
the authors propose to initialize their linear
00:19:25
RNN in this stable state.
00:19:29
Specifically, they parameterize the weights
in complex polar form ae^ib, where a is magnitude
00:19:36
and b is angle.
00:19:37
They then restrict the magnitude to be less
than 1 by running a through this e^(-e^())
00:19:43
function, which always outputs a number between
0 and 1, and instead of randomly sampling
00:19:49
a from a normal distribution centred at 0,
as we usually do, they initialize a so that
00:19:55
the magnitude e^(-e^(a)) is uniformly distributed
between 0.999 and 1.
00:20:03
They initialize the angle, b, uniformly between
0 and /10 radians.
00:20:09
This ensures that, at initialization, the
weights are all very close to 1.
00:20:14
Finally they multiply the inputs by , which
is another learnable parameter initialized
00:20:19
to sqrt(1-e^(-e^(a))), which since e^(-e^a)
is close to one, this is some very small number.
00:20:30
This ensures that at initialization the inputs
are all close to 0, and so they don’t interfere
00:20:35
with the recurrence.
00:20:37
So at initialization, this model is approximately
the same as the stable RNN I showed you before.
00:20:44
After the model begins training and the weights
change, there is no guarantee that it will
00:20:48
remain stable, but in practice it appears
that just initializing the model like this
00:20:53
is sufficient to allow it to learn to remember
context for tens of thousands of steps.
00:20:59
And there we have it, with these modifications,
we now have a linear RNN that is fast to compute,
00:21:04
and learns to use extremely long context.
00:21:08
In the linear RNN paper, they evaluate this
model on the long range arena benchmark, which
00:21:13
is a collection of 6 synthetic tasks that
evaluate a model’s ability to perform long
00:21:18
range reasoning.
00:21:21
For example, in the PathX task the model must
classify images as whether or not they contain
00:21:27
a complete dotted path between two circles.
00:21:32
Except that the image are flattened into one
long sequence of 16 thousand pixels.
00:21:41
The linear RNN achieved state-of-the-art performance
on the long range arena, outperforming transformers
00:21:47
by about 33% on average across tasks.
00:21:54
So now that we understand the linear RNN,
what’s with all the talk about state-space
00:21:59
models?
00:22:00
Well, it turns out that state-space models
are just linear RNNs.
00:22:04
State space models were inspired by control
theory, and were derived from a totally different
00:22:10
idea of trying to discretize a continuous
dynamical system, but the end result is just
00:22:16
a linear RNN, with a slightly different initialization
scheme.
00:22:20
The most common form of state space model
parameterizes each recurrent weight as w=e^((a+bi)),
00:22:26
where is again a learnable parameter which
is initialized to a very small number, usually
00:22:34
between 0.0001 and 0.1.
00:22:39
Multiplying the weight by a small number makes
it close to 0, and when you take e to the
00:22:43
power of something close to 0 you get something
close to 1.
00:22:46
This again ensures that at initialization
the recurrent weights are all approximately
00:22:51
one, so training is stable.
00:22:53
State space models also multiply the inputs
by ((a+bi))^-1(w-1), because that’s what’s
00:23:01
prescribed by control theory, but empirically
you get the same performance when you just
00:23:06
scale the inputs by as in the linear RNN setup.
00:23:11
On the long range arena, the control theory
inspired state-space initialization performs
00:23:16
roughly the same as the linear RNN initialization.
00:23:19
Anyway, whenever you hear state-space model,
think linear RNN.
00:23:24
And finally we can talk about Mamba.
00:23:26
You see, while linear RNNs do perform really
well on the long range arena benchmark, this
00:23:32
does not mean they are good language models.
00:23:34
For language modelling, linear RNNs perform
way worse than transformers.
00:23:39
Here is the performance of various state-of-the-art
language models, lower is better on this graph.
00:23:46
As you can see, everything, including state-space
models does significantly worse than transformers.
00:23:53
The reason for this, as identified in the
Mamba paper, is that linear RNNs are incapable
00:23:58
of selectively forgetting information from
the output vector.
00:24:02
If the weights are close to 0 then the output
vector will be set to 0 after every input,
00:24:08
effectively the model will always immediately
forget whatever came before the current input.
00:24:13
If the recurrent weights are close to 1 then
the output vector doesn’t change when its
00:24:17
multiplied with the weights, so the output
vector will accumulate information from all
00:24:22
inputs observed.
00:24:24
What you want is for the model to be able
to decide when to store information and when
00:24:29
to forget information, based on what input
it sees.
00:24:33
Mamba proposes an elegant solution: instead
of using the same weights in each step, use
00:24:39
different weights which depend on the input.
00:24:43
Mamba applies a linear function to each input
vector to generate a separate weight vector
00:24:49
for that input.
00:24:52
Then the recurrent scan is performed using
these generated weights.
00:24:57
This way, certain inputs can generate weights
close to 0 and thereby erase information from
00:25:02
the output vector, but other inputs can generate
weights close to 1 thereby leaving the output
00:25:08
vector unchanged.
00:25:11
And I also suspect that using different weights
at each step helps with vanishing and exploding
00:25:16
gradients, since there should now be many
different stable configurations, like in feed-forward
00:25:22
networks, although this wasn’t mentioned
in the Mamba paper.
00:25:27
Mamba also uses one more trick, which is to
increase the size of the output vectors.
00:25:33
In a standard RNN the output vectors are the
same size as the input vectors.
00:25:38
Mamba expands the size of the output vectors
by a factor of 16.
00:25:43
This allows it to store much more information
from previous inputs.
00:25:48
The output vectors are then projected back
down to the original size before being passed
00:25:53
to the next layer.
00:25:57
Usually this would increase the computation
time by a factor of 16, but it turns out that
00:26:02
the major bottleneck of a Mamba layer on modern
GPUs is the time it takes to read and write
00:26:07
data into high performance memory.
00:26:10
You see modern GPUs actually have 2 different
types of memory.
00:26:14
Data is stored in main memory, but in order
to do computations, data first needs to be
00:26:19
transferred into high-performance memory.
00:26:23
For the mamba recurrence operation, it turns
out that the time taken to transfer data is
00:26:28
actually much larger than the time it takes
to do the computation itself.
00:26:33
Mamba therefore transfers the input vectors
and model parameters to high performance memory,
00:26:39
then computes the whole mamba operation in
a single block, including projecting outputs
00:26:44
back down to the smaller original size, before
writing the results back to main memory.
00:26:51
This way you only transfer vectors of the
original size to and from high performance
00:26:56
memory, so the transfer time is unchanged.
00:27:01
The actual computation time is 16 times slower,
but the computation time was so small compared
00:27:06
to the transfer time that it doesn’t really
affect the overall time taken.
00:27:10
You essentially get to use 16 times larger
vectors for free.
00:27:15
And there we have it, this, along with a few
minor architecture modifications, is Mamba,
00:27:21
the dynamic linear recurrent neural network,
which performs better than transformers at
00:27:26
language modelling, while using only O(nlog(n))
compute, down from O(n^2).
00:27:32
Ok, now that we’ve made it through all of
those boring technical details, we can finally
00:27:40
talk about what really matters: the Mamba
drama.
00:27:49
You see, the Mamba paper caused quite a bit
of controversy in the machine learning community
00:27:53
this year.
00:27:55
The Mamba paper was submitted to ICLR 2024,
which is one of the most prestigious machine
00:28:00
learning conferences in the world.
00:28:03
And in January, it was rejected by peer reviewers.
00:28:07
But so what?
00:28:09
Papers get rejected from top conferences all
the time, right.
00:28:12
Well, to give some context, the Mamba pre-print
has been publicly available since last year
00:28:18
and during this time several different groups
have re-implemented Mamba and all successfully
00:28:23
reproduced the results claimed in the Mamba
paper, namely that Mamba performs better than
00:28:29
transformers and uses less computation.
00:28:32
And since transformers are all anyone has
talked about for the last 5 years, that’s
00:28:37
kind of a big deal.
00:28:39
Because of this, everyone in the community
was expecting the Mamba paper to be accepted,
00:28:45
if not win a best-paper award.
00:28:50
So then, if the Mamba architecture really
works, what glaring flaws are in the paper
00:28:56
that caused it to be rejected?
00:28:57
Well, the ICLR peer review is publicly available
for everyone to view.
00:29:04
So let’s take a look.
00:29:05
According to the meta review, Mamba wasn’t
tested on the long range arena benchmark.
00:29:13
Remember that benchmark I talked about, where
linear RNNs performed way better than transformers?
00:29:18
This reviewer wanted to see how well Mamba
performed on that task.
00:29:23
Now this is a really dumb reason to reject
a paper, because, the long range arena is
00:29:28
a completely different task to language modelling,
and Mamba is specifically a language model.
00:29:34
Keep in mind, transformers perform way worse
than linear RNNs on the long range arena,
00:29:39
but transformers are still way better language
models.
00:29:43
So performance on the long range arena is
in no way indicative of language modelling
00:29:47
ability.
00:29:49
Mamba sets a new state of the art for language
modelling, it shouldn't be rejected because
00:29:53
it doesn’t also solve some other, unrelated
task.
00:29:59
The only other major criticism in the meta
review is that Mamba was only evaluated on
00:30:03
language modelling, that is the accuracy when
predicting the next word in a piece of text.
00:30:10
The reviewers argue that this metric isn’t
indicative of a language model’s utility,
00:30:15
and instead Mamba should have been evaluated
on downstream tasks that measure a model’s
00:30:20
reasoning ability.
00:30:22
Except that... this is exactly what they did
in the Mamba paper, they pre-trained Mamba
00:30:27
as a language model and then performed zero-shot
prompting on a bunch of standard down-stream
00:30:33
benchmark tasks.
00:30:34
And surprise, surprise, Mamba outperforms
all other language models.
00:30:39
As a bonus, another reviewer said, and I quote,
“Mamba has a quadratic memory requirement
00:30:46
during training, just like transformers”.
00:30:49
Which… is just not true.
00:30:53
Neither Mamba nor transformers have quadratic
memory costs.
00:30:58
Transformers have a quadratic compute cost,
but their memory cost is linear.
00:31:03
So is Mamba’s…
00:31:04
I’m not sure how its even possible to come
to the conclusion that Mamba has a quadratic
00:31:08
memory cost if you understand how it works
at all…
00:31:13
So as you can imagine, this less than ideal
peer review sparked some debate in the machine
00:31:17
learning community about peer reviewing practices
and whether or not Mamba should have been
00:31:22
rejected.
00:31:23
You can probably guess which side of the debate
I fall on, but let me know your thoughts about
00:31:27
how broken academic peer review is in the
comments below.
00:31:31
Or thoughts about the actual Mamba architecture
itself, I guess that’s fine too.