(Translated by https://www.hiragana.jp/)
Navigate Beyond Shortcuts: Debiased Learning through the Lens of Neural Collapse

Navigate Beyond Shortcuts:
Debiased Learning through the Lens of Neural Collapse

Yining Wang, Junjie Sun, Chenyue Wang, Mi Zhang , Min Yang
School of Computer Science, Fudan University, China
{ynwang22@m.,jjsun22@m.,wangcy23@m.,mi_zhang@,m_yang@}fudan.edu.cn
Corresponding Author.
Abstract

Recent studies have noted an intriguing phenomenon termed Neural Collapse, that is, when the neural networks establish the right correlation between feature spaces and the training targets, their last-layer features, together with the classifier weights, will collapse into a stable and symmetric structure. In this paper, we extend the investigation of Neural Collapse to the biased datasets with imbalanced attributes. We observe that models will easily fall into the pitfall of shortcut learning and form a biased, non-collapsed feature space at the early period of training, which is hard to reverse and limits the generalization capability. To tackle the root cause of biased classification, we follow the recent inspiration of prime training, and propose an avoid-shortcut learning framework without additional training complexity. With well-designed shortcut primes based on Neural Collapse structure, the models are encouraged to skip the pursuit of simple shortcuts and naturally capture the intrinsic correlations. Experimental results demonstrate that our method induces better convergence properties during training, and achieves state-of-the-art generalization performance on both synthetic and real-world biased datasets.

1 Introduction

Refer to caption
Figure 1: Illustration of (a) Neural Collapse phenomenon on balanced datasets, where the simplex ETF structure maximizes the class-wise angles, and (b) Biased classification on datasets with imbalanced attributes, where the model takes the shortcut of attributes to make predictions and fails to collapse into the simplex ETF. The color of points represents different class labels and the shape of points represents different attributes.

When the input-output correlation learned by a neural network is consistent with its training target, the last-layer features and classifier weights will attract and reinforce each other, forming a stable, symmetric and robust structure. Just as the Neural Collapse phenomenon discovered by Papyan et al. [24], at the terminal phase of training on balanced datasets, a model will witness its last-layer features of the same class converge towards the class centers, and the classifier weights align to these class centers correspondingly. The convergence will ultimately lead to the collapse of feature space into a simplex equiangular tight frame (ETF) structure, as illustrated in Fig. 1(a). The elegant structure has demonstrated its efficacy in enhancing the generalization, robustness, and interpretability of the trained models [24, 5]. Therefore, a wave of empirical and theoretical analysis of Neural Collapse has been proposed [2, 8, 6, 25, 26, 38, 40], and a series of studies have adopted the simplex ETF as the optimal geometric structure of the classifier, to guide the maximized class-wise separation in class-imbalanced training [20, 43, 37, 35, 36].

However, in practical visual recognition tasks, besides the challenge of inter-class imbalance, we also encounter intra-class imbalance, where the majority of samples are dominated by the bias attributes (e.g., some misleading contents such as background, color, texture, etc.). For example, the widely used LFW dataset [12] for facial recognition has been demonstrated severely imbalanced in gender, age and ethnicity [3]. A biased dataset often contains a majority of bias-aligned samples and a minority of bias-conflicting ones. The prevalent bias-aligned samples exhibit a strong correlation between the ground-truth labels and bias attributes, while the scarce bias-conflicting samples have no such correlation. Once a model relies on the simple but spurious shortcut of bias attributes for prediction, it will ignore the intrinsic relations and struggle to generalize on out-of-distribution test samples. The potential impact of biased classification may range from political and economic disparities to social inequalities within AI systems, as emphasized in EDRi’s latest report [1].

Therefore, the fundamental solution to biased classification lies in deferring, or ideally, preventing the learning of shortcut correlations. However, previous debiased learning methods rely heavily on additional training expenses. For example, a bias-amplified auxiliary model is often adopted to identify and up-weight the bias-conflicting samples [23, 19, 16], or employed to guide the input-level and feature-level augmentations [13, 18, 21]. Some disentangle-based debiasing methods, from the perspective of causal intervention [32, 42] or Information Bottleneck theory [30], also require large amounts of contrastive samples or pre-training process to disentangle the biased features, significantly increasing the burden of debiased learning.

In this paper, we extend the investigation of Neural Collapse to the biased visual datasets with imbalanced attributes. Through the lens of Neural Collapse, we observe that models prioritize the period of shortcut learning, and quickly form the biased feature space based on misleading attributes at the early stage of training. After the bias-aligned samples reach zero training error, the intrinsic correlation within bias-conflicting samples will then be discovered. However, due to i) the scarcity of bias-conflicting samples and ii) the stability of the established feature space, the learned shortcut correlation is challenging to reverse and eliminate. The mismatch between bias feature space and the training target induces inferior generalizability, and hinders the convergence of Neural Collapse, as shown in Fig. 1(b).

To achieve efficient model debiasing, we follow the inspiration of prime training, and encourage the model to skip the active learning of shortcut correlations. The primes are often provided as additional supervisory signals to redirect the model’s reliance on shortcuts, which helps improve generalization in image classification and CARLA autonomous driving [33]. To rectify models’ attention on the intrinsic correlations, we define the primes with a training-free simplex ETF structure, which approximates the “optimal" shortcut features and guides the model to pursue unbiased classification from the beginning of training. Our method is free of auxiliary models or additional optimization of prime features. Experimental results also substantiate its state-of-the-art debiasing performance on both synthetic and real-world biased datasets.

Our contributions are summarized as follows:

  • For the first time, we investigate the Neural Collapse phenomenon on biased datasets with imbalanced attributes. Through the empirical results of feature convergence, we analyze the shortcut learning stage of training, as well as the fundamental issues of biased classification.

  • We propose an efficient avoid-shortcut training paradigm, which introduces the simplex ETF structure as prime features, to rectify models’ attention on the intrinsic correlations.

  • We demonstrate the state-of-the-art debiasing performance of our method on 2 synthetic and 3 real-world biased datasets, as well as the better convergence properties of debiased models.

2 Related Works

Debiased Learning. Extensive efforts have been dedicated to model debiasing, but they are significantly limited by additional training costs. Recent advances can be divided into three categories: reweight-based, augmentation-based, and disentangle-based. Based on the easy-to-learn heuristic of biased features, reweight-based approaches require pre-trained bias-amplified models to identify and emphasize the bias-conflicting training samples [23, 19, 16]. Augmentation-based approaches, with the guidance of explicit bias annotations, conduct image-level and feature-level augmentations to enhance the diversity of training datasets [13, 18, 21]. Other disentangle-based approaches attempt to remove the bias-related part of features, from the perspective of Information Bottleneck theory [30] or causal intervention [32, 42], but at the cost of substantial contrastive samples. Additionally, model debiasing is also well studied in graph neural networks [4, 41], language models [7, 22] and multi-modal tasks [34, 11].

Neural Collapse. Discovered by Papyan et al. [24], the Neural Collapse phenomenon reveals the convergence of the last-layer feature space to an elegant geometry. At the terminal phase of training on balanced datasets, the feature centers and classifier weights will collapse together into the structure of a simplex ETF, which is illustrated in Section 3.1. Recent works have dug deeper into the phenomenon and provided theoretical supports under different constraints or regularizations [2, 40, 8], as well as empirical studies of intermediate features and transfer learning [26, 38, 25, 6]. Considering the class-imbalanced datasets, Fang et al. [5] point out the Minority Collapse phenomenon, where features of long-tailed classes will merge together and be hard to classify. As a remedy, they fix the classifier as an ETF structure during training, which guarantees the optimal geometric property in imbalanced learning [36], semantic segmentation [43], and federated learning [37]. To take a step further, our work fills the gap of Neural Collapse analysis on biased datasets with shortcut correlations.

Avoid-shortcut Learning. The recent inspiration of avoid-shortcut learning aims to postpone, or even prevent the learning of shortcut relations in model training. With well-crafted contrastive samples [27, 29, 28] or artificial shortcut signals [33, 42], avoid-shortcut learning has demonstrated its efficacy in image classification, autonomous driving and question answering models. One of the representative methods is named prime training, which provides richer supervisory signals of key input features (i.e., primes) to guide the establishment of correct correlations, therefore improving generalization on OOD samples [33]. In this work, we leverage the approximated “optimal" shortcuts as primes to encourage the models to bypass shortcut learning.

3 Preliminaries

3.1 Neural Collapse Phenomenon

