Deep learning models are now trained on increasingly larger datasets, making it crucial to reduce computational costs and improve data quality. Dataset distillation aims to distill a large dataset into a small synthesized dataset such that models trained on it can achieve similar performance to those trained on the original dataset. While there have been many empirical efforts to improve dataset distillation algorithms, a thorough theoretical analysis and provable, efficient algorithms are still lacking.
In this paper, by focusing on dataset distillation for kernel ridge regression (KRR), we show that one data point per class is already necessary and sufficient to recover the original model’s performance in many settings.
An illustration of dataset distillation from [1]
Given an original dataset: \((X, Y) \in \mathbb{R}^{d \times n} \times \mathbb{R}^{k \times n}\), where \(d\) is the dimenion of the data, \(n\) is the number of the original data, and \(k\) is the dimension of the label, dataset distillation aims to distill the large original dataset into a small synthesized dataset \((X_S, Y_S) \in \mathbb{R}^{d \times m} \times \mathbb{R}^{k \times n} \), such that models trained on it can achieve similar performance to those trained on the original dataset, where \(m\) is the number of distilled data and \(m \ll n\).
A Kernel ridge regression (KRR) is \(f(x) = W \phi(x)\), where \(\phi: \mathbb{R}^{d} \mapsto \mathbb{R}^{p}\) and \(W \in \mathbb{R}^{k \times p}\),
that minimize the following loss:
$$\min_{W} \|Y - W \phi(X)\|_F^2 + \lambda \|W\|^2$$
where \(\lambda > 0\) is the regularization parameter.
The solution can be computed analytically as \(W = Y \phi_{\lambda}(X)^+\), where
$$\phi_{\lambda}(X)^+ = \left\{
\begin{array}{ll}
(K(X, X) + \lambda I_n)^{-1} \phi(X)^\top = \phi(X)^\top (\phi(X) \phi(X)^\top + \lambda I_p)^{-1}, &\text{if $\lambda >0$, } \\
\phi(X)^+, &\text{if $\lambda =0$.}
\end{array} \right. $$
and \(K(X, X) = \phi(X)^\top \phi(X) \in \mathbb{R}^{n \times n}\). \(\phi_{\lambda}(X)\) can be considered as regularized features.
Linear ridge regression is a special case of kernel ridge regression with \(\phi(X) = X \).
Similarly, a KRR trained on distilled dataset with regularization \(\lambda_S \geq 0\) is \(f_S(x) = W_S \phi(x)\),
The goal of dataset distillation is to find \((X_S, Y_S)\) such that \(W_S = W\).
For a LRR model, we show that \(k\) distilled data points (one per class) are necessary and
sufficient to guarantee \(W_S = W\). We provide analytical solutions of such \((X_S, Y_S)\) allowing us to
compute the distilled dataset analytically instead of having to learn it heuristically in existing works.
Intuitively, original dataset \((X, Y)\) is compressed into \((X_S, Y_S)\) through original model’s parameter \(W\).
Besides, we can also
The results of LRR can be extended to KRR by replacing \(X_S\) with \(\phi(X_S) \). When \(\phi\) is surjective or bijective, we can always find a \(X_S\) for a desired \(\phi(X_S) \). Examples of surjective \(\phi: p \leq d\):
For non-surjective \(\phi\) such as deep nonlinear NNs, one data per class is generally not sufficient as long as \((Y_S^+ W)^+\) is not in the range space of \(\phi\). For deep linear NNs, we show \(m=k+1\) can be sufficient under certain conditions.
Below is a comparison with existing theoretical analysis of dataset distillation.
For surjective 𝜙, our algorithm outperforms previous work such as KIP [2] while being significantly more efficient.
[1] Tongzhou Wang, Jun-Yan Zhu, Antonio Torralba, Alexei A. Efros.
"Dataset Distillation."
2020.
[2] Timothy Nguyen, Roman Novak, Lechao Xiao, and Jaehoon Lee.
"Dataset distillation with infinitely wide convolutional networks."
Advances in Neural Information Processing Systems. 2021.
@inproceedings{
chen2024provable,
title={Provable and Efficient Dataset Distillation for Kernel Ridge Regression},
author={Yilan Chen and Wei Huang and Tsui-Wei Weng},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=WI2VpcBdnd}
}