[llvm] 3564791 - [IndVarSimplify] Fix `IndVarSimplify` to skip unfolding predicates when the loop contains control convergence operations. (#165643)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 26 09:04:45 PST 2025


Author: Lucie Choi
Date: 2025-11-26T09:04:41-08:00
New Revision: 356479191ca001df47136c89cc9a761c64a6323c

URL: https://github.com/llvm/llvm-project/commit/356479191ca001df47136c89cc9a761c64a6323c
DIFF: https://github.com/llvm/llvm-project/commit/356479191ca001df47136c89cc9a761c64a6323c.diff

LOG: [IndVarSimplify] Fix `IndVarSimplify` to skip unfolding predicates when the loop contains control convergence operations. (#165643)

Skip constant folding the loop predicates if the loop contains control
convergence tokens referenced outside the loop.

Fixes https://github.com/llvm/llvm-project/issues/164496.

Verified
[loop_peeling.test](https://github.com/llvm/offload-test-suite/pull/473)
passes with the fix.

Similar control convergence issues are found on other passes.
https://github.com/llvm/llvm-project/issues/165642

HLSL used for tests:
```hlsl
RWStructuredBuffer<uint> Out : register(u0);

[numthreads(8,1,1)]
void main(uint3 TID : SV_GroupThreadID) {
    for (uint i = 0; i < 8; i++) {
        if (i == TID.x) {
            Out[TID.x] = WaveActiveMax(TID.x);
            break;
        }
    }
}
```
With nested loop:
```hlsl
RWStructuredBuffer<uint> Out : register(u0);

[numthreads(8,8,1)]
void main(uint3 TID : SV_GroupThreadID) {
    for (uint i = 0; i < 8; i++) {
        for (uint j = 0; j < 8; j++) {
            if (i == TID.x && j == TID.y) {
                uint index = TID.x * 8 + TID.y;
                Out[index] = WaveActiveMax(index);
                break;
            }
        }
    }
}
```

Added: 
    llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll
    llvm/test/Transforms/IndVarSimplify/skip-predication-nested-convergence.ll

Modified: 
    llvm/lib/Transforms/Scalar/IndVarSimplify.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
index b46527eb1057b..19d801acd928e 100644
--- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -1855,7 +1855,7 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
   // is that enough for *all* side effects?
   bool HasThreadLocalSideEffects = false;
   for (BasicBlock *BB : L->blocks())
-    for (auto &I : *BB)
+    for (auto &I : *BB) {
       // TODO:isGuaranteedToTransfer
       if (I.mayHaveSideEffects()) {
         if (!LoopPredicationTraps)
@@ -1873,6 +1873,18 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
         }
       }
 
+      // Skip if the loop has tokens referenced outside the loop to avoid
+      // changing convergence behavior.
+      if (I.getType()->isTokenTy()) {
+        for (User *U : I.users()) {
+          Instruction *UserInst = dyn_cast<Instruction>(U);
+          if (UserInst && !L->contains(UserInst)) {
+            return false;
+          }
+        }
+      }
+    }
+
   bool Changed = false;
   // Finally, do the actual predication for all predicatable blocks.  A couple
   // of notes here:

diff  --git a/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll b/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll
new file mode 100644
index 0000000000000..59b84a3c082c2
--- /dev/null
+++ b/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll
@@ -0,0 +1,64 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=indvars -indvars-predicate-loops=1 -S | FileCheck %s
+
+; Loop with body using loop convergence token should be skipped by IndVarSimplify.
+
+declare token @llvm.experimental.convergence.entry() #0
+
+define void @loop(i32 %tid, ptr %array) #0 {
+; CHECK-LABEL: @loop(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = tail call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT:    br label [[FOR_COND_I:%.*]]
+; CHECK:       for.cond.i:
+; CHECK-NEXT:    [[I_0_I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC_I:%.*]], [[FOR_BODY_I:%.*]] ]
+; CHECK-NEXT:    [[TMP1:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP0]]) ]
+; CHECK-NEXT:    [[CMP_I:%.*]] = icmp ult i32 [[I_0_I]], 8
+; CHECK-NEXT:    br i1 [[CMP_I]], label [[FOR_BODY_I]], label [[EXIT_LOOPEXIT:%.*]]
+; CHECK:       for.body.i:
+; CHECK-NEXT:    [[CMP1_I:%.*]] = icmp eq i32 [[I_0_I]], [[TID:%.*]]
+; CHECK-NEXT:    [[INC_I]] = add nuw nsw i32 [[I_0_I]], 1
+; CHECK-NEXT:    br i1 [[CMP1_I]], label [[IF_THEN_I:%.*]], label [[FOR_COND_I]]
+; CHECK:       exit.loopexit:
+; CHECK-NEXT:    br label [[EXIT:%.*]]
+; CHECK:       if.then.i:
+; CHECK-NEXT:    [[HLSL_WAVE_ACTIVE_MAX2_I:%.*]] = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 [[TID]]) [ "convergencectrl"(token [[TMP1]]) ]
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[ARRAY:%.*]], i32 [[TID]]
+; CHECK-NEXT:    store i32 [[HLSL_WAVE_ACTIVE_MAX2_I]], ptr [[TMP2]], align 4
+; CHECK-NEXT:    br label [[EXIT]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %0 = tail call token @llvm.experimental.convergence.entry()
+  br label %for.cond.i
+
+for.cond.i:
+  %i.0.i = phi i32 [ 0, %entry ], [ %inc.i, %for.body.i ]
+  %2 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+  %cmp.i = icmp ult i32 %i.0.i, 8
+  br i1 %cmp.i, label %for.body.i, label %exit.loopexit
+
+for.body.i:
+  %cmp1.i = icmp eq i32 %i.0.i, %tid
+  %inc.i = add nuw nsw i32 %i.0.i, 1
+  br i1 %cmp1.i, label %if.then.i, label %for.cond.i
+
+exit.loopexit:
+  br label %exit
+
+if.then.i:
+  %hlsl.wave.active.max2.i = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 %tid) [ "convergencectrl"(token %2) ]
+  %3 = getelementptr inbounds i32, ptr %array, i32 %tid
+  store i32 %hlsl.wave.active.max2.i, ptr %3, align 4
+  br label %exit
+
+exit:
+  ret void
+}
+
+declare token @llvm.experimental.convergence.loop() #0
+
+declare i32 @llvm.spv.wave.reduce.umax.i32(i32) #0
+
+attributes #0 = { convergent }