Consider a biased dataset 𝒟𝒟\mathcal{D}caligraphic_D with K𝐾Kitalic_K classes of training samples, we denote 𝐱k,isubscript𝐱𝑘𝑖\mathbf{x}_{k,i}bold_x start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT as the i𝑖iitalic_i-th sample of the k𝑘kitalic_k-th class and 𝐳k,idsubscript𝐳𝑘𝑖superscript𝑑\mathbf{z}_{k,i}\in\mathbb{R}^{d}bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT as its corresponding last-layer feature. A linear classifier with weights 𝐖=[𝐰1,,𝐰K]d×K𝐖subscript𝐰1subscript𝐰𝐾superscript𝑑𝐾\mathbf{W}=[\mathbf{w}_{1},...,\mathbf{w}_{K}]\in\mathbb{R}^{d\times K}bold_W = [ bold_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_w start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_K end_POSTSUPERSCRIPT is trained upon the last-layer features to make predictions.

The Neural Collapse (NC) phenomenon discovered that, when neural networks are trained on balanced datasets, the correctly learned correlations will naturally lead to the convergence of feature spaces. Given enough training steps after the zero classification error, the last-layer features and classifier weights will collapse to the vertices of a simplex equiangular tight frame (ETF), which is defined as below.

Definition 1 (Simplex Equiangular Tight Frame) A collection of vectors 𝐦kdsubscript𝐦𝑘superscript𝑑\mathbf{m}_{k}\in\mathbb{R}^{d}bold_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, k=1,2,,K,dK1formulae-sequence𝑘12𝐾𝑑𝐾1k=1,2,...,K,d\geq K-1italic_k = 1 , 2 , … , italic_K , italic_d ≥ italic_K - 1 is said to be a k𝑘kitalic_k-simplex equiangular tight frame if:

𝐌=KK1𝐏(𝐈K1K𝟏K𝟏KT)𝐌𝐾𝐾1𝐏subscript𝐈𝐾1𝐾subscript1𝐾superscriptsubscript1𝐾T\mathbf{M}=\sqrt{\frac{K}{K-1}}\mathbf{P}(\mathbf{I}_{K}-\frac{1}{K}\mathbf{1}% _{K}\mathbf{1}_{K}^{\mathrm{T}})bold_M = square-root start_ARG divide start_ARG italic_K end_ARG start_ARG italic_K - 1 end_ARG end_ARG bold_P ( bold_I start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_K end_ARG bold_1 start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT ) (1)

where 𝐌=[𝐦1,,𝐦K]d×K𝐌subscript𝐦1subscript𝐦𝐾superscript𝑑𝐾\mathbf{M}=[\mathbf{m}_{1},...,\mathbf{m}_{K}]\in\mathbb{R}^{d\times K}bold_M = [ bold_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_K end_POSTSUPERSCRIPT, and 𝐏d×K𝐏superscript𝑑𝐾\mathbf{P}\in\mathbb{R}^{d\times K}bold_P ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_K end_POSTSUPERSCRIPT is an orthogonal matrix which satisfies 𝐏T𝐏=𝐈Ksuperscript𝐏T𝐏subscript𝐈𝐾\mathbf{P}^{\mathrm{T}}\mathbf{P}=\mathbf{I}_{K}bold_P start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_P = bold_I start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, with 𝐈Ksubscript𝐈𝐾\mathbf{I}_{K}bold_I start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT denotes the identity matrix and 𝟏Ksubscript1𝐾\mathbf{1}_{K}bold_1 start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT denotes the all-ones vector. Within the ETF structure, all vectors have the maximal pair-wise angle of 1K11𝐾1-\frac{1}{K-1}- divide start_ARG 1 end_ARG start_ARG italic_K - 1 end_ARG, namely the maximal equiangular separation.

Besides the convergence to simplex ETF structure, the Neural Collapse phenomenon could be concluded as the following properties during the terminal phase of training:

NC1: Variability collapse. The last-layer features 𝐳k,isubscript𝐳𝑘𝑖\mathbf{z}_{k,i}bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT of the same class k𝑘kitalic_k will collapse to their class means 𝐳¯k=Avgi{𝐳k,i}subscript¯𝐳𝑘subscriptAvg𝑖subscript𝐳𝑘𝑖\mathbf{\overline{z}}_{k}={\rm Avg}_{i}\{\mathbf{z}_{k,i}\}over¯ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = roman_Avg start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT { bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT }, and the within-class variation of the last-layer features will approach 0.

NC2: Convergence to simplex ETF. The normalized class means will collapse to the vertices of a simplex ETF. We denote the global mean of all last-layer features as 𝐳G=Avgi,k{𝐳k,i},k[1,,K]formulae-sequencesubscript𝐳𝐺subscriptAvg𝑖𝑘subscript𝐳𝑘𝑖𝑘1𝐾\mathbf{z}_{G}={\rm Avg}_{i,k}\{\mathbf{z}_{k,i}\},k\in[1,...,K]bold_z start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT = roman_Avg start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT { bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT } , italic_k ∈ [ 1 , … , italic_K ] and the normalized class means as 𝐳~k=(𝐳k𝐳G)/𝐳k𝐳Gsubscript~𝐳𝑘subscript𝐳𝑘subscript𝐳𝐺delimited-∥∥subscript𝐳𝑘subscript𝐳𝐺\mathbf{{\tilde{z}}}_{k}=(\mathbf{z}_{k}-\mathbf{z}_{G})/\lVert\mathbf{z}_{k}-% \mathbf{z}_{G}\rVertover~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = ( bold_z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - bold_z start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ) / ∥ bold_z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - bold_z start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ∥, which satisfies Eq.1.

NC3: Self duality. The classifier weights 𝐰ksubscript𝐰𝑘\mathbf{w}_{k}bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT will align with the corresponding normalized class means 𝐳~ksubscript~𝐳𝑘\mathbf{{\tilde{z}}}_{k}over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, which satisfies 𝐳~k=𝐰k/𝐰ksubscript~𝐳𝑘subscript𝐰𝑘normsubscript𝐰𝑘\mathbf{{\tilde{z}}}_{k}=\mathbf{w}_{k}/||\mathbf{w}_{k}||over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT / | | bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT | |.

NC4: Simplification to nearest class center. After convergence, the model’s prediction will collapse to simply choosing the nearest class mean to the input feature (in standard Euclidean distance). The prediction of 𝐳𝐳\mathbf{z}bold_z could be denoted as argmaxk𝐳,𝐰k=argmink𝐳𝐳¯ksubscript𝑘𝐳subscript𝐰𝑘subscript𝑘norm𝐳subscript¯𝐳𝑘\arg\max_{k}\langle\mathbf{z},\mathbf{w}_{k}\rangle=\arg\min_{k}||\mathbf{z}-% \mathbf{\overline{z}}_{k}||roman_arg roman_max start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⟨ bold_z , bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⟩ = roman_arg roman_min start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT | | bold_z - over¯ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT | |.

Refer to caption
Figure 2: Comparison of (a) testset accuracy and (b-d) Neural Collapse metrics on unbiased (CIFAR-10) and synthetic biased (Corrupted CIFAR-10 with the bias ratio of 5.0%) datasets. All vanilla models are trained with standard cross-entropy loss for 500 epochs. The postfix -Aligned and -Conflicting indicate the results of bias-aligned and bias-conflicting samples respectively. The NC1 metric evaluates the convergence of same-class features, NC2 evaluates the difference between the feature space and a simplex ETF, and NC3 measures the duality between feature centers and classifier weights. The vertical dashed line at the epoch of 60 divides two stages of training.

3.2 Neural Collapse Observation on Biased Dataset

Besides the findings on balanced datasets, some studies have explored Neural Collapse under the class-imbalanced situation [36, 5]. Taking a step further, we investigate the phenomenon on biased datasets with imbalanced attributes, to advance the understanding of biased classification. To examine the convergence of last-layer features and classifier weights, we compare the metrics of Neural Collapse on both unbiased and synthetic biased datasets. As shown in Fig. 2, we report the result of NC1-NC3, which corresponds to the first three convergence properties in Section 3.1 and respectively evaluates the convergence of same-class features, the structure of feature space and self-duality. The details of NC metrics are concluded in Tab. 1.

Metrics Computational details NC1 𝒩𝒞1=1KTr(ΣWΣB)𝒩subscript𝒞11𝐾TrsubscriptΣ𝑊superscriptsubscriptΣ𝐵{\rm\mathcal{NC}_{1}}=\frac{1}{K}{\rm Tr}(\Sigma_{W}\Sigma_{B}^{\dagger})caligraphic_N caligraphic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_K end_ARG roman_Tr ( roman_Σ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT roman_Σ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT † end_POSTSUPERSCRIPT ), where TrTr{\rm Tr}roman_Tr is the trace of matrix and ΣBsuperscriptsubscriptΣ𝐵\Sigma_{B}^{\dagger}roman_Σ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT † end_POSTSUPERSCRIPT denotes the pseudo-inverse of ΣBsubscriptΣ𝐵\Sigma_{B}roman_Σ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ΣB=1Kk[K](𝐳¯k𝐳G)(𝐳¯k𝐳G)TsubscriptΣ𝐵1𝐾subscript𝑘delimited-[]𝐾subscript¯𝐳𝑘subscript𝐳𝐺superscriptsubscript¯𝐳𝑘subscript𝐳𝐺T\Sigma_{B}=\frac{1}{K}\sum\limits_{k\in[K]}(\mathbf{\overline{z}}_{k}-\mathbf{% z}_{G})(\mathbf{\overline{z}}_{k}-\mathbf{z}_{G})^{\mathrm{T}}roman_Σ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_K end_ARG ∑ start_POSTSUBSCRIPT italic_k ∈ [ italic_K ] end_POSTSUBSCRIPT ( over¯ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - bold_z start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ) ( over¯ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - bold_z start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT, ΣW=1Kk[K]1nki=1nk(𝐳k,i𝐳¯k)(𝐳k,i𝐳¯k)TsubscriptΣ𝑊1𝐾subscript𝑘delimited-[]𝐾1subscript𝑛𝑘superscriptsubscript𝑖1subscript𝑛𝑘subscript𝐳𝑘𝑖subscript¯𝐳𝑘superscriptsubscript𝐳𝑘𝑖subscript¯𝐳𝑘T\Sigma_{W}=\frac{1}{K}\sum\limits_{k\in[K]}\frac{1}{n_{k}}\sum\limits_{i=1}^{n% _{k}}(\mathbf{z}_{k,i}-\mathbf{\overline{z}}_{k})(\mathbf{z}_{k,i}-\mathbf{% \overline{z}}_{k})^{\mathrm{T}}roman_Σ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_K end_ARG ∑ start_POSTSUBSCRIPT italic_k ∈ [ italic_K ] end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT - over¯ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT - over¯ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT NC2 𝒩𝒞2=𝐖𝐖T𝐖𝐖T1K1(𝐈K1K𝟏K𝟏KT)𝒩subscript𝒞2subscriptdelimited-∥∥superscript𝐖𝐖Tsubscriptnormsuperscript𝐖𝐖T1𝐾1subscript𝐈𝐾1𝐾subscript1𝐾superscriptsubscript1𝐾T{\rm\mathcal{NC}_{2}}=\Big{\lVert}\frac{\mathbf{W}\mathbf{W}^{\mathrm{T}}}{||% \mathbf{W}\mathbf{W}^{\mathrm{T}}||_{\mathcal{F}}}-\frac{1}{\sqrt{K-1}}(% \mathbf{I}_{K}-\frac{1}{K}\mathbf{1}_{K}\mathbf{1}_{K}^{\mathrm{T}})\Big{% \rVert}_{\mathcal{F}}caligraphic_N caligraphic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = ∥ divide start_ARG bold_WW start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT end_ARG start_ARG | | bold_WW start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT | | start_POSTSUBSCRIPT caligraphic_F end_POSTSUBSCRIPT end_ARG - divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_K - 1 end_ARG end_ARG ( bold_I start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_K end_ARG bold_1 start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT caligraphic_F end_POSTSUBSCRIPT NC3 𝒩𝒞3=𝐖𝐙¯𝐖𝐙¯1K1(𝐈K1K𝟏K𝟏KT)𝒩subscript𝒞3subscriptdelimited-∥∥𝐖¯𝐙subscriptnorm𝐖¯𝐙1𝐾1subscript𝐈𝐾1𝐾subscript1𝐾superscriptsubscript1𝐾T{\rm\mathcal{NC}_{3}}=\Big{\lVert}\frac{\mathbf{W}\mathbf{\overline{Z}}}{||% \mathbf{W}\mathbf{\overline{Z}}||_{\mathcal{F}}}-\frac{1}{\sqrt{K-1}}(\mathbf{% I}_{K}-\frac{1}{K}\mathbf{1}_{K}\mathbf{1}_{K}^{\mathrm{T}})\Big{\rVert}_{% \mathcal{F}}caligraphic_N caligraphic_C start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = ∥ divide start_ARG bold_W over¯ start_ARG bold_Z end_ARG end_ARG start_ARG | | bold_W over¯ start_ARG bold_Z end_ARG | | start_POSTSUBSCRIPT caligraphic_F end_POSTSUBSCRIPT end_ARG - divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_K - 1 end_ARG end_ARG ( bold_I start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_K end_ARG bold_1 start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT caligraphic_F end_POSTSUBSCRIPT, where 𝐙¯=[𝐳¯1𝐳G,,𝐳¯K𝐳G]¯𝐙subscript¯𝐳1subscript𝐳𝐺subscript¯𝐳𝐾subscript𝐳𝐺\mathbf{\overline{Z}}=[\mathbf{\overline{z}}_{1}-\mathbf{z}_{G},...,\mathbf{% \overline{z}}_{K}-\mathbf{z}_{G}]over¯ start_ARG bold_Z end_ARG = [ over¯ start_ARG bold_z end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - bold_z start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT , … , over¯ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT - bold_z start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT ]

Table 1: The metrics of evaluating the Neural Collapse phenomenon, which are generally adopted in previous studies [46, 38, 24]. \|\cdot\|_{\mathcal{F}}∥ ⋅ ∥ start_POSTSUBSCRIPT caligraphic_F end_POSTSUBSCRIPT denotes the Frobenius norm, 𝐈Ksubscript𝐈𝐾\mathbf{I}_{K}bold_I start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT is the identity matrix and 𝟏Ksubscript1𝐾\mathbf{1}_{K}bold_1 start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT is the all-ones vector.

When trained on unbiased datasets (black lines in Fig. 2), the model displays the expected convergence properties, with metrics NC1-NC3 all converge to zero. We owe the elegant collapse phenomenon to the right correlation between the feature space and training objective, which is also supported by the analysis of benign global landscapes [46, 45].

However, when trained on biased datasets, the training process exhibits two stages: first the shortcut learning period and then the intrinsic learning period, as divided by the vertical dashed line. During the shortcut learning period, the accuracy of bias-aligned samples increases quickly, and the NC1-NC3 metrics show a rapid decline (green lines with \blacktriangle in Fig. 2). It indicates that when simple shortcuts exist in the training distribution, the model will quickly establish its feature space based on the bias attributes, and exhibit a converging trend towards the simplex ETF structure.

After the bias-aligned samples approach zero error, the model turns to the period of intrinsic learning, which focuses on the intrinsic correlations within bias-conflicting samples to further reduce the empirical loss. However, although their final loss reduces to zero, the bias-conflicting samples still display low accuracy and poor convergence results (green lines with \blacktriangledown). It implies that the intrinsic learning period merely induces the over-fitting of bias-conflicting samples and does not benefit in generalization. We attribute the failure of collapse to the early establishment of shortcut correlations. Once the biased feature space is established based on misleading attributes, rectifying it becomes challenging, particularly with scarce bias-conflicting samples. In the subsequent training steps, the misled features of bias-conflicting samples will hinder the convergence of same-class features, thereby halting the converging trend towards the simplex ETF structure and leading to a non-collapsed, sub-optimal feature space.

To break the curse of shortcut learning, we turn the tricky shortcut into a training prime, which effectively guides the models to focus on intrinsic correlations and form a naturally collapsed feature space (blue lines in Fig. 2). Our method is presented in the following sections.

4 Methodology

Refer to caption
Figure 3: The illustration of our method. We take the class climbing from BAR [23] as an example, which contains samples of human climbing but with the bias attribute of different backgrounds (as indicated with the color of image frames). The framework contains: 1) Prime Construction: Before training, a randomly initialized ETF structure is constructed as the shortcut primes, and 2) During Prime Training, the prime features 𝐦bsubscript𝐦𝑏\mathbf{m}_{b}bold_m start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT are retrieved based on the bias attribute b𝑏bitalic_b of the input samples, to guide the optimization of learnable features 𝐳𝐳\mathbf{z}bold_z towards the intrinsic correlations. The classifier Fθsubscript𝐹𝜃F_{\theta}italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT will take both the learnable features and fixed prime features to make predictions. 3) In Unbiased Classification, the prime features are assigned as null vectors to evaluate the debiased model on test distributions.

4.1 Motivation

Following the previous analysis, we highlight the importance of redirecting the model’s emphasis from simple shortcuts to intrinsic relations. Since the models can be easily misled by shortcuts in the training distribution, is it feasible to supply a “perfectly learned" shortcut feature, to deceive the models into skipping the active learning of shortcuts, and directly focusing on the intrinsic correlations?

We observe in Fig. 2 that the NC1-NC3 metrics show a rapid decrease during shortcut learning, but remain stable in the subsequent training epochs. However, if the training distribution does follow the shortcut correlation (with no obstacle from bias-conflicting samples), the convergence will end up with the optimal structure of simplex ETF, just as the results on unbiased datasets. This inspires us to approximate the “perfectly learned" shortcut features with a simplex ETF structure, which requires no additional training and represents the optimal geometry of feature space.

Therefore, following the outstanding performance of prime training in OOD generalization [33], we introduce the approximated “perfect" shortcuts as the primes for debiased learning. The provided shortcut primes are constructed with a training-free simplex ETF structure, which encourages the models to directly capture the intrinsic correlations, therefore exhibit superior generalizability and convergence properties in our experiments.

4.2 Avoid-shortcut Learning with Neural Collapse

Building upon our motivation of avoid-shortcut learning, the illustration of the proposed ETF-Debias is shown in Fig. 3. The debiased learning framework can be divided into three stages: prime construction, prime training, and unbiased classification. Firstly, a prime ETF will be constructed to approximate the “perfect" shortcut features. Then during the prime training, the model will be guided to directly capture the intrinsic correlations with the prime training and the prime reinforcement regularization. In evaluation, we rely on the intrinsic correlations to perform unbiased classification. The details are as follows.

Prime construction. When constructing the prime ETF, we first randomly initialize a simplex ETF as 𝐌d×B𝐌superscript𝑑𝐵\mathbf{M}\in\mathbb{R}^{d\times B}bold_M ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_B end_POSTSUPERSCRIPT, which satisfies the definition in Eq. 1. The dimension d𝑑ditalic_d is the same as the learnable features, and the number of vectors in 𝐌𝐌\mathbf{M}bold_M is determined by the categories of bias attributes 𝐛i{1,,B}subscript𝐛𝑖1𝐵\mathbf{b}_{i}\in\{1,...,B\}bold_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ { 1 , … , italic_B }, which are pre-defined in the training distribution. After initialization, the vertices of prime ETF [𝐦1,,𝐦B]subscript𝐦1subscript𝐦𝐵[\mathbf{m}_{1},...,\mathbf{m}_{B}][ bold_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_m start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ] are considered as the approximation of the “perfect" shortcut features for each attribute, which serve as the prime features for avoid-shortcut training. During training, the prime features will be retrieved based on the bias attribute b𝑏bitalic_b of each input sample.

Prime training. During the prime training, we take the end-to-end model architecture with a backbone Eϕsubscript𝐸italic-ϕE_{\phi}italic_E start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT and a classifier Fθsubscript𝐹𝜃F_{\theta}italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT. For the i𝑖iitalic_i-th input 𝐱i,bsubscript𝐱𝑖𝑏\mathbf{x}_{i,b}bold_x start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT with the bias attribute of b𝑏bitalic_b, we first extract its learnable feature 𝐳i,b=Eϕ(𝐱i,b)subscript𝐳𝑖𝑏subscript𝐸italic-ϕsubscript𝐱𝑖𝑏\mathbf{z}_{i,b}=E_{\phi}(\mathbf{x}_{i,b})bold_z start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT = italic_E start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ) with the backbone model, and retrieve its prime feature 𝐦bsubscript𝐦𝑏\mathbf{m}_{b}bold_m start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT based on the bias attribute b𝑏bitalic_b. The classifier Fθsubscript𝐹𝜃F_{\theta}italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT will take both the learnable feature 𝐳i,bsubscript𝐳𝑖𝑏\mathbf{z}_{i,b}bold_z start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT and the prime feature 𝐦bsubscript𝐦𝑏\mathbf{m}_{b}bold_m start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT to make softmaxed predictions 𝐲^=Fθ(𝐳i,b,𝐦b)^𝐲subscript𝐹𝜃subscript𝐳𝑖𝑏subscript𝐦𝑏\hat{\mathbf{y}}=F_{\theta}(\mathbf{z}_{i,b},\mathbf{m}_{b})over^ start_ARG bold_y end_ARG = italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ). The standard classification objective is defined as:

