[llvm] 3564791 - [IndVarSimplify] Fix `IndVarSimplify` to skip unfolding predicates when the loop contains control convergence operations. (#165643)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 26 09:04:45 PST 2025
Author: Lucie Choi
Date: 2025-11-26T09:04:41-08:00
New Revision: 356479191ca001df47136c89cc9a761c64a6323c
URL: https://github.com/llvm/llvm-project/commit/356479191ca001df47136c89cc9a761c64a6323c
DIFF: https://github.com/llvm/llvm-project/commit/356479191ca001df47136c89cc9a761c64a6323c.diff
LOG: [IndVarSimplify] Fix `IndVarSimplify` to skip unfolding predicates when the loop contains control convergence operations. (#165643)
Skip constant folding the loop predicates if the loop contains control
convergence tokens referenced outside the loop.
Fixes https://github.com/llvm/llvm-project/issues/164496.
Verified
[loop_peeling.test](https://github.com/llvm/offload-test-suite/pull/473)
passes with the fix.
Similar control convergence issues are found on other passes.
https://github.com/llvm/llvm-project/issues/165642
HLSL used for tests:
```hlsl
RWStructuredBuffer<uint> Out : register(u0);
[numthreads(8,1,1)]
void main(uint3 TID : SV_GroupThreadID) {
for (uint i = 0; i < 8; i++) {
if (i == TID.x) {
Out[TID.x] = WaveActiveMax(TID.x);
break;
}
}
}
```
With nested loop:
```hlsl
RWStructuredBuffer<uint> Out : register(u0);
[numthreads(8,8,1)]
void main(uint3 TID : SV_GroupThreadID) {
for (uint i = 0; i < 8; i++) {
for (uint j = 0; j < 8; j++) {
if (i == TID.x && j == TID.y) {
uint index = TID.x * 8 + TID.y;
Out[index] = WaveActiveMax(index);
break;
}
}
}
}
```
Added:
llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll
llvm/test/Transforms/IndVarSimplify/skip-predication-nested-convergence.ll
Modified:
llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
index b46527eb1057b..19d801acd928e 100644
--- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -1855,7 +1855,7 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
// is that enough for *all* side effects?
bool HasThreadLocalSideEffects = false;
for (BasicBlock *BB : L->blocks())
- for (auto &I : *BB)
+ for (auto &I : *BB) {
// TODO:isGuaranteedToTransfer
if (I.mayHaveSideEffects()) {
if (!LoopPredicationTraps)
@@ -1873,6 +1873,18 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
}
}
+ // Skip if the loop has tokens referenced outside the loop to avoid
+ // changing convergence behavior.
+ if (I.getType()->isTokenTy()) {
+ for (User *U : I.users()) {
+ Instruction *UserInst = dyn_cast<Instruction>(U);
+ if (UserInst && !L->contains(UserInst)) {
+ return false;
+ }
+ }
+ }
+ }
+
bool Changed = false;
// Finally, do the actual predication for all predicatable blocks. A couple
// of notes here:
diff --git a/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll b/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll
new file mode 100644
index 0000000000000..59b84a3c082c2
--- /dev/null
+++ b/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll
@@ -0,0 +1,64 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=indvars -indvars-predicate-loops=1 -S | FileCheck %s
+
+; Loop with body using loop convergence token should be skipped by IndVarSimplify.
+
+declare token @llvm.experimental.convergence.entry() #0
+
+define void @loop(i32 %tid, ptr %array) #0 {
+; CHECK-LABEL: @loop(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = tail call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT: br label [[FOR_COND_I:%.*]]
+; CHECK: for.cond.i:
+; CHECK-NEXT: [[I_0_I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC_I:%.*]], [[FOR_BODY_I:%.*]] ]
+; CHECK-NEXT: [[TMP1:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP0]]) ]
+; CHECK-NEXT: [[CMP_I:%.*]] = icmp ult i32 [[I_0_I]], 8
+; CHECK-NEXT: br i1 [[CMP_I]], label [[FOR_BODY_I]], label [[EXIT_LOOPEXIT:%.*]]
+; CHECK: for.body.i:
+; CHECK-NEXT: [[CMP1_I:%.*]] = icmp eq i32 [[I_0_I]], [[TID:%.*]]
+; CHECK-NEXT: [[INC_I]] = add nuw nsw i32 [[I_0_I]], 1
+; CHECK-NEXT: br i1 [[CMP1_I]], label [[IF_THEN_I:%.*]], label [[FOR_COND_I]]
+; CHECK: exit.loopexit:
+; CHECK-NEXT: br label [[EXIT:%.*]]
+; CHECK: if.then.i:
+; CHECK-NEXT: [[HLSL_WAVE_ACTIVE_MAX2_I:%.*]] = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 [[TID]]) [ "convergencectrl"(token [[TMP1]]) ]
+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[ARRAY:%.*]], i32 [[TID]]
+; CHECK-NEXT: store i32 [[HLSL_WAVE_ACTIVE_MAX2_I]], ptr [[TMP2]], align 4
+; CHECK-NEXT: br label [[EXIT]]
+; CHECK: exit:
+; CHECK-NEXT: ret void
+;
+entry:
+ %0 = tail call token @llvm.experimental.convergence.entry()
+ br label %for.cond.i
+
+for.cond.i:
+ %i.0.i = phi i32 [ 0, %entry ], [ %inc.i, %for.body.i ]
+ %2 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+ %cmp.i = icmp ult i32 %i.0.i, 8
+ br i1 %cmp.i, label %for.body.i, label %exit.loopexit
+
+for.body.i:
+ %cmp1.i = icmp eq i32 %i.0.i, %tid
+ %inc.i = add nuw nsw i32 %i.0.i, 1
+ br i1 %cmp1.i, label %if.then.i, label %for.cond.i
+
+exit.loopexit:
+ br label %exit
+
+if.then.i:
+ %hlsl.wave.active.max2.i = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 %tid) [ "convergencectrl"(token %2) ]
+ %3 = getelementptr inbounds i32, ptr %array, i32 %tid
+ store i32 %hlsl.wave.active.max2.i, ptr %3, align 4
+ br label %exit
+
+exit:
+ ret void
+}
+
+declare token @llvm.experimental.convergence.loop() #0
+
+declare i32 @llvm.spv.wave.reduce.umax.i32(i32) #0
+
+attributes #0 = { convergent }
diff --git a/llvm/test/Transforms/IndVarSimplify/skip-predication-nested-convergence.ll b/llvm/test/Transforms/IndVarSimplify/skip-predication-nested-convergence.ll
new file mode 100644
index 0000000000000..0944205839aca
--- /dev/null
+++ b/llvm/test/Transforms/IndVarSimplify/skip-predication-nested-convergence.ll
@@ -0,0 +1,95 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=indvars -indvars-predicate-loops=1 -S | FileCheck %s
+
+; Nested loops with body using loop convergence token should be skipped by IndVarSimplify.
+
+declare token @llvm.experimental.convergence.entry() #0
+
+define void @nested(i32 %tidx, i32 %tidy, ptr %array) #0 {
+; CHECK-LABEL: @nested(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = tail call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT: [[MUL_I:%.*]] = shl nsw i32 [[TIDX:%.*]], 3
+; CHECK-NEXT: [[ADD_I:%.*]] = add nsw i32 [[MUL_I]], [[TIDY:%.*]]
+; CHECK-NEXT: br label [[FOR_COND_I:%.*]]
+; CHECK: for.cond.i:
+; CHECK-NEXT: [[I_0_I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC10_I:%.*]], [[CLEANUP_I:%.*]] ]
+; CHECK-NEXT: [[TMP1:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP0]]) ]
+; CHECK-NEXT: [[CMP_I:%.*]] = icmp ult i32 [[I_0_I]], 8
+; CHECK-NEXT: br i1 [[CMP_I]], label [[FOR_COND1_I_PREHEADER:%.*]], label [[EXIT:%.*]]
+; CHECK: for.cond1.i.preheader:
+; CHECK-NEXT: [[CMP5_I:%.*]] = icmp eq i32 [[I_0_I]], [[TIDX]]
+; CHECK-NEXT: br label [[FOR_COND1_I:%.*]]
+; CHECK: for.cond1.i:
+; CHECK-NEXT: [[J_0_I:%.*]] = phi i32 [ [[INC_I:%.*]], [[FOR_BODY4_I:%.*]] ], [ 0, [[FOR_COND1_I_PREHEADER]] ]
+; CHECK-NEXT: [[TMP2:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP1]]) ]
+; CHECK-NEXT: [[CMP2_I:%.*]] = icmp ult i32 [[J_0_I]], 8
+; CHECK-NEXT: br i1 [[CMP2_I]], label [[FOR_BODY4_I]], label [[CLEANUP_I_LOOPEXIT:%.*]]
+; CHECK: for.body4.i:
+; CHECK-NEXT: [[CMP6_I:%.*]] = icmp eq i32 [[J_0_I]], [[TIDY]]
+; CHECK-NEXT: [[OR_COND:%.*]] = select i1 [[CMP5_I]], i1 [[CMP6_I]], i1 false
+; CHECK-NEXT: [[INC_I]] = add nuw nsw i32 [[J_0_I]], 1
+; CHECK-NEXT: br i1 [[OR_COND]], label [[IF_THEN_I:%.*]], label [[FOR_COND1_I]]
+; CHECK: cleanup.i.loopexit:
+; CHECK-NEXT: br label [[CLEANUP_I]]
+; CHECK: if.then.i:
+; CHECK-NEXT: [[HLSL_WAVE_ACTIVE_MAX7_I:%.*]] = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 [[ADD_I]]) [ "convergencectrl"(token [[TMP2]]) ]
+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[ARRAY:%.*]], i32 [[ADD_I]]
+; CHECK-NEXT: store i32 [[HLSL_WAVE_ACTIVE_MAX7_I]], ptr [[TMP3]], align 4
+; CHECK-NEXT: br label [[CLEANUP_I]]
+; CHECK: cleanup.i:
+; CHECK-NEXT: [[INC10_I]] = add nuw nsw i32 [[I_0_I]], 1
+; CHECK-NEXT: br label [[FOR_COND_I]]
+; CHECK: exit:
+; CHECK-NEXT: ret void
+;
+entry:
+ %0 = tail call token @llvm.experimental.convergence.entry()
+ %mul.i = shl nsw i32 %tidx, 3
+ %add.i = add nsw i32 %mul.i, %tidy
+ br label %for.cond.i
+
+for.cond.i:
+ %i.0.i = phi i32 [ 0, %entry ], [ %inc10.i, %cleanup.i ]
+ %2 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+ %cmp.i = icmp ult i32 %i.0.i, 8
+ br i1 %cmp.i, label %for.cond1.i.preheader, label %exit
+
+for.cond1.i.preheader:
+ %cmp5.i = icmp eq i32 %i.0.i, %tidx
+ br label %for.cond1.i
+
+for.cond1.i:
+ %j.0.i = phi i32 [ %inc.i, %for.body4.i ], [ 0, %for.cond1.i.preheader ]
+ %3 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %2) ]
+ %cmp2.i = icmp ult i32 %j.0.i, 8
+ br i1 %cmp2.i, label %for.body4.i, label %cleanup.i.loopexit
+
+for.body4.i:
+ %cmp6.i = icmp eq i32 %j.0.i, %tidy
+ %or.cond = select i1 %cmp5.i, i1 %cmp6.i, i1 false
+ %inc.i = add nsw i32 %j.0.i, 1
+ br i1 %or.cond, label %if.then.i, label %for.cond1.i
+
+cleanup.i.loopexit:
+ br label %cleanup.i
+
+if.then.i:
+ %hlsl.wave.active.max7.i = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 %add.i) [ "convergencectrl"(token %3) ]
+ %4 = getelementptr inbounds i32, ptr %array, i32 %add.i
+ store i32 %hlsl.wave.active.max7.i, ptr %4, align 4
+ br label %cleanup.i
+
+cleanup.i:
+ %inc10.i = add nsw i32 %i.0.i, 1
+ br label %for.cond.i
+
+exit:
+ ret void
+}
+
+declare token @llvm.experimental.convergence.loop() #0
+
+declare i32 @llvm.spv.wave.reduce.umax.i32(i32) #0
+
+attributes #0 = { convergent }
More information about the llvm-commits
mailing list