Nguyen, D. M. H., Le, A. T., Nguyen, T. Q., Diep, N. T., Nguyen, T., Duong-Tran, D., Peters, J., Shen, L., Niepert, M., & Sonntag, D. (2024). Dude: Dual Distribution-Aware Context Prompt Learning For Large Vision-Language Model.
Proceedings of Machine Learning Research.
https://arxiv.org/abs/2407.04489
Abstract
Prompt learning methods are gaining increasing attention due to their ability to customize large vision-language models to new domains using pre-trained contextual knowledge and minimal training data. However, existing works typically rely on optimizing unified prompt inputs, often struggling with fine-grained classification tasks due to insufficient discriminative attributes. To tackle this, we consider a new framework based on a dual context of both domain-shared and class-specific contexts, where the latter is generated by Large Language Models (LLMs) such as GPTs. Such dual prompt methods enhance the model's feature representation by joining implicit and explicit factors encoded in LLM knowledge. Moreover, we formulate the Unbalanced Optimal Transport (UOT) theory to quantify the relationships between constructed prompts and visual tokens. Through partial matching, UOT can properly align discrete sets of visual tokens and prompt embeddings under different mass distributions, which is particularly valuable for handling irrelevant or noisy elements, ensuring that the preservation of mass does not restrict transport solutions. Furthermore, UOT's characteristics integrate seamlessly with image augmentation, expanding the training sample pool while maintaining a reasonable distance between perturbed images and prompt inputs. Extensive experiments across few-shot classification and adapter settings substantiate the superiority of our model over current state-of-the-art baselines.BibTeX
Tran, H.-C., Nguyen, D. M. H., Nguyen, M.-D., Le, N. H., & T. Nguyen, B. (2024, May). Energy Minimizing-based Token Merging for Accelerating Transformers. Proceedings of Practical ML for Low Resource Settings in Science Workshop at ICLR 2024, May 7-11, 2024, Austria.
Abstract
Model compression has been an active research field that has been used to reduce the size and complexity of the model. In a recent noteworthy study, ToMe and its variants utilize the Bipartite Soft Matching (BSM) algorithm in which tokens representing patches in an image are split into two sets, and top-k similar tokens from one set are merged. This approach utilizes pre-trained weights, enhances speed, and reduces memory usage. However, these algorithms have some drawbacks. First, the choice of a token-splitting strategy significantly influences algorithm performance since tokens in one set can only perceive tokens in the other set, leading to mis-merging issues. Furthermore, although ToMe is effective in the initial layers, it becomes increasingly problematic in deeper layers as the number of tokens diminishes because of damaged informative tokens. To address these limitations, rather than relying on specific splitting strategies like BSM, we propose a new algorithm called PiToMe. Specifically, we prioritize the protection of informative tokens using an additional factor called energy score. In experiments, PiToMe achieved up to a 50% memory reduction while exhibiting superior off-the-shelf performance on image classification ( keeping a 1.71% average performance drop compared to 2.6% for ToMe) and image-text retrieval (1.35% average performance drop compared to 6.89% for ToMe) compared to previous BSM-based approaches dependent solely on token similarity.BibTeX
Zaverkin, V., Holzmüller, D., Christiansen, H., Errica, F., Alesiani, F., Takamoto, M., Niepert, M., & Kästner, J. (2024). Uncertainty-biased molecular dynamics for learning uniformly accurate interatomic potentials.
Npj Comput. Mater.,
10(1), Article 1.
https://doi.org/10.1038/s41524-024-01254-1
Abstract
Efficiently creating a concise but comprehensive data set for training machine-learned interatomic potentials (MLIPs) is an under-explored problem. Active learning, which uses biased or unbiased molecular dynamics (MD) to generate candidate pools, aims to address this objective. Existing biased and unbiased MD-simulation methods, however, are prone to miss either rare events or extrapolative regions—areas of the configurational space where unreliable predictions are made. This work demonstrates that MD, when biased by the MLIP’s energy uncertainty, simultaneously captures extrapolative regions and rare events, which is crucial for developing uniformly accurate MLIPs. Furthermore, exploiting automatic differentiation, we enhance bias-forces-driven MD with the concept of bias stress. We employ calibrated gradient-based uncertainties to yield MLIPs with similar or, sometimes, better accuracy than ensemble-based methods at a lower computational cost. Finally, we apply uncertainty-biased MD to alanine dipeptide and MIL-53(Al), generating MLIPs that represent both configurational spaces more accurately than models trained with conventional MD.BibTeX
Zaverkin, V., Alesiani, F., Maruyama, T., Errica, F., Christiansen, H., Takamoto, M., Weber, N., & Niepert, M. (2024). Higher-Rank Irreducible Cartesian Tensors for Equivariant Message Passing.
In Proceedings of the 38th Annual Conference on Neural Information Processing Systems (NeurIPS 2024).
https://doi.org/10.48550/arXiv.2405.14253
Abstract
The ability to perform fast and accurate atomistic simulations is crucial for advancing the chemical sciences. By learning from high-quality data, machine-learned interatomic potentials achieve accuracy on par with ab initio and first-principles methods at a fraction of their computational cost. The success of machine-learned interatomic potentials arises from integrating inductive biases such as equivariance to group actions on an atomic system, e.g., equivariance to rotations and reflections. In particular, the field has notably advanced with the emergence of equivariant message passing. Most of these models represent an atomic system using spherical tensors, tensor products of which require complicated numerical coefficients and can be computationally demanding. Cartesian tensors offer a promising alternative, though state-of-the-art methods lack flexibility in message-passing mechanisms, restricting their architectures and expressive power. This work explores higher-rank irreducible Cartesian tensors to address these limitations. We integrate irreducible Cartesian tensor products into message-passing neural networks and prove the equivariance and traceless property of the resulting layers. Through empirical evaluations on various benchmark data sets, we consistently observe on-par or better performance than that of state-of-the-art spherical and Cartesian models.BibTeX
Hagnberger, J., Kalimuthu, M., Musekamp, D., & Niepert, M. (2024, May). Vectorized Conditional Neural Fields: A Framework for Solving Time-dependent PDEs. Proceedings of the AI4DifferentialEquations in Science Workshop at ICLR 2024, May 7-11, 2024, Austria.
Abstract
Neural Operators are a recent class of data-driven models for learning solutions to Partial Differential Equations (PDEs). Traditionally, these models are trained in an autoregressive fashion using data collected at discrete time points in the evolution of the PDE. This setup gives rise to two problems: (i) poor temporal generalization due to error accumulation and (ii) poor zero-shot super-resolution capabilities. To address these issues, we propose Vectorized Conditional Neural Fields (VCNeF), a general framework that utilizes transformers and implicit neural representations to efficiently solve time-dependent PDEs of varying coefficients. A comprehensive evaluation of VCNeF on the challenging 1D and 2D PDEs from PDEBench demonstrates the superiority of our model over four state-of-the-art baselines. Furthermore, our proposed model achieves faster inference and generalizes better to unseen PDE parameters than the compared models.BibTeX
Qian, C., Manolache, A., Morris, C., & Niepert, M. (2024). Probabilistic Graph Rewiring via Virtual Nodes.
In Proceedings of the 38th Annual Conference on Neural Information Processing Systems (NeurIPS 2024).
https://doi.org/10.48550/arXiv.2405.17311
Abstract
Message-passing graph neural networks (MPNNs) have emerged as a powerful paradigm for graph-based machine learning. Despite their effectiveness, MPNNs face challenges such as under-reaching and over-squashing, where limited receptive fields and structural bottlenecks hinder information flow in the graph. While graph transformers hold promise in addressing these issues, their scalability is limited due to quadratic complexity regarding the number of nodes, rendering them impractical for larger graphs. Here, we propose implicitly rewired message-passing neural networks (IPR-MPNNs), a novel approach that integrates implicit probabilistic graph rewiring into MPNNs. By introducing a small number of virtual nodes, i.e., adding additional nodes to a given graph and connecting them to existing nodes, in a differentiable, end-to-end manner, IPR-MPNNs enable long-distance message propagation, circumventing quadratic complexity. Theoretically, we demonstrate that IPR-MPNNs surpass the expressiveness of traditional MPNNs. Empirically, we validate our approach by showcasing its ability to mitigate under-reaching and over-squashing effects, achieving state-of-the-art performance across multiple graph datasets. Notably, IPR-MPNNs outperform graph transformers while maintaining significantly faster computational efficiency.BibTeX
Torres, E., & Niepert, M. (2024). Survey: Adaptive Physics-informed Neural Networks.
Neurips 2024 Workshop Foundation Models for Science: Progress, Opportunities, and Challenges.
https://openreview.net/forum?id=bYP6YB84Pq
Abstract
Physics-informed neural networks (PINNs) have emerged as a promising approach for solving partial differential equations (PDEs) using neural networks, particularly in data-scarce scenarios due to their unsupervised training capability. However, a key limitation is the need for re-optimization with each change in PDE parameters, similar to the challenge in traditional numerical methods where each system of equations corresponds to a specific PDE instance. This characteristic poses a barrier to the widespread adoption of PINNs across scientific and engineering applications. This survey explores research addressing this limitation through transfer learning and meta-learning, synthesizing insights to establish a foundation for efficient data generation strategies tailored to PINNs. These methods can potentially improve PINNs' training efficiency, enabling quicker adaptation to new PDEs with fewer data and computational demands. While numerical methods directly solve systems of equations to derive solutions, neural networks implicitly learn solutions by adjusting their parameters. One notable advantage of neural networks lies in their capacity to abstract away from specific problem domains, enabling them to retain, discard, or adapt learned representations to efficiently address similar problems. By understanding how these techniques can be applied to PINNs, this survey seeks to identify promising directions for future research to enable the widespread adoption of PINNs across a wide range of scientific and engineering applications.BibTeX
Abstract
In this work, we propose a simple transformer-based baseline for multimodal molecular representation learning, integrating three distinct modalities: SMILES strings, 2D graph representations, and 3D conformers of molecules. A key aspect of our approach is the aggregation of 3D conformers, allowing the model to account for the fact that molecules can adopt multiple conformations-an important factor for accurate molecular representation. The tokens for each modality are extracted using modality-specific encoders: a transformer for SMILES strings, a message-passing neural network for 2D graphs, and an equivariant neural network for 3D conformers. The flexibility and modularity of this framework enable easy adaptation and replacement of these encoders, making the model highly versatile for different molecular tasks. The extracted tokens are then combined into a unified multimodal sequence, which is processed by a downstream transformer for prediction tasks. To efficiently scale our model for large multimodal datasets, we utilize Flash Attention 2 and bfloat16 precision. Despite its simplicity, our approach achieves state-of-the-art results across multiple datasets, demonstrating its effectiveness as a strong baseline for multimodal molecular representation learning.BibTeX
Musekamp, D., Kalimuthu, M., Holzmüller, D., Takamoto, M., & Niepert, M. (2024). Active Learning for Neural PDE Solvers.
NeurIPS 2024 Workshop on Data-Driven and Differentiable Simulations, Surrogates, and Solvers.
https://openreview.net/forum?id=LD63WlGRQQ
Abstract
Solving partial differential equations (PDEs) is a fundamental problem in engineering and science. While neural PDE solvers can be more efficient than established numerical solvers, they often require large amounts of training data that is costly to obtain. Active Learning (AL) could help surrogate models reach the same accuracy with smaller training sets by querying classical solvers with more informative initial conditions and PDE parameters. While AL is more common in other domains, it has yet to be studied extensively for neural PDE solvers. To bridge this gap, we introduce AL4PDE, a modular and extensible active learning benchmark. It provides multiple parametric PDEs and state-of-the-art surrogate models for the solver-in-the-loop setting, enabling the evaluation of existing and the development of new AL methods for PDE solving. We use the benchmark to evaluate batch active learning algorithms such as uncertainty- and feature-based methods. We show that AL reduces the average error by up to 71% compared to random sampling and significantly reduces worst-case errors. Moreover, AL generates similar datasets across repeated runs, with consistent distributions over the PDE parameters and initial conditions. The acquired datasets are reusable, providing benefits for surrogate models not involved in the data generation.BibTeX
Wang, Z., Cai, S., Mu, Z., Lin, H., Zhang, C., Liu, X., Li, Q., Liu, A., Ma, X., & Liang, Y. (2024). OmniJARVIS: Unified Vision-Language-Action Tokenization Enables Open-World Instruction Following Agents.
In Proceedings of the 38th Annual Conference on Neural Information Processing Systems (NeurIPS 2024).
https://doi.org/10.48550/arXiv.2407.00114
Abstract
This paper presents OmniJARVIS, a novel Vision-Language-Action (VLA) model for open-world instruction-following agents in Minecraft. Compared to prior works that either emit textual goals to separate controllers or produce the control command directly, OmniJARVIS seeks a different path to ensure both strong reasoning and efficient decision-making capabilities via unified tokenization of multimodal interaction data. First, we introduce a self-supervised approach to learn a behavior encoder that produces discretized tokens for behavior trajectories and an imitation learning policy decoder conditioned on these tokens. These additional behavior tokens will be augmented to the vocabulary of pretrained Multimodal Language Models. With this encoder, we then pack long-term multimodal interactions involving task instructions, memories, thoughts, observations, textual responses, behavior trajectories, etc into unified token sequences and model them with autoregressive transformers. Thanks to the semantically meaningful behavior tokens, the resulting VLA model, OmniJARVIS, can reason (by producing chain-of-thoughts), plan, answer questions, and act (by producing behavior tokens for the imitation learning policy decoder). OmniJARVIS demonstrates excellent performances on a comprehensive collection of atomic, programmatic, and open-ended tasks in open-world Minecraft. Our analysis further unveils the crucial design principles in interaction data formation, unified tokenization, and its scaling potentials.BibTeX
Abstract
Graph Neural Networks (GNNs) are a popular class of machine learning models. Inspired by the learning to explain (L2X) paradigm, we propose L2XGNN, a framework for explainable GNNs that provides faithful explanations by design. L2XGNN learns a mechanism for selecting explanatory subgraphs (motifs), which are exclusively used in the GNN message-passing operations. L2XGNN can select, for each input graph, a subgraph with specific properties, such as being sparse and connected. Imposing such constraints on the motifs often leads to more interpretable and effective explanations. Experiments on several datasets suggest that L2XGNN achieves the same classification accuracy as baseline methods using the entire input graph while ensuring that only the provided explanations are used to make predictions. Moreover, we show that L2XGNN can identify motifs responsible for the graph's properties it is intended to predict.BibTeX
Qian, C., Manolache, A., Ahmed, K., Zeng, Z., den Broeck, G. V., Niepert, M., & Morris, C. (2024, May). Probabilistically Rewired Message-Passing Neural Networks.
Proceedings of the International Conference on Learning Representations(ICLR 2024), May 7--11, 2024, Austria.
https://doi.org/10.48550/arXiv.2310.02156
Abstract
Message-passing graph neural networks (MPNNs) emerged as powerful tools for processing graph-structured input. However, they operate on a fixed input graph structure, ignoring potential noise and missing information. Furthermore, their local aggregation mechanism can lead to problems such as over-squashing and limited expressive power in capturing relevant graph structures. Existing solutions to these challenges have primarily relied on heuristic methods, often disregarding the underlying data distribution. Hence, devising principled approaches for learning to infer graph structures relevant to the given prediction task remains an open challenge. In this work, leveraging recent progress in exact and differentiable k-subset sampling, we devise probabilistically rewired MPNNs (PR-MPNNs), which learn to add relevant edges while omitting less beneficial ones. For the first time, our theoretical analysis explores how PR-MPNNs enhance expressive power, and we identify precise conditions under which they outperform purely randomized approaches. Empirically, we demonstrate that our approach effectively mitigates issues like over-squashing and under-reaching. In addition, on established real-world datasets, our method exhibits competitive or superior predictive performance compared to traditional MPNN models and recent graph transformer architectures.BibTeX
Liu, A., Niepert, M., & den Broeck, G. V. (2024, May). Image Inpainting via Tractable Steering of Diffusion Models.
Proceedings of the International Conference on Learning Representations(ICLR 2024), May 7-11, 2024, Austria.
https://doi.org/10.48550/arXiv.2401.03349
Abstract
Diffusion models are the current state of the art for generating photorealistic images. However, controlling the sampling process for constrained image generation tasks such as inpainting remains challenging since exact conditioning on such constraints is intractable. While existing methods use various techniques to approximate the constrained posterior, this paper proposes to exploit the ability of Tractable Probabilistic Models (TPMs) to exactly and efficiently compute the constrained posterior, and to leverage this signal to steer the denoising process of diffusion models. Specifically, this paper adopts a class of expressive TPMs termed Probabilistic Circuits (PCs). Building upon prior advances, we further scale up PCs and make them capable of guiding the image generation process of diffusion models. Empirical results suggest that our approach can consistently improve the overall quality and semantic coherence of in painted images across three natural image datasets (i.e., CelebA-HQ, ImageNet, and LSUN) with only ~10% additional computational overhead brought by the TPM.BibTeX
Errica, F., & Niepert, M. (2024, May). Tractable Probabilistic Graph Representation Learning with Graph-Induced Sum-Product Networks.
Proceedings of the International Conference on Learning Representations(ICLR 2024), May 7-11, 2024, Austria.
https://doi.org/10.48550/arXiv.2305.10544
Abstract
We introduce Graph-Induced Sum-Product Networks (GSPNs), a new probabilistic framework for graph representation learning that can tractably answer probabilistic queries. Inspired by the computational trees induced by vertices in the context of message-passing neural networks, we build hierarchies of sum-product networks (SPNs) where the parameters of a parent SPN are learnable transformations of the a-posterior mixing probabilities of its children's sum units. Due to weight sharing and the tree-shaped computation graphs of GSPNs, we obtain the efficiency and efficacy of deep graph networks with the additional advantages of a probabilistic model. We show the model's competitiveness on scarce supervision scenarios, under missing data, and for graph classification in comparison to popular neural models. We complement the experiments with qualitative analyses on hyper-parameters and the model's ability to answer probabilistic queries.BibTeX
Elenter, J., Chamon, L. F. O., & Ribeiro, A. (2024, May). Near-Optimal Solutions of Constrained Learning Problems.
Proceedings of the International Conference on Learning Representations(ICLR 2024), May 7-11, 2024, Austria.
https://doi.org/10.48550/arXiv.2403.11844
Abstract
With the widespread adoption of machine learning systems, the need to curtail their behavior has become increasingly apparent. This is evidenced by recent advancements towards developing models that satisfy robustness, safety, and fairness requirements. These requirements can be imposed (with generalization guarantees) by formulating constrained learning problems that can then be tackled by dual ascent algorithms. Yet, though these algorithms converge in objective value, even in non-convex settings, they cannot guarantee that their outcome is feasible. Doing so requires randomizing over all iterates, which is impractical in virtually any modern applications. Still, final iterates have been observed to perform well in practice. In this work, we address this gap between theory and practice by characterizing the constraint violation of Lagrangian minimizers associated with optimal dual variables, despite lack of convexity. To do this, we leverage the fact that non-convex, finite-dimensional constrained learning problems can be seen as parametrizations of convex, functional problems. Our results show that rich parametrizations effectively mitigate the issue of feasibility in dual methods, shedding light on prior empirical successes of dual learning. We illustrate our findings in fair learning tasks.BibTeX
Liu, X., Liu, A., den Broeck, G. V., & Liang, Y. (2024). A Tractable Inference Perspective of Offline RL.
In Proceedings of the 38th Annual Conference on Neural Information Processing Systems (NeurIPS 2024).
https://doi.org/10.48550/arXiv.2311.00094
Abstract
A popular paradigm for offline Reinforcement Learning (RL) tasks is to first fit the offline trajectories to a sequence model, and then prompt the model for actions that lead to high expected return. In addition to obtaining accurate sequence models, this paper highlights that tractability, the ability to exactly and efficiently answer various probabilistic queries, plays an important role in offline RL. Specifically, due to the fundamental stochasticity from the offline data-collection policies and the environment dynamics, highly non-trivial conditional/constrained generation is required to elicit rewarding actions. it is still possible to approximate such queries, we observe that such crude estimates significantly undermine the benefits brought by expressive sequence models. To overcome this problem, this paper proposes Trifle (Tractable Inference for Offline RL), which leverages modern Tractable Probabilistic Models (TPMs) to bridge the gap between good sequence models and high expected returns at evaluation time. Empirically, Trifle achieves the most state-of-the-art scores in 9 Gym-MuJoCo benchmarks against strong baselines. Further, owing to its tractability, Trifle significantly outperforms prior approaches in stochastic environments and safe RL tasks (e.g. with action constraints) with minimum algorithmic modifications.BibTeX
Hagnberger, J., Kalimuthu, M., Musekamp, D., & Niepert, M. (2024). Vectorized Conditional Neural Fields: A Framework for Solving Time-dependent Parametric Partial Differential Equations.
In Proceedings of the 41st International Conference on Machine Learning (ICML 2024).
https://arxiv.org/abs/2406.03919
Abstract
Transformer models are increasingly used for solving Partial Differential Equations (PDEs). Several adaptations have been proposed, all of which suffer from the typical problems of Transformers, such as quadratic memory and time complexity. Furthermore, all prevalent architectures for PDE solving lack at least one of several desirable properties of an ideal surrogate model, such as (i) generalization to PDE parameters not seen during training, (ii) spatial and temporal zero-shot super-resolution, (iii) continuous temporal extrapolation, (iv) support for 1D, 2D, and 3D PDEs, and (v) efficient inference for longer temporal rollouts. To address these limitations, we propose Vectorized Conditional Neural Fields (VCNeFs), which represent the solution of time-dependent PDEs as neural fields. Contrary to prior methods, however, VCNeFs compute, for a set of multiple spatio-temporal query points, their solutions in parallel and model their dependencies through attention mechanisms. Moreover, VCNeF can condition the neural field on both the initial conditions and the parameters of the PDEs. An extensive set of experiments demonstrates that VCNeFs are competitive with and often outperform existing ML-based surrogate models.BibTeX
Nguyen, D. M. H., Lukashina, N., Nguyen, T., Le, A. T., Nguyen, T., Ho, N., Peters, J., Sonntag, D., Zaverkin, V., & Niepert, M. (2024). Structure-Aware E(3)-Invariant Molecular Conformer Aggregation Networks.
In Proceedings of the 41st International Conference on Machine Learning (ICML 2024).
https://arxiv.org/abs/2402.01975
Abstract
A molecule's 2D representation consists of its atoms, their attributes, and the molecule's covalent bonds. A 3D (geometric) representation of a molecule is called a conformer and consists of its atom types and Cartesian coordinates. Every conformer has a potential energy, and the lower this energy, the more likely it occurs in nature. Most existing machine learning methods for molecular property prediction consider either 2D molecular graphs or 3D conformer structure representations in isolation. Inspired by recent work on using ensembles of conformers in conjunction with 2D graph representations, we propose E(3)-invariant molecular conformer aggregation networks. The method integrates a molecule's 2D representation with that of multiple of its conformers. Contrary to prior work, we propose a novel 2D-3D aggregation mechanism based on a differentiable solver for the Fused Gromov-Wasserstein Barycenter problem and the use of an efficient conformer generation method based on distance geometry. We show that the proposed aggregation mechanism is E(3) invariant and propose an efficient GPU implementation. Moreover, we demonstrate that the aggregation mechanism helps to significantly outperform state-of-the-art molecule property prediction methods on established datasets.BibTeX
Chamon, L. F. O., Karimi, M. R., & Korba, A. (2024). Constrained Sampling with Primal-Dual Langevin Monte Carlo.
In Proceedings of the 38th Annual Conference on Neural Information Processing Systems (NeurIPS 2024).
https://doi.org/10.48550/arXiv.2411.00568
Abstract
This work considers the problem of sampling from a probability distribution known up to a normalization constant while satisfying a set of statistical constraints specified by the expected values of general nonlinear functions. This problem finds applications in, e.g., Bayesian inference, where it can constrain moments to evaluate counterfactual scenarios or enforce desiderata such as prediction fairness. Methods developed to handle support constraints, such as those based on mirror maps, barriers, and penalties, are not suited for this task. This work therefore relies on gradient descent-ascent dynamics in Wasserstein space to put forward a discrete-time primal-dual Langevin Monte Carlo algorithm (PD-LMC) that simultaneously constrains the target distribution and samples from it. We analyze the convergence of PD-LMC under standard assumptions on the target distribution and constraints, namely (strong) convexity and log-Sobolev inequalities. To do so, we bring classical optimization arguments for saddle-point algorithms to the geometry of Wasserstein space. We illustrate the relevance and effectiveness of PD-LMC in several applications.BibTeX
Tran, H.-C., Nguyen, D. M. H., Nguyen, D. M., Nguyen, T.-T., Le, N., Xie, P., Sonntag, D., Zou, J. Y., Nguyen, B. T., & Niepert, M. (2024). Accelerating Transformers with Spectrum-Preserving Token Merging.
In Proceedings of the 38th Annual Conference on Neural Information Processing Systems (NeurIPS 2024).
https://doi.org/10.48550/arXiv.2405.16148
Abstract
Increasing the throughput of the Transformer architecture, a foundational component used in numerous state-of-the-art models for vision and language tasks (e.g., GPT, LLaVa), is an important problem in machine learning. One recent and effective strategy is to merge token representations within Transformer models, aiming to reduce computational and memory requirements while maintaining accuracy. Prior works have proposed algorithms based on Bipartite Soft Matching (BSM), which divides tokens into distinct sets and merges the top k similar tokens. However, these methods have significant drawbacks, such as sensitivity to token-splitting strategies and damage to informative tokens in later layers. This paper presents a novel paradigm called PiToMe, which prioritizes the preservation of informative tokens using an additional metric termed the energy score. This score identifies large clusters of similar tokens as high-energy, indicating potential candidates for merging, while smaller (unique and isolated) clusters are considered as low-energy and preserved. Experimental findings demonstrate that PiToMe saved from 40-60% FLOPs of the base models while exhibiting superior off-the-shelf performance on image classification (0.5% average performance drop of ViT-MAE-H compared to 2.6% as baselines), image-text retrieval (0.3% average performance drop of CLIP on Flickr30k compared to 4.5% as others), and analogously in visual questions answering with LLaVa-7B. Furthermore, PiToMe is theoretically shown to preserve intrinsic spectral properties of the original token space under mild conditions.BibTeX