This repository contains the code and resources for our research paper on Task-Specific Dynamic Token Pruning (TS-DTP) for Large Language Models (LLMs). TS-DTP is a novel approach designed to adapt token pruning to the specific demands of downstream NLP tasks, enhancing both efficiency and, in some cases, performance.
Large Language Models (LLMs) have achieved remarkable performance across a wide range of Natural Language Processing (NLP) tasks. However, their computational demands, particularly during inference, can be a major barrier to deployment, especially in resource-constrained environments. Dynamic token pruning, which selectively removes less crucial tokens during inference, has been proposed as a promising solution to reduce this computational burden.
Our work introduces Task-Specific Dynamic Token Pruning (TS-DTP), which builds upon the concept of dynamic token pruning but enhances it by incorporating task-specific information. This is achieved through task-specific attention mechanisms and feature representations, allowing the pruning strategy to be tailored to the specific requirements of each downstream task, ensuring the retention of the most relevant tokens. This adaptation is key to not only improving efficiency but also maintaining or improving performance on downstream tasks.
Our research makes the following key contributions:
The core components of our implementation are in the TS-DTP.ipynb
file.
TaskSpecificAttention
ClassThis module computes task-specific attention weights based on the input hidden states and task-specific feature representations.
task_specific_weight
) and a feature layer to compute the task-specific features.forward
method computes the task-specific attention based on these parameters.TaskSpecificDynamicTokenPruning
ClassThis class implements the core TS-DTP framework.
TaskSpecificAttention
modules for each layer.calculate_token_importance
calculates the cumulative importance scores for tokens.prune_tokens
implements a dynamic thresholding mechanism to prune tokens.forward
method combines task-specific attention with standard attention, prunes tokens, and then performs classification using a linear layer.calculate_auxiliary_loss
computes an auxiliary loss term for the training.The if __name__ == '__main__':
block shows an example of how to train and evaluate TS-DTP. It:
datasets
library from Hugging Face).A SST2Dataset
class provides the data loading and preparation for the sentiment analysis task, which includes tokenization, padding and truncation. This class should be extended for other tasks according to the task input requirements.
pip install torch transformers datasets scikit-learn tqdm
git clone https://github.com/ahmadpanah/TS-DTP
cd TS-DTP
python TS-DTP.ipynb
Adjust the model and dataset parameters in the if __name__ == '__main__':
section according to your needs.
The experimental evaluation, detailed in our paper, uses three datasets:
Our approach, TS-DTP, is compared to several baselines:
TS-DTP achieves competitive accuracy (or BLEU scores) with reduced computational cost (improved speed and memory consumption) on all datasets.
Future research directions include:
This project is licensed under the MIT License.
For any questions or inquiries, please contact: h.ahmadpanah@iau.ac.ir