diff  --git a/llvm/test/Transforms/IndVarSimplify/skip-predication-nested-convergence.ll b/llvm/test/Transforms/IndVarSimplify/skip-predication-nested-convergence.ll
new file mode 100644
index 0000000000000..0944205839aca
--- /dev/null
+++ b/llvm/test/Transforms/IndVarSimplify/skip-predication-nested-convergence.ll
@@ -0,0 +1,95 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=indvars -indvars-predicate-loops=1 -S | FileCheck %s
+
+; Nested loops with body using loop convergence token should be skipped by IndVarSimplify.
+
+declare token @llvm.experimental.convergence.entry() #0
+
+define void @nested(i32 %tidx, i32 %tidy, ptr %array) #0 {
+; CHECK-LABEL: @nested(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = tail call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT:    [[MUL_I:%.*]] = shl nsw i32 [[TIDX:%.*]], 3
+; CHECK-NEXT:    [[ADD_I:%.*]] = add nsw i32 [[MUL_I]], [[TIDY:%.*]]
+; CHECK-NEXT:    br label [[FOR_COND_I:%.*]]
+; CHECK:       for.cond.i:
+; CHECK-NEXT:    [[I_0_I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC10_I:%.*]], [[CLEANUP_I:%.*]] ]
+; CHECK-NEXT:    [[TMP1:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP0]]) ]
+; CHECK-NEXT:    [[CMP_I:%.*]] = icmp ult i32 [[I_0_I]], 8
+; CHECK-NEXT:    br i1 [[CMP_I]], label [[FOR_COND1_I_PREHEADER:%.*]], label [[EXIT:%.*]]
+; CHECK:       for.cond1.i.preheader:
+; CHECK-NEXT:    [[CMP5_I:%.*]] = icmp eq i32 [[I_0_I]], [[TIDX]]
+; CHECK-NEXT:    br label [[FOR_COND1_I:%.*]]
+; CHECK:       for.cond1.i:
+; CHECK-NEXT:    [[J_0_I:%.*]] = phi i32 [ [[INC_I:%.*]], [[FOR_BODY4_I:%.*]] ], [ 0, [[FOR_COND1_I_PREHEADER]] ]
+; CHECK-NEXT:    [[TMP2:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP1]]) ]
+; CHECK-NEXT:    [[CMP2_I:%.*]] = icmp ult i32 [[J_0_I]], 8
+; CHECK-NEXT:    br i1 [[CMP2_I]], label [[FOR_BODY4_I]], label [[CLEANUP_I_LOOPEXIT:%.*]]
+; CHECK:       for.body4.i:
+; CHECK-NEXT:    [[CMP6_I:%.*]] = icmp eq i32 [[J_0_I]], [[TIDY]]
+; CHECK-NEXT:    [[OR_COND:%.*]] = select i1 [[CMP5_I]], i1 [[CMP6_I]], i1 false
+; CHECK-NEXT:    [[INC_I]] = add nuw nsw i32 [[J_0_I]], 1
+; CHECK-NEXT:    br i1 [[OR_COND]], label [[IF_THEN_I:%.*]], label [[FOR_COND1_I]]
+; CHECK:       cleanup.i.loopexit:
+; CHECK-NEXT:    br label [[CLEANUP_I]]
+; CHECK:       if.then.i:
+; CHECK-NEXT:    [[HLSL_WAVE_ACTIVE_MAX7_I:%.*]] = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 [[ADD_I]]) [ "convergencectrl"(token [[TMP2]]) ]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[ARRAY:%.*]], i32 [[ADD_I]]
+; CHECK-NEXT:    store i32 [[HLSL_WAVE_ACTIVE_MAX7_I]], ptr [[TMP3]], align 4
+; CHECK-NEXT:    br label [[CLEANUP_I]]
+; CHECK:       cleanup.i:
+; CHECK-NEXT:    [[INC10_I]] = add nuw nsw i32 [[I_0_I]], 1
+; CHECK-NEXT:    br label [[FOR_COND_I]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %0 = tail call token @llvm.experimental.convergence.entry()
+  %mul.i = shl nsw i32 %tidx, 3
+  %add.i = add nsw i32 %mul.i, %tidy
+  br label %for.cond.i
+
+for.cond.i:
+  %i.0.i = phi i32 [ 0, %entry ], [ %inc10.i, %cleanup.i ]
+  %2 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+  %cmp.i = icmp ult i32 %i.0.i, 8
+  br i1 %cmp.i, label %for.cond1.i.preheader, label %exit
+
+for.cond1.i.preheader:
+  %cmp5.i = icmp eq i32 %i.0.i, %tidx
+  br label %for.cond1.i
+
+for.cond1.i:
+  %j.0.i = phi i32 [ %inc.i, %for.body4.i ], [ 0, %for.cond1.i.preheader ]
+  %3 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %2) ]
+  %cmp2.i = icmp ult i32 %j.0.i, 8
+  br i1 %cmp2.i, label %for.body4.i, label %cleanup.i.loopexit
+
+for.body4.i:
+  %cmp6.i = icmp eq i32 %j.0.i, %tidy
+  %or.cond = select i1 %cmp5.i, i1 %cmp6.i, i1 false
+  %inc.i = add nsw i32 %j.0.i, 1
+  br i1 %or.cond, label %if.then.i, label %for.cond1.i
+
+cleanup.i.loopexit:
+  br label %cleanup.i
+
+if.then.i:
+  %hlsl.wave.active.max7.i = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 %add.i) [ "convergencectrl"(token %3) ]
+  %4 = getelementptr inbounds i32, ptr %array, i32 %add.i
+  store i32 %hlsl.wave.active.max7.i, ptr %4, align 4
+  br label %cleanup.i
+
+cleanup.i:
+  %inc10.i = add nsw i32 %i.0.i, 1
+  br label %for.cond.i
+
+exit:
+  ret void
+}
+
+declare token @llvm.experimental.convergence.loop() #0
+
+declare i32 @llvm.spv.wave.reduce.umax.i32(i32) #0
+
+attributes #0 = { convergent }


        


More information about the llvm-commits mailing list