minϕ,θCE(𝐱,𝐲)=i=1N(Fθ(𝐳i,b,𝐦b),𝐲i,b)subscriptitalic-ϕ𝜃subscriptCE𝐱𝐲superscriptsubscript𝑖1𝑁subscript𝐹𝜃subscript𝐳𝑖𝑏subscript𝐦𝑏subscript𝐲𝑖𝑏\displaystyle\min_{\phi,\theta}\mathcal{L}_{\rm CE}(\mathbf{x},\mathbf{y})=% \sum_{i=1}^{N}\mathcal{L}(F_{\theta}(\mathbf{z}_{i,b},\mathbf{m}_{b}),\mathbf{% y}_{i,b})roman_min start_POSTSUBSCRIPT italic_ϕ , italic_θ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT ( bold_x , bold_y ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT caligraphic_L ( italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ) , bold_y start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ) (2)

where 𝐲𝐲\mathbf{y}bold_y is the ground-truth label. In our implementation, we use the standard cross-entropy loss as \mathcal{L}caligraphic_L, and concatenate the prime features after the learnable features to perform predictions.

In essence, we provide a pre-defined prime feature for each training sample based on its bias attribute. The prime features, with a strong correlation with the bias attributes, can be viewed as the optimal solution to shortcut learning. By leveraging the already “perfect" representation of shortcut correlations, the model will be forced to explore the intrinsic correlations within the training distribution. The prime-guided mechanism targets at the fundamental issue of biased classification, without inducing extra training costs.

Prime reinforcement regularization. Given the prime features, the model is encouraged to grasp the intrinsic correlation of the training distributions. However, we raise another potential risk that, despite the provided “perfectly learned" shortcut features, the model may still pursue the easy-to-follow shortcuts, leading to the redundancy between the learnable feature 𝐳i,bsubscript𝐳𝑖𝑏\mathbf{z}_{i,b}bold_z start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT and the fixed 𝐦bsubscript𝐦𝑏\mathbf{m}_{b}bold_m start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT. We point out that the model may not establish a strong correlation between the prime features and the bias attributes, and continues to optimize the learnable features for the missing connections.

Therefore, we introduce a prime reinforcement regularization mechanism to enhance the model’s dependency on prime features. We encourage the model to classify the bias attributes with only the prime features, and the regularization loss is defined as:

RE(𝐱,𝐛)=i=1N(Fθ(𝐳i,b,𝐦b)Fθ(𝐳i,b,𝐦null),𝐛)subscriptRE𝐱𝐛superscriptsubscript𝑖1𝑁subscript𝐹𝜃subscript𝐳𝑖𝑏subscript𝐦𝑏subscript𝐹𝜃subscript𝐳𝑖𝑏subscript𝐦null𝐛\mathcal{L}_{\rm RE}(\mathbf{x},\mathbf{b})=\sum_{i=1}^{N}\mathcal{L}(F_{% \theta}(\mathbf{z}_{i,b},\mathbf{m}_{b})-F_{\theta}(\mathbf{z}_{i,b},\mathbf{m% }_{\rm null}),\mathbf{b})caligraphic_L start_POSTSUBSCRIPT roman_RE end_POSTSUBSCRIPT ( bold_x , bold_b ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT caligraphic_L ( italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ) - italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT roman_null end_POSTSUBSCRIPT ) , bold_b )

(3)

where 𝐦nullsubscript𝐦null\mathbf{m}_{\rm null}bold_m start_POSTSUBSCRIPT roman_null end_POSTSUBSCRIPT is implemented as all-zero vectors with the same dimension as 𝐦bsubscript𝐦𝑏\mathbf{m}_{b}bold_m start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT, and \mathcal{L}caligraphic_L is the standard cross-entropy loss. In the ablation studies in Section 5.3, we observe an improved generalization capability across test distributions, as the result of the strengthened reliance on prime features. Regarding the entire framework, we define the overall training objective as:

minϕ,θCE(𝐱,𝐲)+αRE(𝐱,𝐛)subscriptitalic-ϕ𝜃subscriptCE𝐱𝐲𝛼subscriptRE𝐱𝐛\min_{\phi,\theta}\mathcal{L}_{\rm CE}(\mathbf{x},\mathbf{y})+\alpha\mathcal{L% }_{\rm RE}(\mathbf{x},\mathbf{b})roman_min start_POSTSUBSCRIPT italic_ϕ , italic_θ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT ( bold_x , bold_y ) + italic_α caligraphic_L start_POSTSUBSCRIPT roman_RE end_POSTSUBSCRIPT ( bold_x , bold_b ) (4)

where α𝛼\alphaitalic_α is the hyper-parameter to adjust the regularization.

Unbiased classification. In evaluation, we rely on the intrinsic correlations to perform unbiased classification. Given a test sample 𝐱i,bsubscript𝐱𝑖𝑏\mathbf{x}_{i,b}bold_x start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT, we extract its learnable feature 𝐳i,bsubscript𝐳𝑖𝑏\mathbf{z}_{i,b}bold_z start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT and set its prime feature as 𝐦nullsubscript𝐦null\mathbf{m}_{\rm null}bold_m start_POSTSUBSCRIPT roman_null end_POSTSUBSCRIPT to obtain the final output 𝐲^=Fθ(𝐳i,b,𝐦null)^𝐲subscript𝐹𝜃subscript𝐳𝑖𝑏subscript𝐦null\hat{\mathbf{y}}=F_{\theta}(\mathbf{z}_{i,b},\mathbf{m}_{\rm null})over^ start_ARG bold_y end_ARG = italic_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT roman_null end_POSTSUBSCRIPT ).

4.3 Theoretical Justification

Based on the analysis of Neural Collapse from the perspective of gradients [43, 36], we provide a brief theoretical justification for our method.

With the priming mechanism, we denote the i𝑖iitalic_i-th feature of the k𝑘kitalic_k-th class as 𝐳~k,i=[𝐳k,i,𝐦i,b]2×dsubscript~𝐳𝑘𝑖subscript𝐳𝑘𝑖subscript𝐦𝑖𝑏superscript2𝑑\widetilde{\mathbf{z}}_{k,i}=[\mathbf{z}_{k,i},\mathbf{m}_{i,b}]\in\mathbb{R}^% {2\times d}over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT = [ bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT 2 × italic_d end_POSTSUPERSCRIPT, which represents the concatenation of learnable feature 𝐳k,isubscript𝐳𝑘𝑖\mathbf{z}_{k,i}bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT and prime feature 𝐦i,bsubscript𝐦𝑖𝑏\mathbf{m}_{i,b}bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT based on its bias attribute b𝑏bitalic_b. To keep the same form, we also denote the classifier weights as 𝐰~k=[𝐰k,𝐚k]2×dsubscript~𝐰𝑘subscript𝐰𝑘subscript𝐚𝑘superscript2𝑑\widetilde{\mathbf{w}}_{k}=[\mathbf{w}_{k},\mathbf{a}_{k}]\in\mathbb{R}^{2% \times d}over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = [ bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT 2 × italic_d end_POSTSUPERSCRIPT, where 𝐰ksubscript𝐰𝑘\mathbf{w}_{k}bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT represents the weight for intrinsic correlations and 𝐚ksubscript𝐚𝑘\mathbf{a}_{k}bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT represents the one for shortcut correlations. We observe that, due to the fixed prime features during training, 𝐚ksubscript𝐚𝑘\mathbf{a}_{k}bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT will quickly collapse to the bias-correlated prime features of class k𝑘kitalic_k, and can be viewed as constant after just a few steps of training. With the definition, the cross-entropy (CE) loss can be written as:

CE(𝐳~k,i,𝐰~k)=log(exp([𝐳k,i,𝐦i,b]T[𝐰k,𝐚k])k=1Kexp([𝐳k,i,𝐦i,b]T[𝐰k,𝐚k]))subscriptCEsubscript~𝐳𝑘𝑖subscript~𝐰𝑘superscriptsubscript𝐳𝑘𝑖subscript𝐦𝑖𝑏Tsubscript𝐰𝑘subscript𝐚𝑘superscriptsubscriptsuperscript𝑘1𝐾superscriptsubscript𝐳𝑘𝑖subscript𝐦𝑖𝑏Tsubscript𝐰superscript𝑘superscriptsubscript𝐚𝑘\mathcal{L}_{\rm CE}(\widetilde{\mathbf{z}}_{k,i},\widetilde{\mathbf{w}}_{k})=% -\log\left(\frac{\exp([\mathbf{z}_{k,i},\mathbf{m}_{i,b}]^{\mathrm{T}}[\mathbf% {w}_{k},\mathbf{a}_{k}])}{\sum_{k^{\prime}=1}^{K}\exp([\mathbf{z}_{k,i},% \mathbf{m}_{i,b}]^{\mathrm{T}}[\mathbf{w}_{k^{\prime}},\mathbf{a}_{k}^{\prime}% ])}\right)caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) = - roman_log ( divide start_ARG roman_exp ( [ bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT [ bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ] ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( [ bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT [ bold_w start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ] ) end_ARG )

(5)

We follow the analysis of previous works and compute the gradients of CEsubscriptCE\mathcal{L}_{\rm CE}caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT w.r.t both classifier weights and features.

Gradient w.r.t classifier weights. We first compute the gradient of CEsubscriptCE\mathcal{L}_{\rm CE}caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT w.r.t classifier weights 𝐖~=[𝐰~1,𝐰~K]~𝐖subscript~𝐰1subscript~𝐰𝐾\mathbf{\widetilde{W}}=[\widetilde{\mathbf{w}}_{1}...,\widetilde{\mathbf{w}}_{% K}]over~ start_ARG bold_W end_ARG = [ over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT … , over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ]:

CE𝐰~ksubscriptCEsubscript~𝐰𝑘\displaystyle\frac{\partial\mathcal{L}_{\rm CE}}{\partial\widetilde{\mathbf{w}% }_{k}}divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT end_ARG start_ARG ∂ over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG =i=1nk(1pk(𝐳~k,i))𝐳~k,i+kkKj=1nkpk(𝐳~k,j)𝐳~k,jabsentsuperscriptsubscript𝑖1subscript𝑛𝑘1subscript𝑝𝑘subscript~𝐳𝑘𝑖subscript~𝐳𝑘𝑖superscriptsubscriptsuperscript𝑘𝑘𝐾superscriptsubscript𝑗1subscript𝑛superscript𝑘subscript𝑝𝑘subscript~𝐳superscript𝑘𝑗subscript~𝐳superscript𝑘𝑗\displaystyle=\sum_{i=1}^{n_{k}}-(1-p_{k}(\widetilde{\mathbf{z}}_{k,i}))% \widetilde{\mathbf{z}}_{k,i}+\sum_{k^{\prime}\neq k}^{K}\sum_{j=1}^{n_{k^{% \prime}}}p_{k}(\widetilde{\mathbf{z}}_{k^{\prime},j})\widetilde{\mathbf{z}}_{k% ^{\prime},j}= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - ( 1 - italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) ) over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT ) over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT
i=1nk(1pk(b)(𝐦i,b)pk(l)(𝐳k,i))𝐳~k,ipullingpartabsentsubscriptsuperscriptsubscript𝑖1subscript𝑛𝑘1superscriptsubscript𝑝𝑘𝑏subscript𝐦𝑖𝑏superscriptsubscript𝑝𝑘𝑙subscript𝐳𝑘𝑖subscript~𝐳𝑘𝑖pullingpart\displaystyle\leq\underbrace{\sum_{i=1}^{n_{k}}-(1-p_{k}^{(b)}(\mathbf{m}_{i,b% })-p_{k}^{(l)}(\mathbf{z}_{k,i}))\widetilde{\mathbf{z}}_{k,i}}_{\rm pulling\ part}≤ under⏟ start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - ( 1 - italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b ) end_POSTSUPERSCRIPT ( bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ) - italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l ) end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) ) over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT roman_pulling roman_part end_POSTSUBSCRIPT
+kkKj=1nk(pk(b)(𝐦j,b)+pk(l)(𝐳k,j))𝐳~k,jforcingpartsubscriptsuperscriptsubscriptsuperscript𝑘𝑘𝐾superscriptsubscript𝑗1subscript𝑛superscript𝑘superscriptsubscript𝑝𝑘𝑏subscript𝐦𝑗superscript𝑏superscriptsubscript𝑝𝑘𝑙subscript𝐳superscript𝑘𝑗subscript~𝐳superscript𝑘𝑗forcingpart\displaystyle+\underbrace{\sum_{k^{\prime}\neq k}^{K}\sum_{j=1}^{n_{k^{\prime}% }}(p_{k}^{(b)}(\mathbf{m}_{j,b^{\prime}})+p_{k}^{(l)}(\mathbf{z}_{k^{\prime},j% }))\widetilde{\mathbf{z}}_{k^{\prime},j}}_{\rm forcing\ part}+ under⏟ start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b ) end_POSTSUPERSCRIPT ( bold_m start_POSTSUBSCRIPT italic_j , italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) + italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l ) end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT ) ) over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT roman_forcing roman_part end_POSTSUBSCRIPT (6)

