Nuclear Norm Optimizer
13 min read

Deriving and Implementing an Optimizer under the Nuclear Norm

An assignment by an exciting startup working on applied interpretability.

A day ago I was given a take-home assignment by an exciting startup working on applied interpretability. I might write a separate blog about that later.

Before reading the assignment, I told them that I could finish it within a day. I won't forget their shocked face after hearing that, and I fully understand why after I read it. In short, I need to derive and implement an optimizer that minimizes the approximation error under the nuclear norm. Then, I need to train a NanoGPT scale model with the new optimizer.

Still, here I am, one day later with my work.

So this blog is my attempt to document what I did, explain my thought process, and eventually present my work. Turning this journey into a blog post might seem too excessive, but I'm still doing this because it has been too long since my last writing (sorry!).

The first half here is going to be storytelling, which will basically talk about the things I tried. So if you are here to read my immature derivation and implementation, feel free to skip this part and read the second half :)

Spoiler alert, it was after I finished training all 60 epochs (4 hours on RTX 4090) when I realized I misunderstood the task. I'm going to elaborate on this in the end.

“What even is an optimizer?”

When I first saw the task (a very short one), I could fully understand each individual word. But when they are connected together and form a sentence, I feel like comprehending a new language.

Yet, you might have seen how I put “building safer and more efficient LLMs” as one of the goals on my personal bio page. How can someone who's interested in LLMs not know about norms and optimizers? The reason here is that I have been learning the so-called “advanced concepts” that feel cutting-edged, such as Transformer-based algorithms. If you were to ask me about Deepseek's papers, I can probably hand-draw most of their methods. However, if you ask me about normalization and gradient descent, the last time I really learned about them was in Andrew Ng's deep learning specialization course.

So to speak optimistically, this assignment was also a wake-up call. I really have to fundamentally understand things in order to innovate and provide real value.

Jumping back to solving the task, the first thing I did was to understand it. Ignore the derivation and implementation. I first need to learn what spectral and nuclear norms are. The only resource I used here was Gemini 2.5 Pro. It did a great job building conceptual and mathematical terms from the ground base.

After learning the basics (the threshold here is probably that I can know what a formal methodology is talking about), I turned to o3 (my recent search engine), which found a few resources on Muon Optimizer. Muon optimizer is the inspiration for the assignment. It is analogous to the steepest descent under the spectral norm, which is similar to the nuclear norm. If you want to stick with me through the second half of the blog, I highly recommend reading these two walkthroughs by the original authors:

  1. Muon: An optimizer for hidden layers in neural networks: https://kellerjordan.github.io/posts/muon/
  2. Deriving Muon: https://jeremybernste.in/writing/deriving-muon

They are basically everything I needed else than the source code and some math. I personally prefer sticking to a few very high-quality resources and understanding them fully. I have probably spent the most amount of time on these two blogs. Again, Gemini helped a lot. I first let it read the blog and then asked it to explain different parts by simply screenshotting them. I also went across other posts, with the goal being to see some diverse intuitions, not the mathematical derivation.

Here I'm going to assume that you, just like me, have also gone through Muon optimizer's derivation of the spectral norm. The two posts are going to be enough here, but if I were to conduct a research-level study, I am surely going to read its formal paper, which will at least quadruple the time and effort.

Tauon, a clean, nuclear‑norm, low‑rank optimizer

Theoretical Derivation: Steepest Descent under Nuclear Norm Constraints

The classical steepest descent problem aims to find an update $\Delta W$ that maximizes the local decrease in the loss function $L(W)$, approximated by $-\text{Tr}(G^T \Delta W)$ where $G = \nabla_W L$, subject to a constraint on the step size $|\Delta W|$. Constraining the step using the nuclear norm, $|\Delta W|_* \leq \eta$, leverages norm duality ($|\cdot|_*$ and $|\cdot|_2$ are duals) to yield the optimal rank-1 update direction:

$$\Delta W_{sd}^{(1)} = -\eta \frac{G}{\|G\|_2} = -\eta u_1 v_1^T$$

Where $u_1, v_1$ are the top singular vectors of the gradient $G$. This direction aligns with the Muon optimizer's update. This concept can be extended to a rank-$k$ steepest direction $\Delta W_{sd}^{(k)} = -\sum_{i=1}^k \tau_i u_i v_i^T$ by imposing constraints only on the sum of the top-$k$ singular values of $\Delta W$.

However, computing the singular vectors $u_i, v_i$ of the gradient $G$ at each iteration, even using iterative methods like power iteration, incurs significant computational overhead, particularly for large matrices prevalent in modern neural networks. Therefore, while the steepest descent analysis under nuclear norm constraints provides valuable theoretical motivation for associating gradient singular vectors with optimal low-rank updates, we need to pursue a more computationally efficient practical algorithm.

A Quick Derivation

Before we look into the algorithm, let's go over the derivation that is going to build the ground.

