[Mlir-commits] [mlir] [MLIR][SCF] Propagate loop annotation during while op lowering (PR #151746)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 1 11:42:10 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Thomas Raoux (ThomasRaoux)
<details>
<summary>Changes</summary>
This is expanding on https://github.com/llvm/llvm-project/pull/102562
This allows also propagating attributes for scf.while lowering
---
Full diff: https://github.com/llvm/llvm-project/pull/151746.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp (+20-15)
- (modified) mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir (+42-1)
``````````diff
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 807be7e1003c0..ae943f3c82100 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -312,6 +312,19 @@ struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
} // namespace
+static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) {
+ // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
+ // llvm.loop_annotation attribute.
+ // LLVM requires the loop metadata to be attached on the "latch" block. Which
+ // is the back-edge to the header block (conditionBlock)
+ SmallVector<NamedAttribute> llvmAttrs;
+ llvm::copy_if(scfOp->getAttrs(), std::back_inserter(llvmAttrs),
+ [](auto attr) {
+ return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
+ });
+ brOp->setDiscardableAttrs(llvmAttrs);
+}
+
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const {
Location loc = forOp.getLoc();
@@ -350,17 +363,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
auto branchOp =
cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried);
- // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
- // llvm.loop_annotation attribute.
- // LLVM requires the loop metadata to be attached on the "latch" block. Which
- // is the back-edge to the header block (conditionBlock)
- SmallVector<NamedAttribute> llvmAttrs;
- llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs),
- [](auto attr) {
- return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
- });
- branchOp->setDiscardableAttrs(llvmAttrs);
-
+ propagateLoopAttrs(forOp, branchOp);
rewriter.eraseOp(terminator);
// Compute loop bounds before branching to the condition.
@@ -589,9 +592,10 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
rewriter.setInsertionPointToEnd(after);
auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
- rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
+ auto latch =rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
yieldOp.getResults());
+ propagateLoopAttrs(whileOp, latch);
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
rewriter.replaceOp(whileOp, args);
@@ -631,10 +635,11 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
// Loop around the "before" region based on condition.
rewriter.setInsertionPointToEnd(before);
auto condOp = cast<ConditionOp>(before->getTerminator());
- cf::CondBranchOp::create(rewriter, condOp.getLoc(), condOp.getCondition(),
- before, condOp.getArgs(), continuation,
- ValueRange());
+ auto latch = cf::CondBranchOp::create(
+ rewriter, condOp.getLoc(), condOp.getCondition(), before,
+ condOp.getArgs(), continuation, ValueRange());
+ propagateLoopAttrs(whileOp, latch);
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
rewriter.replaceOp(whileOp, condOp.getArgs());
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index e6fdb7ab5ecd8..ef0fa083a021a 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -708,4 +708,45 @@ func.func @simple_std_for_loops_annotation(%arg0 : index, %arg1 : index, %arg2 :
} {llvm.loop_annotation = #full_unroll}
} {llvm.loop_annotation = #no_unroll}
return
-}
\ No newline at end of file
+}
+
+// -----
+
+// CHECK: #[[LOOP_UNROLL_DISABLE:.*]] = #llvm.loop_unroll<disable = true>
+// CHECK: #[[NO_UNROLL:.*]] = #llvm.loop_annotation<unroll = #[[LOOP_UNROLL_DISABLE]]>
+// CHECK: func @simple_while_loops_annotation
+// CHECK: cf.br
+// CHECK: cf.cond_br {{.*}} {llvm.loop_annotation = #[[NO_UNROLL]]}
+// CHECK: return
+#no_unroll = #llvm.loop_annotation<unroll = <disable = true>>
+func.func @simple_while_loops_annotation(%arg0 : i1) {
+ scf.while : () -> () {
+ scf.condition(%arg0)
+ } do {
+ scf.yield
+ } attributes {llvm.loop_annotation = #no_unroll}
+ return
+}
+
+// -----
+
+// CHECK: #[[LOOP_UNROLL_DISABLE:.*]] = #llvm.loop_unroll<disable = true>
+// CHECK: #[[NO_UNROLL:.*]] = #llvm.loop_annotation<unroll = #[[LOOP_UNROLL_DISABLE]]>
+// CHECK: func @do_while_loops_annotation
+// CHECK: cf.br
+// CHECK: cf.cond_br
+// CHECK: cf.br {{.*}} {llvm.loop_annotation = #[[NO_UNROLL]]}
+// CHECK: return
+#no_unroll = #llvm.loop_annotation<unroll = <disable = true>>
+func.func @do_while_loops_annotation() {
+ %c0_i32 = arith.constant 0 : i32
+ scf.while (%arg2 = %c0_i32) : (i32) -> (i32) {
+ %0 = "test.make_condition"() : () -> i1
+ scf.condition(%0) %c0_i32 : i32
+ } do {
+ ^bb0(%arg2: i32):
+ scf.yield %c0_i32: i32
+ } attributes {llvm.loop_annotation = #no_unroll}
+ return
+}
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/151746
More information about the Mlir-commits
mailing list