where pk(l)superscriptsubscript𝑝𝑘𝑙p_{k}^{(l)}italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l ) end_POSTSUPERSCRIPT and pk(b)superscriptsubscript𝑝𝑘𝑏p_{k}^{(b)}italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b ) end_POSTSUPERSCRIPT are the predicted probabilities for class labels and bias attributes, calculated with softmax:

pk(l)(𝐳k,i)=exp(𝐳k,iT𝐰k)k=1Kexp([𝐳k,i,𝐦i,b]T[𝐰k,𝐚k])superscriptsubscript𝑝𝑘𝑙subscript𝐳𝑘𝑖superscriptsubscript𝐳𝑘𝑖Tsubscript𝐰𝑘superscriptsubscriptsuperscript𝑘1𝐾superscriptsubscript𝐳𝑘𝑖subscript𝐦𝑖𝑏Tsubscript𝐰superscript𝑘subscript𝐚superscript𝑘\displaystyle p_{k}^{(l)}(\mathbf{z}_{k,i})=\frac{\exp(\mathbf{z}_{k,i}^{% \mathrm{T}}\mathbf{w}_{k})}{\sum_{k^{\prime}=1}^{K}\exp([\mathbf{z}_{k,i},% \mathbf{m}_{i,b}]^{\mathrm{T}}[\mathbf{w}_{k^{\prime}},\mathbf{a}_{k^{\prime}}% ])}italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l ) end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) = divide start_ARG roman_exp ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( [ bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT [ bold_w start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ] ) end_ARG (7)
pk(b)(𝐦i,b)=exp(𝐦i,bT𝐚k)k=1Kexp([𝐳k,i,𝐦i,b]T[𝐰k,𝐚k])superscriptsubscript𝑝𝑘𝑏subscript𝐦𝑖𝑏superscriptsubscript𝐦𝑖𝑏Tsubscript𝐚𝑘superscriptsubscriptsuperscript𝑘1𝐾superscriptsubscript𝐳𝑘𝑖subscript𝐦𝑖𝑏Tsubscript𝐰superscript𝑘subscript𝐚superscript𝑘p_{k}^{(b)}(\mathbf{m}_{i,b})=\frac{\exp(\mathbf{m}_{i,b}^{\mathrm{T}}\mathbf{% a}_{k})}{\sum_{k^{\prime}=1}^{K}\exp([\mathbf{z}_{k,i},\mathbf{m}_{i,b}]^{% \mathrm{T}}[\mathbf{w}_{k^{\prime}},\mathbf{a}_{k^{\prime}}])}italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b ) end_POSTSUPERSCRIPT ( bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ) = divide start_ARG roman_exp ( bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( [ bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT [ bold_w start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ] ) end_ARG (8)

In Eq. 4.3, the gradient w.r.t classifier weights are divided into two parts. The pulling part is composed of features from the same class that pulls 𝐰ksubscript𝐰𝑘\mathbf{w}_{k}bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT towards the direction of the k𝑘kitalic_k-th feature cluster, while the forcing part contains the features of other classes and pushes 𝐰ksubscript𝐰𝑘\mathbf{w}_{k}bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT away from their clusters. The weight factor of each feature 𝐳~k,isubscript~𝐳𝑘𝑖\widetilde{\mathbf{z}}_{k,i}over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT represents its influence on the optimization of 𝐰~ksubscript~𝐰𝑘\widetilde{\mathbf{w}}_{k}over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, which implicitly plays the role of re-weighting in our method.

We assume that class k𝑘kitalic_k is strongly correlated with bias attribute b𝑏bitalic_b. As the weight 𝐚ksubscript𝐚𝑘\mathbf{a}_{k}bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is observed to collapse quickly to the bias-correlated prime feature 𝐦bsubscript𝐦𝑏\mathbf{m}_{b}bold_m start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT, the probability pk(b)exp(𝐦bT𝐚k)proportional-tosuperscriptsubscript𝑝𝑘𝑏superscriptsubscript𝐦𝑏Tsubscript𝐚𝑘p_{k}^{(b)}\propto\exp(\mathbf{m}_{b}^{\mathrm{T}}\mathbf{a}_{k})italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b ) end_POSTSUPERSCRIPT ∝ roman_exp ( bold_m start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) of bias-aligned samples (with prime features 𝐦bsubscript𝐦𝑏\mathbf{m}_{b}bold_m start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT) are much greater than that of bias-conflicting samples (with prime features 𝐦bsubscript𝐦superscript𝑏\mathbf{m}_{b^{\prime}}bold_m start_POSTSUBSCRIPT italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT). Thus, with the weight factors in Eq. 4.3, the pulling and forcing effects of bias-aligned samples will be relatively down-weighted, and the impact of bias-conflicting samples will be up-weighted. The re-weighting mechanism of gradient mitigates the tendency of pulling 𝐰~ksubscript~𝐰𝑘\widetilde{\mathbf{w}}_{k}over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT towards the center of bias-aligned samples, which alleviates the misdirection of bias attributes.

Gradient w.r.t features. Similarly, we compute the gradient of CEsubscriptCE\mathcal{L}_{\rm CE}caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT w.r.t the feature 𝐳~k,isubscript~𝐳𝑘𝑖\widetilde{\mathbf{z}}_{k,i}over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT:

CE𝐳~k,isubscriptCEsubscript~𝐳𝑘𝑖\displaystyle\frac{\partial\mathcal{L}_{\rm CE}}{\partial\widetilde{\mathbf{z}% }_{k,i}}divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT end_ARG start_ARG ∂ over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT end_ARG =(1pk(𝐳~k,i))𝐰k+kkKpk(𝐳~k,i)𝐰kabsent1subscript𝑝𝑘subscript~𝐳𝑘𝑖subscript𝐰𝑘superscriptsubscriptsuperscript𝑘𝑘𝐾subscript𝑝superscript𝑘subscript~𝐳𝑘𝑖subscript𝐰superscript𝑘\displaystyle=-(1-p_{k}(\widetilde{\mathbf{z}}_{k,i}))\mathbf{w}_{k}+\sum_{k^{% \prime}\neq k}^{K}p_{k^{\prime}}(\widetilde{\mathbf{z}}_{k,i})\mathbf{w}_{k^{% \prime}}= - ( 1 - italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) ) bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) bold_w start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT
(1pk(b)(𝐦i,b)pk(l)(𝐳k,i))𝐰kpullingpartabsentsubscript1superscriptsubscript𝑝𝑘𝑏subscript𝐦𝑖𝑏superscriptsubscript𝑝𝑘𝑙subscript𝐳𝑘𝑖subscript𝐰𝑘pullingpart\displaystyle\leq\underbrace{-(1-p_{k}^{(b)}(\mathbf{m}_{i,b})-p_{k}^{(l)}(% \mathbf{z}_{k,i}))\mathbf{w}_{k}}_{\rm pulling\ part}≤ under⏟ start_ARG - ( 1 - italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b ) end_POSTSUPERSCRIPT ( bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ) - italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l ) end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) ) bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT roman_pulling roman_part end_POSTSUBSCRIPT
+kkK(pk(b)(𝐦i,b)+pk(l)(𝐳k,i))𝐰kforcingpartsubscriptsuperscriptsubscriptsuperscript𝑘𝑘𝐾superscriptsubscript𝑝superscript𝑘𝑏subscript𝐦𝑖𝑏superscriptsubscript𝑝superscript𝑘𝑙subscript𝐳𝑘𝑖subscript𝐰superscript𝑘forcingpart\displaystyle+\underbrace{\sum_{k^{\prime}\neq k}^{K}(p_{k^{\prime}}^{(b)}(% \mathbf{m}_{i,b})+p_{k^{\prime}}^{(l)}(\mathbf{z}_{k,i}))\mathbf{w}_{k^{\prime% }}}_{\rm forcing\ part}+ under⏟ start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ( italic_p start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b ) end_POSTSUPERSCRIPT ( bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ) + italic_p start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l ) end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) ) bold_w start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT roman_forcing roman_part end_POSTSUBSCRIPT (9)

In the gradient w.r.t features, the pulling part directs the feature 𝐳~k,isubscript~𝐳𝑘𝑖\widetilde{\mathbf{z}}_{k,i}over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT towards the weight of its class 𝐰csubscript𝐰𝑐\mathbf{w}_{c}bold_w start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT, and the forcing part repels it from wrong classes. Regarding the weight factors, the probability of bias attribute pk(b)superscriptsubscript𝑝𝑘𝑏p_{k}^{(b)}italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b ) end_POSTSUPERSCRIPT also re-weights the influence of classifier weights. Bias-aligned samples, with high pk(b)superscriptsubscript𝑝𝑘𝑏p_{k}^{(b)}italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b ) end_POSTSUPERSCRIPT probability, will have smaller pulling effects towards the classifier weight 𝐰ksubscript𝐰𝑘\mathbf{w}_{k}bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, which avoids the dominance of bias-aligned features around the weight centers and hinders the tendency of shortcut learning. In comparison, the bias-conflicting samples are granted stronger pulling and pushing effects, which strengthens their convergence toward the right class. The detailed theoretical justification of our method, along with the comparison with vanilla training, are available in Appendix A.

5 Experiments

Table 2: Comparison of debiasing performance on synthetic datasets. We report the accuracy on the unbiased test sets of Colored MNIST and Corrupted CIFAR-10. Best performances are marked in bold, and the number in brackets indicates the improvement compared to the best result in baselines. ()(^{\ast})( start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) and ()(^{\diamond})( start_POSTSUPERSCRIPT ⋄ end_POSTSUPERSCRIPT ) denote methods with/without bias supervision respectively.