The nuclear norm of a matrix $W \in \mathbb{R}^{m \times n}$ is defined as the sum of its singular values. Mathematically, this is expressed as:

$$\|W\|_* = \sum_i \sigma_i(W)$$

where $\sigma_i(W)$ are the singular values of the matrix $W$, arranged in non-increasing order. To understand how we can exploit the nuclear norm, let's first talk about the Singular Value Decomposition (SVD). Any matrix $W$ can be decomposed into three matrices as follows:

$$W = U \Sigma V^T$$

Where:

  • $U \in \mathbb{R}^{m \times m}$ is an orthogonal matrix (meaning $U^T U = I$),
  • $V \in \mathbb{R}^{n \times n}$ is also an orthogonal matrix (meaning $V^T V = I$),
  • $\Sigma \in \mathbb{R}^{m \times n}$ is a diagonal matrix containing the singular values $\sigma_1, \sigma_2, \dots, \sigma_r$, where $\sigma_i \geq 0$ and the remaining singular values are zero if $W$ is of rank $r$.

This decomposition tells us that we can express any matrix $W$ as a product of these three matrices, where the diagonal entries of $\Sigma$ represent the singular values of $W$.

It's important to note that if $W$ has rank $r$, then only the first $r$ singular values are non-zero. This decomposition will be extremely useful when we start thinking about low-rank approximations of $W$.

Low-Rank Factorization Idea

Now, imagine we define:

$$L := U_r \Sigma_r^{1/2} \quad \text{and} \quad R := V_r \Sigma_r^{1/2}$$

Where:

  • $U_r = \text{first } r \text{ columns of } U$,
  • $V_r = \text{first } r \text{ columns of } V$,
  • $\Sigma_r \in \mathbb{R}^{r \times r}$ is the diagonal matrix of positive singular values.

Thus, $L$ is $m \times r$, and $R$ is $n \times r$.

Now check:

$$LR^T = (U_r \Sigma_r^{1/2})(V_r \Sigma_r^{1/2})^T = U_r \Sigma_r^{1/2} (\Sigma_r^{1/2} V_r^T) = U_r \Sigma_r V_r^T = W$$

We reconstructed $W$ exactly!

Compute the Frobenius Norms of $L$ and $R$

Next, we need to compute the Frobenius norms of $L$ and $R$, which will play an important role in regularizing the rank of $W$. The Frobenius norm of a matrix $M$ is simply the sum of the squares of its entries and can be written as:

$$\|L\|_F^2 = \text{trace}(L^T L)$$ $$\|R\|_F^2 = \text{trace}(R^T R)$$

Now, let's expand these for $L$ and $R$.

For $L$, we have:

$$L^T L = (\Sigma_r^{1/2})^T (U_r^T U_r) \Sigma_r^{1/2}$$

Since $U_r^T U_r = I_r$ (because $U_r$ is orthonormal), this simplifies to:

$$L^T L = \Sigma_r$$

Thus, the Frobenius norm squared of $L$ becomes:

$$\|L\|_F^2 = \text{trace}(\Sigma_r) = \sum_{i=1}^{r} \sigma_i$$

Similarly for $R$, we get:

$$R^T R = \Sigma_r$$

And thus:

$$\|R\|_F^2 = \sum_{i=1}^{r} \sigma_i$$

Add the Norms Together

Now, let's combine the Frobenius norms of $L$ and $R$:

$$\|L\|_F^2 + \|R\|_F^2 = \sum_{i=1}^{r} \sigma_i + \sum_{i=1}^{r} \sigma_i = 2 \sum_{i=1}^{r} \sigma_i = 2 \|W\|_*$$

Thus, we have the important result:

$$\frac{1}{2} \left( \|L\|_F^2 + \|R\|_F^2 \right) = \|W\|_*$$

Practical Algorithm: Optimizing Low-Rank Factors via a Surrogate Objective

With this derivation, we can achieve computational efficiency by parameterizing a weight matrix into low-rank factors:

$$W = LR^T$$

where $L \in \mathbb{R}^{d_{\text{out}} \times r}$ and $R \in \mathbb{R}^{d_{\text{in}} \times r}$ are the trainable parameters, and $r \ll \min(d_{\text{out}}, d_{\text{in}})$ is a fixed, pre-determined rank.

Instead of directly regularizing the nuclear norm $|LR^T|_*$, we employ our derived result:

$$|W|_* = \min_{W=LR^T} \frac{1}{2} \left( |L|_F^2 + |R|_F^2 \right)$$

Which suggests minimizing the task loss jointly with quadratic penalties on the Frobenius norms of the factors. Tauon therefore minimizes the following composite objective function $\Phi$ with respect to the factors $L$ and $R$:

$$\min_{L, R} \Phi(L, R) = L(LR^T) + \frac{\lambda}{2} \left( \|L\|_F^2 + \|R\|_F^2 \right)$$

