Abstract:The efficiency of attention is critical because its time complexity grows quadratically with sequence length. SageAttention2 addresses this by utilizing quantization to accelerate matrix multiplications (Matmul) in attention. To further accelerate SageAttention2, we propose to utilize the faster instruction of FP8 Matmul accumulated in FP16. The instruction is 2x faster than the FP8 Matmul used in SageAttention2. Our experiments show that SageAttention2++ achieves a 3.9x speedup over FlashAttention while maintaining the same attention accuracy as SageAttention2. This means SageAttention2++ effectively accelerates various models, including those for language, image, and video generation, with negligible end-to-end metrics loss. The code will be available at https://github.com/thu-ml/SageAttention.
Abstract:The efficiency of attention is important due to its quadratic time complexity. We enhance the efficiency of attention through two key contributions: First, we leverage the new FP4 Tensor Cores in Blackwell GPUs to accelerate attention computation. Our implementation achieves 1038 TOPS on RTX5090, which is a 5x speedup over the fastest FlashAttention on RTX5090. Experiments show that our FP4 attention can accelerate inference of various models in a plug-and-play way. Second, we pioneer low-bit attention to training tasks. Existing low-bit attention works like FlashAttention3 and SageAttention focus only on inference. However, the efficiency of training large models is also important. To explore whether low-bit attention can be effectively applied to training tasks, we design an accurate and efficient 8-bit attention for both forward and backward propagation. Experiments indicate that 8-bit attention achieves lossless performance in fine-tuning tasks but exhibits slower convergence in pretraining tasks. The code will be available at https://github.com/thu-ml/SageAttention.
Abstract:Single domain generalization (SDG) has recently attracted growing attention in medical image segmentation. One promising strategy for SDG is to leverage consistent semantic shape priors across different imaging protocols, scanner vendors, and clinical sites. However, existing dictionary learning methods that encode shape priors often suffer from limited representational power with a small set of offline computed shape elements, or overfitting when the dictionary size grows. Moreover, they are not readily compatible with large foundation models such as the Segment Anything Model (SAM). In this paper, we propose a novel Mixture-of-Shape-Experts (MoSE) framework that seamlessly integrates the idea of mixture-of-experts (MoE) training into dictionary learning to efficiently capture diverse and robust shape priors. Our method conceptualizes each dictionary atom as a shape expert, which specializes in encoding distinct semantic shape information. A gating network dynamically fuses these shape experts into a robust shape map, with sparse activation guided by SAM encoding to prevent overfitting. We further provide this shape map as a prompt to SAM, utilizing the powerful generalization capability of SAM through bidirectional integration. All modules, including the shape dictionary, are trained in an end-to-end manner. Extensive experiments on multiple public datasets demonstrate its effectiveness.
Abstract:Transformer models have achieved remarkable success across various AI applications but face significant training costs. Low-bit training, such as INT8 training, can leverage computational units with higher throughput, and has already demonstrated its effectiveness on GPT2 models with block-level quantization. However, it struggles with modern Transformer variants incorporating GLU units. This is because those variants demonstrate complex distributions of activation outliers. To address the challenge, we propose Fallback Quantization, implementing mixed-precision GEMM that dynamically falls back 8-bit to 16-bit for activation blocks containing outliers. Experiments show that our approach is robustly competent in both fine-tuning and pretraining settings. Moreover, our method achieves a 1.57x end-to-end training speedup on RTX4090 GPUs.
Abstract:An efficient attention implementation is essential for large models due to its quadratic time complexity. Fortunately, attention commonly exhibits sparsity, i.e., many values in the attention map are near zero, allowing for the omission of corresponding computations. Many studies have utilized the sparse pattern to accelerate attention. However, most existing works focus on optimizing attention within specific models by exploiting certain sparse patterns of the attention map. A universal sparse attention that guarantees both the speedup and end-to-end performance of diverse models remains elusive. In this paper, we propose SpargeAttn, a universal sparse and quantized attention for any model. Our method uses a two-stage online filter: in the first stage, we rapidly and accurately predict the attention map, enabling the skip of some matrix multiplications in attention. In the second stage, we design an online softmax-aware filter that incurs no extra overhead and further skips some matrix multiplications. Experiments show that our method significantly accelerates diverse models, including language, image, and video generation, without sacrificing end-to-end metrics. The codes are available at https://github.com/thu-ml/SpargeAttn.
Abstract:Temporal embryo images and parental fertility table indicators are both valuable for pregnancy prediction in \textbf{in vitro fertilization embryo transfer} (IVF-ET). However, current machine learning models cannot make full use of the complementary information between the two modalities to improve pregnancy prediction performance. In this paper, we propose a Decoupling Fusion Network called DeFusion to effectively integrate the multi-modal information for IVF-ET pregnancy prediction. Specifically, we propose a decoupling fusion module that decouples the information from the different modalities into related and unrelated information, thereby achieving a more delicate fusion. And we fuse temporal embryo images with a spatial-temporal position encoding, and extract fertility table indicator information with a table transformer. To evaluate the effectiveness of our model, we use a new dataset including 4046 cases collected from Southern Medical University. The experiments show that our model outperforms state-of-the-art methods. Meanwhile, the performance on the eye disease prediction dataset reflects the model's good generalization. Our code and dataset are available at https://github.com/Ou-Young-1999/DFNet.
Abstract:Although quantization for linear layers has been widely used, its application to accelerate the attention process remains limited. SageAttention utilizes 8-bit matrix multiplication, 16-bit matrix multiplication with 16-bit accumulator, and precision-enhancing methods, implementing an accurate and 2x speedup kernel compared to FlashAttention2. To further enhance the efficiency of attention computation while maintaining precision, we propose SageAttention2, which utilizes significantly faster 4-bit matrix multiplication (Matmul) alongside additional precision-enhancing techniques. First, we propose to quantize matrixes $(Q, K)$ to INT4 in a warp-level granularity and quantize matrixes $(\widetilde P, V)$ to FP8. Second, we propose a method to smooth $Q$ and $V$, enhancing the accuracy of attention with INT4 $QK$ and FP8 $PV$. Third, we analyze the quantization accuracy across timesteps and layers, then propose an adaptive quantization method to ensure the end-to-end metrics over various models. The operations per second (OPS) of SageAttention2 surpass FlashAttention2 and xformers by about 3x and 5x on RTX4090, respectively. Comprehensive experiments confirm that our approach incurs negligible end-to-end metrics loss across diverse models, including those for large language processing, image generation, and video generation. The codes are available at https://github.com/thu-ml/SageAttention.
Abstract:Laryngo-pharyngeal cancer (LPC) is a highly lethal malignancy in the head and neck region. Recent advancements in tumor detection, particularly through dual-branch network architectures, have significantly improved diagnostic accuracy by integrating global and local feature extraction. However, challenges remain in accurately localizing lesions and fully capitalizing on the complementary nature of features within these branches. To address these issues, we propose SAM-Swin, an innovative SAM-driven Dual-Swin Transformer for laryngo-pharyngeal tumor detection. This model leverages the robust segmentation capabilities of the Segment Anything Model 2 (SAM2) to achieve precise lesion segmentation. Meanwhile, we present a multi-scale lesion-aware enhancement module (MS-LAEM) designed to adaptively enhance the learning of nuanced complementary features across various scales, improving the quality of feature extraction and representation. Furthermore, we implement a multi-scale class-aware guidance (CAG) loss that delivers multi-scale targeted supervision, thereby enhancing the model's capacity to extract class-specific features. To validate our approach, we compiled three LPC datasets from the First Affiliated Hospital (FAHSYSU), the Sixth Affiliated Hospital (SAHSYSU) of Sun Yat-sen University, and Nanfang Hospital of Southern Medical University (NHSMU). The FAHSYSU dataset is utilized for internal training, while the SAHSYSU and NHSMU datasets serve for external evaluation. Extensive experiments demonstrate that SAM-Swin outperforms state-of-the-art methods, showcasing its potential for advancing LPC detection and improving patient outcomes. The source code of SAM-Swin is available at the URL of \href{https://github.com/VVJia/SAM-Swin}{https://github.com/VVJia/SAM-Swin}.
Abstract:Diffusion models have recently gained recognition for generating diverse and high-quality content, especially in the domain of image synthesis. These models excel not only in creating fixed-size images but also in producing panoramic images. However, existing methods often struggle with spatial layout consistency when producing high-resolution panoramas, due to the lack of guidance of the global image layout. In this paper, we introduce the Multi-Scale Diffusion (MSD) framework, a plug-and-play module that extends the existing panoramic image generation framework to multiple resolution levels. By utilizing gradient descent techniques, our method effectively incorporates structural information from low-resolution images into high-resolution outputs. A comprehensive evaluation of the proposed method was conducted, comparing it with the prior works in qualitative and quantitative dimensions. The evaluation results demonstrate that our method significantly outperforms others in generating coherent high-resolution panoramas.
Abstract:Laryngo-pharyngeal cancer (LPC) is a highly fatal malignant disease affecting the head and neck region. Previous studies on endoscopic tumor detection, particularly those leveraging dual-branch network architectures, have shown significant advancements in tumor detection. These studies highlight the potential of dual-branch networks in improving diagnostic accuracy by effectively integrating global and local (lesion) feature extraction. However, they are still limited in their capabilities to accurately locate the lesion region and capture the discriminative feature information between the global and local branches. To address these issues, we propose a novel SAM-guided fusion network (SAM-FNet), a dual-branch network for laryngo-pharyngeal tumor detection. By leveraging the powerful object segmentation capabilities of the Segment Anything Model (SAM), we introduce the SAM into the SAM-FNet to accurately segment the lesion region. Furthermore, we propose a GAN-like feature optimization (GFO) module to capture the discriminative features between the global and local branches, enhancing the fusion feature complementarity. Additionally, we collect two LPC datasets from the First Affiliated Hospital (FAHSYSU) and the Sixth Affiliated Hospital (SAHSYSU) of Sun Yat-sen University. The FAHSYSU dataset is used as the internal dataset for training the model, while the SAHSYSU dataset is used as the external dataset for evaluating the model's performance. Extensive experiments on both datasets of FAHSYSU and SAHSYSU demonstrate that the SAM-FNet can achieve competitive results, outperforming the state-of-the-art counterparts. The source code of SAM-FNet is available at the URL of https://github.com/VVJia/SAM-FNet.