Dataset Ratio(%) Vanilla LfF[23] LfF+BE[19] EnD[30] SD[42] DisEnt[18] Selecmix[13] ETF-Debias Colored MNIST 0.5 32.22±plus-or-minus\pm±0.13 57.78±plus-or-minus\pm±0.81 69.69±plus-or-minus\pm±1.99 35.93±plus-or-minus\pm±0.40 56.96±plus-or-minus\pm±0.37 68.83±plus-or-minus\pm±1.62 70.53±plus-or-minus\pm±0.46 71.63±plus-or-minus\pm±0.28 (+1.10) 1.0 48.45±plus-or-minus\pm±0.06 72.29±plus-or-minus\pm±1.69 80.90±plus-or-minus\pm±1.40 49.32±plus-or-minus\pm±0.58 72.46±plus-or-minus\pm±0.18 79.49±plus-or-minus\pm±1.44 83.34±plus-or-minus\pm±0.37 81.97±plus-or-minus\pm±0.26 (-1.37) 2.0 58.90±plus-or-minus\pm±0.12 79.51±plus-or-minus\pm±1.82 84.90±plus-or-minus\pm±1.14 65.58±plus-or-minus\pm±0.46 79.37±plus-or-minus\pm±0.46 84.56±plus-or-minus\pm±1.19 85.90±plus-or-minus\pm±0.23 86.00±plus-or-minus\pm±0.03 (+0.10) 5.0 74.19±plus-or-minus\pm±0.04 83.96±plus-or-minus\pm±1.44 90.28±plus-or-minus\pm±0.18 80.70±plus-or-minus\pm±0.17 88.89±plus-or-minus\pm±0.21 88.83±plus-or-minus\pm±0.15 91.27±plus-or-minus\pm±0.31 91.36±plus-or-minus\pm±0.21 (+0.09) Corrupted CIFAR-10 0.5 17.06±plus-or-minus\pm±0.12 31.00±plus-or-minus\pm±2.67 23.68±plus-or-minus\pm±0.50 14.30±plus-or-minus\pm±0.10 36.66±plus-or-minus\pm±0.74 30.12±plus-or-minus\pm±1.60 33.30±plus-or-minus\pm±0.26 40.06±plus-or-minus\pm±0.03 (+3.40) 1.0 21.48±plus-or-minus\pm±0.55 34.33±plus-or-minus\pm±1.76 30.72±plus-or-minus\pm±0.12 20.17±plus-or-minus\pm±0.19 45.66±plus-or-minus\pm±1.05 35.28±plus-or-minus\pm±1.39 38.72±plus-or-minus\pm±0.27 47.52±plus-or-minus\pm±0.26 (+1.86) 2.0 27.15±plus-or-minus\pm±0.46 39.68±plus-or-minus\pm±1.15 42.22±plus-or-minus\pm±0.60 30.10±plus-or-minus\pm±0.54 50.11±plus-or-minus\pm±0.69 40.34±plus-or-minus\pm±1.41 47.09±plus-or-minus\pm±0.17 54.64±plus-or-minus\pm±0.42 (+4.53) 5.0 39.46±plus-or-minus\pm±0.58 53.04±plus-or-minus\pm±0.76 57.93±plus-or-minus\pm±0.58 45.85±plus-or-minus\pm±0.21 62.43±plus-or-minus\pm±0.57 49.99±plus-or-minus\pm±0.84 54.69±plus-or-minus\pm±0.29 65.34±plus-or-minus\pm±0.60 (+2.91)

Table 3: Comparison of debiasing performance on real-world datasets. We report the accuracy on the unbiased test sets of Biased FFHQ, Dogs & Cats, and BAR. The class-wise accuracy on BAR is reported in Appendix D. Best performances are marked in bold, and the number in brackets indicates the improvement compared to the best result in baselines. ()(^{\ast})( start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) and ()(^{\diamond})( start_POSTSUPERSCRIPT ⋄ end_POSTSUPERSCRIPT ) denote methods with/without bias supervision respectively.

Dataset Ratio(%) Vanilla LfF[23] LfF+BE[19] EnD[30] SD[42] DisEnt[18] Selecmix[13] ETF-Debias Biased FFHQ 0.5 53.27±plus-or-minus\pm±0.61 65.60±plus-or-minus\pm±2.27 67.07±plus-or-minus\pm±2.37 55.93±plus-or-minus\pm±1.62 65.60±plus-or-minus\pm±0.20 63.07±plus-or-minus\pm±1.14 65.00±plus-or-minus\pm±0.82 73.60±plus-or-minus\pm±1.22 (+6.53) 1.0 57.13±plus-or-minus\pm±0.64 72.33±plus-or-minus\pm±2.19 73.53±plus-or-minus\pm±1.62 61.13±plus-or-minus\pm±0.50 69.20±plus-or-minus\pm±0.20 68.53±plus-or-minus\pm±2.32 67.50±plus-or-minus\pm±0.30 76.53±plus-or-minus\pm±1.10 (+3.00) 2.0 67.67±plus-or-minus\pm±0.81 74.80±plus-or-minus\pm±2.03 80.20±plus-or-minus\pm±2.78 66.87±plus-or-minus\pm±0.64 78.40±plus-or-minus\pm±0.20 72.00±plus-or-minus\pm±2.51 69.80±plus-or-minus\pm±0.87 85.20±plus-or-minus\pm±0.61 (+5.00) 5.0 78.87±plus-or-minus\pm±0.83 80.27±plus-or-minus\pm±2.02 87.40±plus-or-minus\pm±2.00 80.87±plus-or-minus\pm±0.42 84.80±plus-or-minus\pm±0.20 80.60±plus-or-minus\pm±0.53 83.47±plus-or-minus\pm±0.61 94.00±plus-or-minus\pm±0.72 (+6.60) Dogs & Cats 1.0 51.96±plus-or-minus\pm±0.90 71.17±plus-or-minus\pm±5.24 78.87±plus-or-minus\pm±2.40 51.91±plus-or-minus\pm±0.24 78.13±plus-or-minus\pm±1.06 65.13±plus-or-minus\pm±2.07 54.19±plus-or-minus\pm±1.61 80.07±plus-or-minus\pm±0.90 (+1.20) 5.0 76.59±plus-or-minus\pm±1.27 85.83±plus-or-minus\pm±1.62 88.60±plus-or-minus\pm±1.21 79.07±plus-or-minus\pm±0.28 89.12±plus-or-minus\pm±0.18 82.47±plus-or-minus\pm±2.86 81.50±plus-or-minus\pm±1.06 92.18±plus-or-minus\pm±0.62 (+3.06) BAR 1.0 68.00±plus-or-minus\pm±0.43 68.30±plus-or-minus\pm±0.97 71.70±plus-or-minus\pm±1.33 68.25±plus-or-minus\pm±0.19 67.33±plus-or-minus\pm±0.35 69.30±plus-or-minus\pm±1.27 69.83±plus-or-minus\pm±1.02 72.79±plus-or-minus\pm±0.21 (+1.09) 5.0 79.34±plus-or-minus\pm±0.19 80.25±plus-or-minus\pm±1.27 82.00±plus-or-minus\pm±1.24 78.86±plus-or-minus\pm±0.36 79.10±plus-or-minus\pm±0.42 81.19±plus-or-minus\pm±0.70 78.79±plus-or-minus\pm±0.52 83.66±plus-or-minus\pm±0.21 (+1.66)

5.1 Experimental Settings

Datasets and models. We validate the effectiveness of ETF-Debias on general debiasing benchmarks, which cover various types of bias attributes including color, corruption, gender, and background. We adopt 2 synthetic biased datasets, Colored MNIST [15] and Corrupted CIFAR-10 [10] with the ratio of bias-conflicting training samples {0.5%, 1.0%, 2.0%, 5.0%}, and 3 real-world biased datasets, Biased FFHQ (BFFHQ) [18] with bias ratio {0.5%, 1.0%, 2.0%, 5.0%}, BAR [23], and Dogs & Cats [15] with bias ratio {1.0%, 5.0%}.

As for the model architecture, we adopt a three-layer MLP for Colored MNIST and ResNet-20 [9] for other datasets. Since BAR has a tiny training set, we follow the previous work [19] and initialize the parameters with pre-trained models on corresponding datasets. All results are averaged over three independent trials. More details about datasets and implementation are available in Appendix B.

Baselines. According to the three categories of debiased learning in Section 2, we compare the performance of ETF-Debias with six recent methods. For reweight-based debiasing, we consider LfF [23] with auxiliary bias models, and its improved version LfF+BE [19]. For disentangle-based debiasing, we consider EnD [30] and SD [42], which stem from the Information Bottleneck theory and causal intervention respectively. For augmentation-based debiasing, we consider DisEnt [18] and Selecmix [13], to include both the feature-level and image-level augmentations.

5.2 Main Results

Comparison on synthetic datasets. To display the debiasing performance, we report the accuracy on the unbiased test set of 2 synthetic datasets in Tab. 2. It’s notable that ETF-Debias consistently outperforms baselines in the generalization capability towards test samples, on almost all levels of bias ratio. We observe that some baseline methods (e.g., EnD) do not display a satisfactory debiasing effect on synthetic datasets, as they rely heavily on diverse contrastive samples to identify and mitigate the bias features. In contrast, our approach directly provides the approximated shortcut features as training primes, which achieves superior performance on synthetic bias attributes.

Comparison on real-world datasets. To verify the scalability of ETF-Debias in real-world scenarios with more diverse bias attributes, we test our method on 3 real-world biased datasets in Tab. 3. We observe that ETF-Debias shows an even greater performance gain on real-world datasets than on synthetic ones, which may be attributed to the semantically meaningful prime features constructed with the simplex ETF structure. On the large-scale BFFHQ dataset, our method achieves up to 6.6% accuracy improvements compared to baseline methods, demonstrating its potential in real-world applications.

Convergence of Neural Collapse. In Fig. 2, we display the trajectory of NC metrics during training on the Corrupted CIFAR-10 dataset. Guided by the prime features, the model establishes a right correlation and shows a much better convergence property on biased datasets, contributing to the superior generalization capability. More convergence results are available in Appendix C.

5.3 Ablation Study

Ablation on the influence of regularization. To measure the sensitivity of our method to different levels of prime reinforcement regularization, we compare the accuracy on the unbiased test set with α𝛼\alphaitalic_α range from 0.0 to 1.0 in Fig. 4(a). It’s been shown that the debiasing performance remains significant with different strengths of regularization, and achieves extra performance gain with the proper level of prime reinforcement.

Refer to caption
Figure 4: Ablation studies on hyper-parameter and prime features. We report (a) test set accuracy on different datasets, with hyper-parameter α𝛼\alphaitalic_α ranging from 0.0 to 1.0, and (b) test set accuracy on 5 datasets with different prime features. The shaded areas represent the standard deviation, and the bias ratio of all datasets is 5%.

Ablation on the influence of ETF prime features. As illustrated before, we choose the vertices of ETF as the “perfectly learned" shortcut features, thus redirecting the model’s attention to intrinsic correlations. To demonstrate the efficacy of ETF prime features, we compare the results of randomly initialized prime features with the same dimension as the ETF-based ones. As shown in Fig. 4(b), the randomly initialized primes suffer a severe performance degradation, underscoring the advantages of ETF-based prime features in approximating the optimal structure.

5.4 Visualization

Refer to caption
Figure 5: The comparison of visualization results on bias-conflicting samples (first row) between vanilla models (second row) and ETF-Debias models (third row). We display the result of CAM [44] on (a) Corrupted CIFAR-10 dataset, with the bias attribute as different types of corruption on the entire image, (b) Dogs & Cats dataset, with the bias attribute as the color of animals and (c) BFFHQ dataset, with the bias attribute as gender.

To intuitively reveal the effectiveness of our method, we compare the CAM [44] visualization results on vanilla models and debiased models trained with ETF-Debias. As shown in Fig. 5, the model’s attention is significantly rectified with ETF-Debias. For example, on Corrupted CIFAR-10 dataset, vanilla models are easily misled by the corruptions on the entire images, but with the guide of prime features in our method, the debiased models shift their attention to the objects themselves. It’s also notable that our method circumvents the wrong attention area in classification and encourages the focus on more discriminative and finer-grained regions. On datasets of facial recognition, our method also breaks the spurious correlation on specific visual attributes [19] and considers more facial features, as shown in Fig. 5(c). More visualization results are available in Appendix E.

6 Conclusion

In this paper, we propose an avoid-shortcut learning framework with the insights of the Neural Collapse phenomenon. By extending the analysis of Neural Collapse to biased datasets, we introduce the simplex ETF as the prime features to redirect the model’s attention to intrinsic correlations. With the state-of-the-art debiasing performance on various benchmarks, we hope our work may advance the understanding of Neural Collapse and shed light on the fundamental solutions to model debiasing.

Acknowledgement