Here, $L(LR^T)$ is the primary task loss evaluated using the reconstructed weights, and $\lambda \geq 0$ is a hyperparameter controlling the strength of the regularization on the factors, which serves as a surrogate for nuclear norm regularization on $W$.

The gradients of this smooth objective $\Phi$ with respect to the factors $L$ and $R$ can be computed exactly using the chain rule. Letting $G = \nabla_W L(LR^T)$ be the gradient of the task loss with respect to the implicit full matrix $W$, we have:

$$\nabla_L \Phi = GR + \lambda L$$ $$\nabla_R \Phi = G^T L + \lambda R$$

The Tauon Update Rule

This looks pretty familiar, doesn't it?

The intuitive next step is then to apply Stochastic Gradient Descent (SGD) with optional momentum to minimize the objective $\Phi(L, R)$ using the exact factor gradients derived above. It operates only on the factor parameters $L$ and $R$.

Given factor parameters $L_t$, $R_t$ at step $t$, a learning rate $\eta$, weight decay factor $\lambda$, momentum factor $\beta$, and Nesterov flag:

  1. Compute task loss gradient $G_t = \nabla_W L(L_t R_t^T)$ via backpropagation.
  2. Compute factor gradients w.r.t. task loss: $$\nabla_L L = G_t R_t, \quad \nabla_R L = G_t^T L_t$$
  3. Compute full gradients of the surrogate objective $\Phi$: $$g_{L,t} = \nabla_L L + \lambda L_t, \quad g_{R,t} = \nabla_R L + \lambda R_t$$
  4. If $\beta > 0$, update momentum buffers (example using standard momentum with dampening $d = 1 - \beta$): $$m_{L,t} = \beta m_{L,t-1} + (1 - \beta) g_{L,t}, \quad m_{R,t} = \beta m_{R,t-1} + (1 - \beta) g_{R,t}$$ Let $u_{L,t}$, $u_{R,t}$ be the final update directions derived from $g_{L,t}$, $g_{R,t}$, potentially incorporating $m_{L,t}$, $m_{R,t}$ and the Nesterov adjustment if enabled (otherwise $u = g$).
  5. Update factors: $$L_{t+1} = L_t - \eta u_{L,t}, \quad R_{t+1} = R_t - \eta u_{R,t}$$

There it is, the final algorithm. One that directly optimizes the low-rank factors using standard SGD dynamics applied to the surrogate objective, which implicitly encourages low-rank solutions via the factor norm penalties.

Importantly, it avoids computationally expensive SVD operations within the optimization loop.

Testing

Before jumping into the results part, there's a plot twist.

Everything we derived about the nuclear norm optimizer to this point is wrong. By saying they are wrong, I'm not talking about the math part. The math is correct (at least in my understanding), and I won't waste your time reading fake math proofs.

But the direction we optimized the nuclear norm is wrong. It should be the opposite. Instead of trying to minimize the sum of the singular values (which we did), the actual way of implementing this optimizer should be to maximize it. You should've probably also realized that if you understand nuclear norm's properties - we should promote higher ranking, not aggressively suppressing it.

Despite the misleading direction, I still decided to journal this process down and record my experiment, because both cases should be equally important from a learning perspective. And I hope they do.

Training a NanoGPT scale model with Tauon

Next off is the training process. For the experiment, I used a NanoGPT scaled model (6 Transformer layers, 8 attn heads, etc.) and trained it on the Tiny Shakespeare dataset for 15 epochs (around 4,700 iterations). To see the full model, you can access the full repository here.

I ran four different models in total (60 epochs), three having the same architecture but with different optimizers. The model that was trained with Tauon has a different architecture in its linear layers. I defined all the linear layers into low-rank linear layers, which instead of directly projecting dimension $x$ to $y$, first project $x$ to a rank $r$, and then from $r$ back to $y$. This is also the intuition behind the optimizer. Rank $r$ was set extremely tiny in this case, with a value of 4. If this experiment were to be further conducted, running an ablation study on its value should be a main priority.

The other three models are 1) muon optimizer-based NanoGPT 2) the standard NanoGPT and 3) Something I let o3 create (this had the worst performance).

Results

There's no training or validation loss plot this time.

I made a mistake by not screenshotting the Tensorboard results, and now the RTX 4090 was rented out.

But here is the one-word summary of Tauon's performance: bad. The training loss after all 15 epochs was higher than 3, while the standard nanogpt model achieved a training loss of lower than 0.15.

The inference speed was slightly higher for tauon, though nothing too significant.

Ultimately, I think the two biggest reasons for the bad performance are 1) I optimized the nuclear norm in the opposite direction (this is single-handedly the most devastating thing) and 2) The rank $r$ was set too low. The largest linear layer has a size of $3840 \times 3840$, but in Tauon's model, it was factored down to $3840 \times 4$ and $4 \times 3840$, losing more than 99% of its original weight.

That's a wrap. This was not a research-level project, but simply an experiment in which I learn by doing. There might be a future extension, stay tuned!

Home