We appreciate the valuable comments from the anonymous reviewers that improves the paper’s quality. This work was supported in part by the National Key Research and Development Program (2021YFB3101200), National Natural Science Foundation of China (U1736208, U1836210, U1836213, 62172104, 62172105, 61902374, 62102093, 62102091). Min Yang is a faculty of Shanghai Institute of Intelligent Electronics & Systems, Shanghai Insitute for Advanced Communication and Data Science, and Engineering Research Center of Cyber Security Auditing and Monitoring, Ministry of Education, China.

References

  • Balayn and Gürses [2021] Agathe Balayn and Seda Gürses. Beyond debiasing: Regulating ai and its inequalities. EDRi Report. https://edri. org/wp-content/uploads/2021/09/EDRi_Beyond-Debiasing-Report_Online. pdf, 2021.
  • Dang et al. [2023] Hien Dang, Tan Nguyen, Tho Tran, Hung Tran, and Nhat Ho. Neural collapse in deep linear network: From balanced to imbalanced data. ICML, 2023.
  • Deviyani [2022] Athiya Deviyani. Assessing dataset bias in computer vision. arXiv preprint arXiv:2205.01811, 2022.
  • Fan et al. [2022] Shaohua Fan, Xiao Wang, Yanhu Mo, Chuan Shi, and Jian Tang. Debiasing graph neural networks via learning disentangled causal substructure. NeurIPS, 35:24934–24946, 2022.
  • Fang et al. [2021] Cong Fang, Hangfeng He, Qi Long, and Weijie J Su. Exploring deep neural networks via layer-peeled model: Minority collapse in imbalanced training. Proceedings of the National Academy of Sciences, 118(43):e2103091118, 2021.
  • Galanti et al. [2022] Tomer Galanti, András György, and Marcus Hutter. On the role of neural collapse in transfer learning. In ICLR, 2022.
  • Guo et al. [2022] Yue Guo, Yi Yang, and Ahmed Abbasi. Auto-debias: Debiasing masked language models with automated biased prompts. In ACL, pages 1012–1023, 2022.
  • Han et al. [2022] XY Han, Vardan Papyan, and David L Donoho. Neural collapse under mse loss: Proximity to and dynamics on the central path. In ICLR, 2022.
  • He et al. [2016] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In CVPR, pages 770–778, 2016.
  • Hendrycks and Dietterich [2018] Dan Hendrycks and Thomas Dietterich. Benchmarking neural network robustness to common corruptions and perturbations. In ICLR, 2018.
  • Hirota et al. [2023] Yusuke Hirota, Yuta Nakashima, and Noa Garcia. Model-agnostic gender debiased image captioning. In CVPR, pages 15191–15200, 2023.
  • Huang et al. [2008] Gary B Huang, Marwan Mattar, Tamara Berg, and Eric Learned-Miller. Labeled faces in the wild: A database forstudying face recognition in unconstrained environments. In Workshop on faces in’Real-Life’Images: detection, alignment, and recognition, 2008.
  • Hwang et al. [2022] Inwoo Hwang, Sangjun Lee, Yunhyeok Kwak, Seong Joon Oh, Damien Teney, Jin-Hwa Kim, and Byoung-Tak Zhang. Selecmix: Debiased learning by contradicting-pair sampling. NeurIPS, 35:14345–14357, 2022.
  • Karras et al. [2019] Tero Karras, Samuli Laine, and Timo Aila. A style-based generator architecture for generative adversarial networks. In CVPR, pages 4401–4410, 2019.
  • Kim et al. [2019] Byungju Kim, Hyunwoo Kim, Kyungsu Kim, Sungjin Kim, and Junmo Kim. Learning not to learn: Training deep neural networks with biased data. In CVPR, pages 9012–9020, 2019.
  • Kim et al. [2022] Nayeong Kim, Sehyun Hwang, Sungsoo Ahn, Jaesik Park, and Suha Kwak. Learning debiased classifier with biased committee. NeurIPS, 35:18403–18415, 2022.
  • Krizhevsky et al. [2009] Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. 2009.
  • Lee et al. [2021] Jungsoo Lee, Eungyeup Kim, Juyoung Lee, Jihyeon Lee, and Jaegul Choo. Learning debiased representation via disentangled feature augmentation. NeurIPS, 34:25123–25133, 2021.
  • Lee et al. [2023] Jungsoo Lee, Jeonghoon Park, Daeyoung Kim, Juyoung Lee, Edward Choi, and Jaegul Choo. Revisiting the importance of amplifying bias for debiasing. In AAAI, pages 14974–14981, 2023.
  • Li et al. [2023] Zexi Li, Xinyi Shang, Rui He, Tao Lin, and Chao Wu. No fear of classifier biases: Neural collapse inspired federated learning with synthetic and fixed classifier. ICCV, 2023.
  • Lim et al. [2023] Jongin Lim, Youngdong Kim, Byungjai Kim, Chanho Ahn, Jinwoo Shin, Eunho Yang, and Seungju Han. Biasadv: Bias-adversarial augmentation for model debiasing. In CVPR, pages 3832–3841, 2023.
  • Lyu et al. [2023] Yougang Lyu, Piji Li, Yechang Yang, Maarten de Rijke, Pengjie Ren, Yukun Zhao, Dawei Yin, and Zhaochun Ren. Feature-level debiased natural language understanding. In AAAI, pages 13353–13361, 2023.
  • Nam et al. [2020] Junhyun Nam, Hyuntak Cha, Sungsoo Ahn, Jaeho Lee, and Jinwoo Shin. Learning from failure: De-biasing classifier from biased classifier. NeurIPS, 33:20673–20684, 2020.
  • Papyan et al. [2020] Vardan Papyan, XY Han, and David L Donoho. Prevalence of neural collapse during the terminal phase of deep learning training. Proceedings of the National Academy of Sciences, 117(40):24652–24663, 2020.
  • Peifeng et al. [2023] Gao Peifeng, Qianqian Xu, Peisong Wen, Zhiyong Yang, Huiyang Shao, and Qingming Huang. Feature directions matter: Long-tailed learning via rotated balanced representation. ICML, 2023.
  • Rangamani et al. [2023] Akshay Rangamani, Marius Lindegaard, Tomer Galanti, and Tomaso A Poggio. Feature learning in deep classifiers through intermediate neural collapse. In ICML, pages 28729–28745. PMLR, 2023.
  • Robinson et al. [2021] Joshua Robinson, Li Sun, Ke Yu, Kayhan Batmanghelich, Stefanie Jegelka, and Suvrit Sra. Can contrastive learning avoid shortcut solutions? NeurIPS, 34:4974–4986, 2021.
  • Saranrittichai et al. [2022] Piyapat Saranrittichai, Chaithanya Kumar Mummadi, Claudia Blaiotta, Mauricio Munoz, and Volker Fischer. Overcoming shortcut learning in a target domain by generalizing basic visual factors from a source domain. In ECCV, pages 294–309. Springer, 2022.
  • Shinoda et al. [2023] Kazutoshi Shinoda, Saku Sugawara, and Akiko Aizawa. Which shortcut solution do question answering models prefer to learn? In AAAI, pages 13564–13572, 2023.
  • Tartaglione et al. [2021] Enzo Tartaglione, Carlo Alberto Barbano, and Marco Grangetto. End: Entangling and disentangling deep representations for bias correction. In CVPR, pages 13508–13517, 2021.
  • Thrampoulidis et al. [2022] Christos Thrampoulidis, Ganesh Ramachandra Kini, Vala Vakilian, and Tina Behnia. Imbalance trouble: Revisiting neural-collapse geometry. Advances in Neural Information Processing Systems, 35:27225–27238, 2022.
  • Wang et al. [2021] Tan Wang, Chang Zhou, Qianru Sun, and Hanwang Zhang. Causal attention for unbiased visual recognition. In CVPR, pages 3091–3100, 2021.
  • Wen et al. [2022] Chuan Wen, Jianing Qian, Jierui Lin, Jiaye Teng, Dinesh Jayaraman, and Yang Gao. Fighting fire with fire: Avoiding dnn shortcuts through priming. In ICML, pages 23723–23750. PMLR, 2022.
  • Wen et al. [2021] Zhiquan Wen, Guanghui Xu, Mingkui Tan, Qingyao Wu, and Qi Wu. Debiased visual question answering from feature and sample perspectives. NeurIPS, 34:3784–3796, 2021.
  • Xie et al. [2023] Liang Xie, Yibo Yang, Deng Cai, and Xiaofei He. Neural collapse inspired attraction–repulsion-balanced loss for imbalanced learning. Neurocomputing, 527:60–70, 2023.
  • Yang et al. [2022a] Yibo Yang, Shixiang Chen, Xiangtai Li, Liang Xie, Zhouchen Lin, and Dacheng Tao. Inducing neural collapse in imbalanced learning: Do we really need a learnable classifier at the end of deep neural network? NeurIPS, 35:37991–38002, 2022a.
  • Yang et al. [2022b] Yibo Yang, Haobo Yuan, Xiangtai Li, Zhouchen Lin, Philip Torr, and Dacheng Tao. Neural collapse inspired feature-classifier alignment for few-shot class-incremental learning. In ICLR, 2022b.
  • Yang et al. [2023] Yongyi Yang, Jacob Steinhardt, and Wei Hu. Are neurons actually collapsed? on the fine-grained structure in neural representations. ICML, 2023.
  • Yann LeCun [2010] Corinna Cortes and Yann LeCun. Mnist handwritten digit database. Available at http://yann.lecun.com/exdb/mnist/, 2010.
  • Yaras et al. [2022] Can Yaras, Peng Wang, Zhihui Zhu, Laura Balzano, and Qing Qu. Neural collapse with normalized features: A geometric analysis over the riemannian manifold. NeurIPS, 35:11547–11560, 2022.
  • Zhang et al. [2023a] Qing Zhang, Xiaoying Zhang, Yang Liu, Hongning Wang, Min Gao, Jiheng Zhang, and Ruocheng Guo. Debiasing recommendation by learning identifiable latent confounders. KDD, 2023a.
  • Zhang et al. [2023b] Yi Zhang, Jitao Sang, Junyang Wang, Dongmei Jiang, and Yaowei Wang. Benign shortcut for debiasing: Fair visual recognition via intervention with shortcut features. ACM MM, 2023b.
  • Zhong et al. [2023] Zhisheng Zhong, Jiequan Cui, Yibo Yang, Xiaoyang Wu, Xiaojuan Qi, Xiangyu Zhang, and Jiaya Jia. Understanding imbalanced semantic segmentation through neural collapse. In CVPR, pages 19550–19560, 2023.
  • Zhou et al. [2016] Bolei Zhou, Aditya Khosla, Agata Lapedriza, Aude Oliva, and Antonio Torralba. Learning deep features for discriminative localization. In CVPR, pages 2921–2929, 2016.
  • Zhou et al. [2022] Jinxin Zhou, Chong You, Xiao Li, Kangning Liu, Sheng Liu, Qing Qu, and Zhihui Zhu. Are all losses created equal: A neural collapse perspective. NeurIPS, 35:31697–31710, 2022.
  • Zhu et al. [2021] Zhihui Zhu, Tianyu Ding, Jinxin Zhou, Xiao Li, Chong You, Jeremias Sulam, and Qing Qu. A geometric analysis of neural collapse with unconstrained features. NeurIPS, 34:29820–29834, 2021.
\thetitle

Supplementary Material

Appendix A Detailed Theoretical Justification

A.1 Analysis of Vanilla Training

To illustrate why vanilla models tend to pursue shortcut learning, we follow the analysis of previous works [36, 43] and re-examine the issue of biased classification from the perspective of gradients.

Following the definition in Section 3.1, we denote 𝐱k,isubscript𝐱𝑘𝑖\mathbf{x}_{k,i}bold_x start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT as the i𝑖iitalic_i-th sample of the k𝑘kitalic_k-th class, 𝐳k,idsubscript𝐳𝑘𝑖superscript𝑑\mathbf{z}_{k,i}\in\mathbb{R}^{d}bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT as its corresponding last-layer feature, and 𝐖=[𝐰1,,𝐰K]𝐖subscript𝐰1subscript𝐰𝐾\mathbf{W}=[\mathbf{w}_{1},...,\mathbf{w}_{K}]bold_W = [ bold_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_w start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ] as the weights of classifier. In vanilla training, the cross-entropy loss is defined as:

CE(𝐳k,i,𝐖)=log(exp(𝐳k,iT𝐰k)k=1Kexp(𝐳k,iT𝐰k))subscriptCEsubscript𝐳𝑘𝑖𝐖superscriptsubscript𝐳𝑘𝑖Tsubscript𝐰𝑘superscriptsubscriptsuperscript𝑘1𝐾superscriptsubscript𝐳𝑘𝑖Tsubscript𝐰superscript𝑘\mathcal{L}_{\rm CE}(\mathbf{z}_{k,i},\mathbf{W})=-\log\left(\frac{\exp(% \mathbf{z}_{k,i}^{\mathrm{T}}\mathbf{w}_{k})}{\sum_{k^{\prime}=1}^{K}\exp(% \mathbf{z}_{k,i}^{\mathrm{T}}\mathbf{w}_{k^{\prime}})}\right)caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , bold_W ) = - roman_log ( divide start_ARG roman_exp ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_w start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) end_ARG ) (A.1)

Gradient w.r.t classifier weights. To analyze the learning behavior of the classifier, we first compute the gradient of CEsubscriptCE\mathcal{L}_{\rm CE}caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT w.r.t the classifier weights:

CE𝐰k=i=1nk(1pk(𝐳k,i))𝐳k,ipullingpart+kkKj=1nkpk(𝐳k,j)𝐳k,jforcingpartsubscriptCEsubscript𝐰𝑘subscriptsuperscriptsubscript𝑖1subscript𝑛𝑘1subscript𝑝𝑘subscript𝐳𝑘𝑖subscript𝐳𝑘𝑖pullingpartsubscriptsuperscriptsubscriptsuperscript𝑘𝑘𝐾superscriptsubscript𝑗1subscript𝑛superscript𝑘subscript𝑝𝑘subscript𝐳superscript𝑘𝑗subscript𝐳superscript𝑘𝑗forcingpart\frac{\partial\mathcal{L}_{\rm CE}}{\partial\mathbf{w}_{k}}=\underbrace{\sum_{% i=1}^{n_{k}}-(1-p_{k}(\mathbf{z}_{k,i}))\mathbf{z}_{k,i}}_{\rm pulling\ part}+% \underbrace{\sum_{k^{\prime}\neq k}^{K}\sum_{j=1}^{n_{k^{\prime}}}p_{k}(% \mathbf{z}_{k^{\prime},j})\mathbf{z}_{k^{\prime},j}}_{\rm forcing\ part}divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG = under⏟ start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - ( 1 - italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) ) bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT roman_pulling roman_part end_POSTSUBSCRIPT + under⏟ start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT ) bold_z start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT roman_forcing roman_part end_POSTSUBSCRIPT (A.2)

where nksubscript𝑛𝑘n_{k}italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT denotes the number of training samples in the k𝑘kitalic_k-th class, and pk(𝐳)subscript𝑝𝑘𝐳p_{k}(\mathbf{z})italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_z ) is the predicted probability of 𝐳𝐳\mathbf{z}bold_z belongs to the k𝑘kitalic_k-th class, which is calculated with the softmax function:

pk(𝐳)=exp(𝐳T𝐰k)k=1Kexp(𝐳T𝐰k),1kKformulae-sequencesubscript𝑝𝑘𝐳superscript𝐳Tsubscript𝐰𝑘superscriptsubscriptsuperscript𝑘1𝐾superscript𝐳Tsubscript𝐰superscript𝑘1𝑘𝐾p_{k}(\mathbf{z})=\frac{\exp(\mathbf{z}^{\mathrm{T}}\mathbf{w}_{k})}{\sum_{k^{% \prime}=1}^{K}\exp(\mathbf{z}^{\mathrm{T}}\mathbf{w}_{k^{\prime}})},1\leq k\leq Kitalic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_z ) = divide start_ARG roman_exp ( bold_z start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( bold_z start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_w start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) end_ARG , 1 ≤ italic_k ≤ italic_K (A.3)

In Eq. A.2, we decompose the gradient w.r.t the classifier weight 𝐰ksubscript𝐰𝑘\mathbf{w}_{k}bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT into two parts, the pulling part and the forcing part. The pulling part contains the effects of features from the same class (i.e., 𝐳k,isubscript𝐳𝑘𝑖\mathbf{z}_{k,i}bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT), which pulls the classifier weight towards the k𝑘kitalic_k-th feature cluster and each feature has an influence of 1pk(𝐳k,i)1subscript𝑝𝑘subscript𝐳𝑘𝑖1-p_{k}(\mathbf{z}_{k,i})1 - italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ). Meanwhile, the forcing part of the gradient contains the features from other classes to push 𝐰ksubscript𝐰𝑘\mathbf{w}_{k}bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT away from the wrong clusters, and each feature has an influence of pk(𝐳k,j)subscript𝑝𝑘subscript𝐳superscript𝑘𝑗p_{k}(\mathbf{z}_{k^{\prime},j})italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT ). When the vanilla model is trained on biased datasets, the prevalent bias-aligned samples of the k𝑘kitalic_k-th class will dominate the pulling part of the gradient. The classifier weight 𝐰ksubscript𝐰𝑘\mathbf{w}_{k}bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT will be pulled towards the center of the bias-aligned features, which have a strong correlation between the class label k𝑘kitalic_k and a bias attribute bksubscript𝑏𝑘b_{k}italic_b start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. The biased feature space based on shortcut will thus be formed at the early period of training, as the result of the imbalanced magnitude of gradients across different attributes. Similarly, the forcing part of the gradient is also guided by the bias-aligned samples of other classes, further reinforcing the tendency of shortcut learning. It confirms our observation in Section 3.2 that the model’s pursuit of shortcut correlation leads to a biased, non-collapsed feature space, which is hard to rectify in the subsequent training steps.

Gradient w.r.t features. Furthermore, we compute the gradient of CEsubscriptCE\mathcal{L}_{\rm CE}caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT w.r.t the last-layer features:

CE𝐳k,i=(1pk(𝐳k,i))𝐰kpullingpart+kkKpk(𝐳k,i)𝐰kforcingpartsubscriptCEsubscript𝐳𝑘𝑖subscript1subscript𝑝𝑘subscript𝐳𝑘𝑖subscript𝐰𝑘pullingpartsubscriptsuperscriptsubscriptsuperscript𝑘𝑘𝐾subscript𝑝superscript𝑘subscript𝐳𝑘𝑖subscript𝐰superscript𝑘forcingpart\frac{\partial\mathcal{L}_{\rm CE}}{\partial\mathbf{z}_{k,i}}=\underbrace{-(1-% p_{k}(\mathbf{z}_{k,i}))\mathbf{w}_{k}}_{\rm pulling\ part}+\underbrace{\sum_{% k^{\prime}\neq k}^{K}p_{k^{\prime}}(\mathbf{z}_{k,i})\mathbf{w}_{k^{\prime}}}_% {\rm forcing\ part}divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT end_ARG = under⏟ start_ARG - ( 1 - italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) ) bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT roman_pulling roman_part end_POSTSUBSCRIPT + under⏟ start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) bold_w start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT roman_forcing roman_part end_POSTSUBSCRIPT (A.4)

In Eq. A.4, the gradient w.r.t features is also considered as the combination of the pulling part and the forcing part. The pulling part represents the pulling effect of the classifier weight from the same class (i.e., 𝐰ksubscript𝐰𝑘\mathbf{w}_{k}bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT), which will guide the features to align with the prediction behavior of the classifier. The forcing part, on the contrary, represents the pushing effect of other classifier weights. As we discussed before, the model’s reliance on simple shortcuts is formed at the early stage of training, due to the misled classifier weights toward the centers of bias-aligned features. It results in a biased decision rule of the classifier, which directly affects the formation of the last-layer feature space. Consider the bias-conflicting samples of the k𝑘kitalic_k-th class, although the pulling part of the gradient supports its convergence towards the right classifier weight 𝐰ksubscript𝐰𝑘\mathbf{w}_{k}bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, the established shortcut correlation is hard to reverse and eliminate, which has been demonstrated with the metrics of Neural Collapse in Section 3.2.

A.2 Analysis of ETF-Debias

In light of the analysis of vanilla training, our proposed debiasing framework, ETF-Debias, turns the easy-to-follow shortcut into the prime features, which guides the model to skip the active learning of shortcuts and directly focus on the intrinsic correlations. We have provided a brief theoretical justification of our method in Section 4.3, and the detailed illustrations are as follows.

When trained on biased datasets, we assume each class [1,,K]1𝐾[1,...,K][ 1 , … , italic_K ] is strongly correlated with a bias attribute [b1,,bK]subscript𝑏1subscript𝑏𝐾[b_{1},...,b_{K}][ italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_b start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ]. Based on the mechanism of prime training, we denote the i𝑖iitalic_i-th feature of the k𝑘kitalic_k-th class as 𝐳~k,i=[𝐳k,i,𝐦i,b]2×dsubscript~𝐳𝑘𝑖subscript𝐳𝑘𝑖subscript𝐦𝑖𝑏superscript2𝑑\widetilde{\mathbf{z}}_{k,i}=[\mathbf{z}_{k,i},\mathbf{m}_{i,b}]\in\mathbb{R}^% {2\times d}over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT = [ bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT 2 × italic_d end_POSTSUPERSCRIPT, where 𝐳k,idsubscript𝐳𝑘𝑖superscript𝑑\mathbf{z}_{k,i}\in\mathbb{R}^{d}bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT represents the learnable features, and 𝐦i,bdsubscript𝐦𝑖𝑏superscript𝑑\mathbf{m}_{i,b}\in\mathbb{R}^{d}bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT represents the prime features retrieved based on the bias attribute b𝑏bitalic_b of the input sample. With the definition, a bias-aligned sample of the k𝑘kitalic_k-th class 𝐱k,isubscript𝐱𝑘𝑖\mathbf{x}_{k,i}bold_x start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT will have its feature in form of 𝐳~k,i=[𝐳k,i,𝐦bk]subscript~𝐳𝑘𝑖subscript𝐳𝑘𝑖subscript𝐦subscript𝑏𝑘\widetilde{\mathbf{z}}_{k,i}=[\mathbf{z}_{k,i},\mathbf{m}_{b_{k}}]over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT = [ bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ], and a bias-conflicting sample will have its feature as 𝐳~k,i=[𝐳k,i,𝐦bk],kkformulae-sequencesubscript~𝐳𝑘𝑖subscript𝐳𝑘𝑖subscript𝐦subscript𝑏superscript𝑘superscript𝑘𝑘\widetilde{\mathbf{z}}_{k,i}=[\mathbf{z}_{k,i},\mathbf{m}_{b_{k^{\prime}}}],k^% {\prime}\neq kover~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT = [ bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] , italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_k. To keep the same form, we denote the classifier weight as 𝐰~k=[𝐰k,𝐚k]2×dsubscript~𝐰𝑘subscript𝐰𝑘subscript𝐚𝑘superscript2𝑑\widetilde{\mathbf{w}}_{k}=[\mathbf{w}_{k},\mathbf{a}_{k}]\in\mathbb{R}^{2% \times d}over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = [ bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT 2 × italic_d end_POSTSUPERSCRIPT, where 𝐰kdsubscript𝐰𝑘superscript𝑑\mathbf{w}_{k}\in\mathbb{R}^{d}bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT represents the weight for intrinsic correlations and 𝐚kdsubscript𝐚𝑘superscript𝑑\mathbf{a}_{k}\in\mathbb{R}^{d}bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is the weight for shortcut features. In the detailed convergence result of Neural Collapse (Section C), we observe that, due to the fixed prime features and their strong correlation with the bias attributes, 𝐚ksubscript𝐚𝑘\mathbf{a}_{k}bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT will quickly collapse into its bias-correlated prime feature 𝐦bksubscript𝐦subscript𝑏𝑘\mathbf{m}_{b_{k}}bold_m start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT, and can be viewed as constant after just a few steps of training. Thus the cross-entropy loss can be re-written as:

CE(𝐳~k,i,𝐖~)=log(exp(𝐳~k,iT𝐰~k)k=1Kexp(𝐳~k,iT𝐰~k))=log(exp([𝐳k,i,𝐦i,b]T[𝐰k,𝐚k])k=1K[𝐳k,i,𝐦i,b]T[𝐰k,𝐚k])subscriptCEsubscript~𝐳𝑘𝑖~𝐖superscriptsubscript~𝐳𝑘𝑖Tsubscript~𝐰𝑘superscriptsubscriptsuperscript𝑘1𝐾superscriptsubscript~𝐳𝑘𝑖Tsubscript~𝐰superscript𝑘superscriptsubscript𝐳𝑘𝑖subscript𝐦𝑖𝑏Tsubscript𝐰𝑘subscript𝐚𝑘superscriptsubscriptsuperscript𝑘1𝐾superscriptsubscript𝐳𝑘𝑖subscript𝐦𝑖𝑏Tsubscript𝐰superscript𝑘subscript𝐚superscript𝑘\mathcal{L}_{\rm CE}(\widetilde{\mathbf{z}}_{k,i},\widetilde{\mathbf{W}})=-% \log\left(\frac{\exp(\widetilde{\mathbf{z}}_{k,i}^{\mathrm{T}}\widetilde{% \mathbf{w}}_{k})}{\sum_{k^{\prime}=1}^{K}\exp(\widetilde{\mathbf{z}}_{k,i}^{% \mathrm{T}}\widetilde{\mathbf{w}}_{k^{\prime}})}\right)=-\log\left(\frac{\exp(% [\mathbf{z}_{k,i},\mathbf{m}_{i,b}]^{\mathrm{T}}[\mathbf{w}_{k},\mathbf{a}_{k}% ])}{\sum_{k^{\prime}=1}^{K}[\mathbf{z}_{k,i},\mathbf{m}_{i,b}]^{\mathrm{T}}[% \mathbf{w}_{k^{\prime}},\mathbf{a}_{k^{\prime}}]}\right)caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , over~ start_ARG bold_W end_ARG ) = - roman_log ( divide start_ARG roman_exp ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) end_ARG ) = - roman_log ( divide start_ARG roman_exp ( [ bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT [ bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ] ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT [ bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT [ bold_w start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ] end_ARG ) (A.5)

Gradient w.r.t classifier weights. With the avoid-shortcut learning framework, we justify that the introduced prime mechanism implicitly plays the role of re-weighting, which weakens the mutual convergence between bias-aligned features and the classifier weights, and amplifies the learning of intrinsic correlations. We first analyze the gradient of CEsubscriptCE\mathcal{L}_{\rm CE}caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT w.r.t the classifier weights 𝐖~~𝐖\widetilde{\mathbf{W}}over~ start_ARG bold_W end_ARG:

CE𝐰~ksubscriptCEsubscript~𝐰𝑘\displaystyle\frac{\partial\mathcal{L}_{\rm CE}}{\partial\widetilde{\mathbf{w}% }_{k}}divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT end_ARG start_ARG ∂ over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG =i=1nk(1pk(𝐳~k,i))𝐳~k,i+kkKj=1nkpk(𝐳~k,j)𝐳~k,jabsentsuperscriptsubscript𝑖1subscript𝑛𝑘1subscript𝑝𝑘subscript~𝐳𝑘𝑖subscript~𝐳𝑘𝑖superscriptsubscriptsuperscript𝑘𝑘𝐾superscriptsubscript𝑗1subscript𝑛superscript𝑘subscript𝑝𝑘subscript~𝐳superscript𝑘𝑗subscript~𝐳superscript𝑘𝑗\displaystyle=\sum_{i=1}^{n_{k}}-(1-p_{k}(\widetilde{\mathbf{z}}_{k,i}))% \widetilde{\mathbf{z}}_{k,i}+\sum_{k^{\prime}\neq k}^{K}\sum_{j=1}^{n_{k^{% \prime}}}p_{k}(\widetilde{\mathbf{z}}_{k^{\prime},j})\widetilde{\mathbf{z}}_{k% ^{\prime},j}= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - ( 1 - italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) ) over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT ) over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT
=i=1nk(1exp([𝐳k,i,𝐦i,b]T[𝐰k,𝐚k])k=1Kexp(𝐳~k,iT𝐰~k))𝐳~k,i+kkKj=1nkexp([𝐳k,j,𝐦j,b]T[𝐰k,𝐚k])k=1Kexp(𝐳~k,jT𝐰~k)𝐳~k,jabsentsuperscriptsubscript𝑖1subscript𝑛𝑘1superscriptsubscript𝐳𝑘𝑖subscript𝐦𝑖𝑏Tsubscript𝐰𝑘subscript𝐚𝑘superscriptsubscriptsuperscript𝑘1𝐾superscriptsubscript~𝐳𝑘𝑖Tsubscript~𝐰superscript𝑘subscript~𝐳𝑘𝑖superscriptsubscriptsuperscript𝑘𝑘𝐾superscriptsubscript𝑗1subscript𝑛superscript𝑘superscriptsubscript𝐳superscript𝑘𝑗subscript𝐦𝑗superscript𝑏Tsubscript𝐰𝑘subscript𝐚𝑘superscriptsubscriptsuperscript𝑘1𝐾superscriptsubscript~𝐳superscript𝑘𝑗Tsubscript~𝐰superscript𝑘subscript~𝐳superscript𝑘𝑗\displaystyle=\sum_{i=1}^{n_{k}}-\left(1-\frac{\exp([\mathbf{z}_{k,i},\mathbf{% m}_{i,b}]^{\mathrm{T}}[\mathbf{w}_{k},\mathbf{a}_{k}])}{\sum_{k^{\prime}=1}^{K% }\exp(\widetilde{\mathbf{z}}_{k,i}^{\mathrm{T}}\widetilde{\mathbf{w}}_{k^{% \prime}})}\right)\widetilde{\mathbf{z}}_{k,i}+\sum_{k^{\prime}\neq k}^{K}\sum_% {j=1}^{n_{k^{\prime}}}\frac{\exp([\mathbf{z}_{k^{\prime},j},\mathbf{m}_{j,b^{% \prime}}]^{\mathrm{T}}[\mathbf{w}_{k},\mathbf{a}_{k}])}{\sum_{k^{\prime}=1}^{K% }\exp(\widetilde{\mathbf{z}}_{k^{\prime},j}^{\mathrm{T}}\widetilde{\mathbf{w}}% _{k^{\prime}})}\widetilde{\mathbf{z}}_{k^{\prime},j}= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - ( 1 - divide start_ARG roman_exp ( [ bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT [ bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ] ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) end_ARG ) over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT divide start_ARG roman_exp ( [ bold_z start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_j , italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT [ bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ] ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) end_ARG over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT (A.6)
i=1nk(1exp(𝐳k,iT𝐰k)k=1Kexp(𝐳~k,iT𝐰~k)exp(𝐦i,bT𝐚k)k=1Kexp(𝐳~k,iT𝐰~k))𝐳~k,iabsentsuperscriptsubscript𝑖1subscript𝑛𝑘1superscriptsubscript𝐳𝑘𝑖Tsubscript𝐰𝑘superscriptsubscriptsuperscript𝑘1𝐾superscriptsubscript~𝐳𝑘𝑖Tsubscript~𝐰superscript𝑘superscriptsubscript𝐦𝑖𝑏Tsubscript𝐚𝑘superscriptsubscriptsuperscript𝑘1𝐾superscriptsubscript~𝐳𝑘𝑖Tsubscript~𝐰superscript𝑘subscript~𝐳𝑘𝑖\displaystyle\leq\sum_{i=1}^{n_{k}}-\left(1-\frac{\exp(\mathbf{z}_{k,i}^{% \mathrm{T}}\mathbf{w}_{k})}{\sum_{k^{\prime}=1}^{K}\exp(\widetilde{\mathbf{z}}% _{k,i}^{\mathrm{T}}\widetilde{\mathbf{w}}_{k^{\prime}})}-\frac{\exp(\mathbf{m}% _{i,b}^{\mathrm{T}}\mathbf{a}_{k})}{\sum_{k^{\prime}=1}^{K}\exp(\widetilde{% \mathbf{z}}_{k,i}^{\mathrm{T}}\widetilde{\mathbf{w}}_{k^{\prime}})}\right)% \widetilde{\mathbf{z}}_{k,i}≤ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - ( 1 - divide start_ARG roman_exp ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) end_ARG - divide start_ARG roman_exp ( bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) end_ARG ) over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT
+kkKj=1nk(exp(𝐳k,jT𝐰k)k′′=1Kexp(𝐳~k,jT𝐰~k′′)+exp(𝐦j,bT𝐚k)k′′=1Kexp(𝐳~k,jT𝐰~k′′))𝐳~k,jsuperscriptsubscriptsuperscript𝑘𝑘𝐾superscriptsubscript𝑗1subscript𝑛superscript𝑘superscriptsubscript𝐳superscript𝑘𝑗Tsubscript𝐰𝑘superscriptsubscriptsuperscript𝑘′′1𝐾superscriptsubscript~𝐳superscript𝑘𝑗Tsubscript~𝐰superscript𝑘′′superscriptsubscript𝐦𝑗superscript𝑏Tsubscript𝐚𝑘superscriptsubscriptsuperscript𝑘′′1𝐾superscriptsubscript~𝐳superscript𝑘𝑗Tsubscript~𝐰superscript𝑘′′subscript~𝐳superscript𝑘𝑗\displaystyle\ \ \ \ +\sum_{k^{\prime}\neq k}^{K}\sum_{j=1}^{n_{k^{\prime}}}% \left(\frac{\exp(\mathbf{z}_{k^{\prime},j}^{\mathrm{T}}\mathbf{w}_{k})}{\sum_{% k^{\prime\prime}=1}^{K}\exp(\widetilde{\mathbf{z}}_{k^{\prime},j}^{\mathrm{T}}% \widetilde{\mathbf{w}}_{k^{\prime\prime}})}+\frac{\exp(\mathbf{m}_{j,b^{\prime% }}^{\mathrm{T}}\mathbf{a}_{k})}{\sum_{k^{\prime\prime}=1}^{K}\exp(\widetilde{% \mathbf{z}}_{k^{\prime},j}^{\mathrm{T}}\widetilde{\mathbf{w}}_{k^{\prime\prime% }})}\right)\widetilde{\mathbf{z}}_{k^{\prime},j}+ ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( divide start_ARG roman_exp ( bold_z start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) end_ARG + divide start_ARG roman_exp ( bold_m start_POSTSUBSCRIPT italic_j , italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_a start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT over~ start_ARG bold_w end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) end_ARG ) over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT (A.7)
i=1nk(1pk(b)(𝐦i,b)pk(l)(𝐳k,i))𝐳~k,ipullingpart+kkKj=1nk(pk(b)(𝐦j,b)+pk(l)(𝐳k,j))𝐳~k,jforcingpartabsentsubscriptsuperscriptsubscript𝑖1subscript𝑛𝑘1superscriptsubscript𝑝𝑘𝑏subscript𝐦𝑖𝑏superscriptsubscript𝑝𝑘𝑙subscript𝐳𝑘𝑖subscript~𝐳𝑘𝑖pullingpartsubscriptsuperscriptsubscriptsuperscript𝑘𝑘𝐾superscriptsubscript𝑗1subscript𝑛superscript𝑘superscriptsubscript𝑝𝑘𝑏subscript𝐦𝑗superscript𝑏superscriptsubscript𝑝𝑘𝑙subscript𝐳superscript𝑘𝑗subscript~𝐳superscript𝑘𝑗forcingpart\displaystyle\leq\underbrace{\sum_{i=1}^{n_{k}}-(1-p_{k}^{(b)}(\mathbf{m}_{i,b% })-p_{k}^{(l)}(\mathbf{z}_{k,i}))\widetilde{\mathbf{z}}_{k,i}}_{\rm pulling\ % part}+\underbrace{\sum_{k^{\prime}\neq k}^{K}\sum_{j=1}^{n_{k^{\prime}}}(p_{k}% ^{(b)}(\mathbf{m}_{j,b^{\prime}})+p_{k}^{(l)}(\mathbf{z}_{k^{\prime},j}))% \widetilde{\mathbf{z}}_{k^{\prime},j}}_{\rm forcing\ part}≤ under⏟ start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - ( 1 - italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b ) end_POSTSUPERSCRIPT ( bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ) - italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l ) end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) ) over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT roman_pulling roman_part end_POSTSUBSCRIPT + under⏟ start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≠ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b ) end_POSTSUPERSCRIPT ( bold_m start_POSTSUBSCRIPT italic_j , italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) + italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l ) end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT ) ) over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT roman_forcing roman_part end_POSTSUBSCRIPT (A.8)

The predicted probabilities both satisfy 0<(1pk(𝐳~k,i))<101subscript𝑝𝑘subscript~𝐳𝑘𝑖10<(1-p_{k}(\widetilde{\mathbf{z}}_{k,i}))<10 < ( 1 - italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) ) < 1 and 0<pk(𝐳~k,j)<10subscript𝑝𝑘subscript~𝐳superscript𝑘𝑗10<p_{k}(\widetilde{\mathbf{z}}_{k^{\prime},j})<10 < italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( over~ start_ARG bold_z end_ARG start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_j end_POSTSUBSCRIPT ) < 1, and the scaling step from Eq. A.6 to Eq. A.7 is based on the Jensen inequality. The terms pk(l)superscriptsubscript𝑝𝑘𝑙p_{k}^{(l)}italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l ) end_POSTSUPERSCRIPT and pk(b)superscriptsubscript𝑝𝑘𝑏p_{k}^{(b)}italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_b ) end_POSTSUPERSCRIPT are respectively the predicted probabilities for class labels and bias attributes, calculated with softmax:

pk(l)(𝐳k,i)=exp(𝐳k,iT𝐰k)k=1Kexp([𝐳k,i,𝐦i,b]T[𝐰k,𝐚k])superscriptsubscript𝑝𝑘𝑙subscript𝐳𝑘𝑖superscriptsubscript𝐳𝑘𝑖Tsubscript𝐰𝑘superscriptsubscriptsuperscript𝑘1𝐾superscriptsubscript𝐳𝑘𝑖subscript𝐦𝑖𝑏Tsubscript𝐰superscript𝑘subscript𝐚superscript𝑘\displaystyle p_{k}^{(l)}(\mathbf{z}_{k,i})=\frac{\exp(\mathbf{z}_{k,i}^{% \mathrm{T}}\mathbf{w}_{k})}{\sum_{k^{\prime}=1}^{K}\exp([\mathbf{z}_{k,i},% \mathbf{m}_{i,b}]^{\mathrm{T}}[\mathbf{w}_{k^{\prime}},\mathbf{a}_{k^{\prime}}% ])}italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_l ) end_POSTSUPERSCRIPT ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) = divide start_ARG roman_exp ( bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT bold_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( [ bold_z start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_i , italic_b end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT roman_T end_POSTSUPERSCRIPT [ bold_w start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ] ) end_ARG (A.9)
pk(b)(𝐦i,b)=exp(