[Mlir-commits] [mlir] [MLIR][SCF] Add support for vectorization hints in `scf-to-cf` lowering. (PR #134201)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 4 10:38:42 PDT 2025


https://github.com/NexMing updated https://github.com/llvm/llvm-project/pull/134201

>From 40cf4898fb6beb13a6bfd36e2b5c6d7f9ba441b3 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Thu, 3 Apr 2025 11:48:51 +0800
Subject: [PATCH 1/2] [MLIR][SCF] Add support for vectorization hints in
 `scf-to-cf` lowering and provide an option to control it.

---
 mlir/include/mlir/Conversion/Passes.td        |  7 ++++-
 .../SCFToControlFlow/SCFToControlFlow.h       |  3 +-
 .../SCFToControlFlow/SCFToControlFlow.cpp     | 28 ++++++++++++++++---
 .../SCFToControlFlow/convert-to-cfg.mlir      | 22 +++++++++++----
 4 files changed, 48 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..5660c55cfb4a4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -984,7 +984,12 @@ def ReconcileUnrealizedCastsPass : Pass<"reconcile-unrealized-casts"> {
 def SCFToControlFlowPass : Pass<"convert-scf-to-cf"> {
   let summary = "Convert SCF dialect to ControlFlow dialect, replacing structured"
                 " control flow with a CFG";
-  let dependentDialects = ["cf::ControlFlowDialect"];
+  let dependentDialects = ["cf::ControlFlowDialect", "LLVM::LLVMDialect"];
+
+  let options = [Option<"enableVectorizeHits", "enable-vectorize-hits", "bool",
+                        /*default=*/"false",
+                        "Add vectorization hints when convert SCF parallel "
+                        "loop to ControlFlow dialect">];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h b/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h
index 2def01d208f72..e062debda9211 100644
--- a/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h
+++ b/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h
@@ -20,7 +20,8 @@ class RewritePatternSet;
 
 /// Collect a set of patterns to convert SCF operations to CFG branch-based
 /// operations within the ControlFlow dialect.
-void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns);
+void populateSCFToControlFlowConversionPatterns(
+    RewritePatternSet &patterns, bool enableVectorizeHits = false);
 
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 114d634629d77..224b510294a10 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -38,6 +38,7 @@ namespace {
 
 struct SCFToControlFlowPass
     : public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> {
+  using Base::Base;
   void runOnOperation() override;
 };
 
@@ -212,6 +213,11 @@ struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
 struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
   using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
 
+  bool enableVectorizeHits;
+
+  ParallelLowering(mlir::MLIRContext *ctx, bool enableVectorizeHits)
+      : OpRewritePattern(ctx), enableVectorizeHits(enableVectorizeHits) {}
+
   LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
                                 PatternRewriter &rewriter) const override;
 };
@@ -487,6 +493,13 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
     return failure();
   }
 
+  auto vecAttr = LLVM::LoopVectorizeAttr::get(
+      rewriter.getContext(),
+      /* disable */ rewriter.getBoolAttr(false), {}, {}, {}, {}, {}, {});
+  auto loopAnnotation = LLVM::LoopAnnotationAttr::get(
+      rewriter.getContext(), {}, /*vectorize=*/vecAttr, {}, {}, {}, {}, {}, {},
+      {}, {}, {}, {}, {}, {}, {});
+
   // For a parallel loop, we essentially need to create an n-dimensional loop
   // nest. We do this by translating to scf.for ops and have those lowered in
   // a further rewrite. If a parallel loop contains reductions (and thus returns
@@ -517,6 +530,11 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
       rewriter.create<scf::YieldOp>(loc, forOp.getResults());
     }
 
+    if (enableVectorizeHits)
+      forOp->setAttr(LLVM::BrOp::getLoopAnnotationAttrName(OperationName(
+                         LLVM::BrOp::getOperationName(), getContext())),
+                     loopAnnotation);
+
     rewriter.setInsertionPointToStart(forOp.getBody());
   }
 
@@ -706,16 +724,18 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
 }
 
 void mlir::populateSCFToControlFlowConversionPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
-               WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
+    RewritePatternSet &patterns, bool enableVectorizeHits) {
+  patterns.add<ForallLowering, ForLowering, IfLowering, WhileLowering,
+               ExecuteRegionLowering, IndexSwitchLowering>(
       patterns.getContext());
+  patterns.add<ParallelLowering>(patterns.getContext(), enableVectorizeHits);
   patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
 }
 
 void SCFToControlFlowPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
-  populateSCFToControlFlowConversionPatterns(patterns);
+  populateSCFToControlFlowConversionPatterns(patterns,
+                                             enableVectorizeHits.getValue());
 
   // Configure conversion to lower out SCF operations.
   ConversionTarget target(getContext());
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index 9ea0093eff786..c8c9635aa530f 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -1,4 +1,9 @@
-// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s --check-prefixes=CHECK,NO-VEC
+// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf="enable-vectorize-hits=true" \
+// RUN:          -split-input-file %s | FileCheck %s --check-prefixes=CHECK,VEC
+
+// VEC: #loop_vectorize = #llvm.loop_vectorize<disable = false>
+// VEC-NEXT: #[[$VEC_ATTR:.+]] = #llvm.loop_annotation<vectorize = #loop_vectorize>
 
 // CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
 //  CHECK-NEXT:  cf.br ^bb1(%{{.*}} : index)
@@ -332,7 +337,8 @@ func.func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
   // variable and the current partially reduced value.
   // CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32
   // CHECK:   %[[COMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]]
-  // CHECK:   cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
+  // NO-VEC:   cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
+  // VEC:   cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]] {loop_annotation = #[[$VEC_ATTR]]}
 
   // Bodies of scf.reduce operations are folded into the main loop body. The
   // result of this partial reduction is passed as argument to the condition
@@ -366,11 +372,13 @@ func.func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
   // CHECK:   %[[INIT2:.*]] = arith.constant 42
   // CHECK:   cf.br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]]
   // CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64
-  // CHECK:   cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
+  // NO-VEC:   cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
+  // VEC:   cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]] {loop_annotation = #[[$VEC_ATTR]]}
   // CHECK: ^[[BODY_OUT]]:
   // CHECK:   cf.br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
   // CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64
-  // CHECK:   cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
+  // NO-VEC:   cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
+  // VEC:   cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]] {loop_annotation = #[[$VEC_ATTR]]}
   // CHECK: ^[[BODY_IN]]:
   // CHECK:   %[[REDUCE1:.*]] = arith.addf %[[ITER_ARG1_IN]], %{{.*}}
   // CHECK:   %[[REDUCE2:.*]] = arith.ori %[[ITER_ARG2_IN]], %{{.*}}
@@ -551,7 +559,8 @@ func.func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1,
   // CHECK:   cf.br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index)
   // CHECK: ^[[LOOP_LATCH]](%[[LOOP_IV:.*]]: index):
   // CHECK:   %[[LOOP_COND:.*]] = arith.cmpi slt, %[[LOOP_IV]], %[[ARG1]] : index
-  // CHECK:   cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
+  // NO-VEC:   cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
+  // VEC:   cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]] {loop_annotation = #[[$VEC_ATTR]]}
   // CHECK: ^[[LOOP_BODY]]:
   // CHECK:   cf.cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]]
   // CHECK: ^[[IF1_THEN]]:
@@ -660,7 +669,8 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
 //       CHECK:   cf.br ^[[bb1:.*]](%[[c0]] : index)
 //       CHECK: ^[[bb1]](%[[arg0:.*]]: index):
 //       CHECK:   %[[cmpi:.*]] = arith.cmpi slt, %[[arg0]], %[[num_threads]]
-//       CHECK:   cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]]
+//       NO-VEC:   cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]]
+//       VEC:   cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]] {loop_annotation = #[[$VEC_ATTR]]}
 //       CHECK: ^[[bb2]]:
 //       CHECK:   "test.foo"(%[[arg0]])
 //       CHECK:   %[[addi:.*]] = arith.addi %[[arg0]], %[[c1]]

>From 719ae0227ed56d4316da9ebad00eeaf711432580 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Fri, 4 Apr 2025 22:54:23 +0800
Subject: [PATCH 2/2] Fix typos

---
 mlir/include/mlir/Conversion/Passes.td           |  3 ++-
 .../SCFToControlFlow/SCFToControlFlow.h          |  2 +-
 .../SCFToControlFlow/SCFToControlFlow.cpp        | 16 ++++++++--------
 .../SCFToControlFlow/convert-to-cfg.mlir         |  2 +-
 4 files changed, 12 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5660c55cfb4a4..0d1ef4dc89829 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -986,7 +986,8 @@ def SCFToControlFlowPass : Pass<"convert-scf-to-cf"> {
                 " control flow with a CFG";
   let dependentDialects = ["cf::ControlFlowDialect", "LLVM::LLVMDialect"];
 
-  let options = [Option<"enableVectorizeHits", "enable-vectorize-hits", "bool",
+  let options = [Option<"enableVectorizeHints", "enable-vectorize-hints",
+                        "bool",
                         /*default=*/"false",
                         "Add vectorization hints when convert SCF parallel "
                         "loop to ControlFlow dialect">];
diff --git a/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h b/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h
index e062debda9211..ca1185d6bb3b5 100644
--- a/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h
+++ b/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h
@@ -21,7 +21,7 @@ class RewritePatternSet;
 /// Collect a set of patterns to convert SCF operations to CFG branch-based
 /// operations within the ControlFlow dialect.
 void populateSCFToControlFlowConversionPatterns(
-    RewritePatternSet &patterns, bool enableVectorizeHits = false);
+    RewritePatternSet &patterns, bool enableVectorizeHints = false);
 
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 224b510294a10..071067adcdb4b 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -213,10 +213,10 @@ struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
 struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
   using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
 
-  bool enableVectorizeHits;
+  bool enableVectorizeHints;
 
-  ParallelLowering(mlir::MLIRContext *ctx, bool enableVectorizeHits)
-      : OpRewritePattern(ctx), enableVectorizeHits(enableVectorizeHits) {}
+  ParallelLowering(mlir::MLIRContext *ctx, bool enableVectorizeHints)
+      : OpRewritePattern(ctx), enableVectorizeHints(enableVectorizeHints) {}
 
   LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
                                 PatternRewriter &rewriter) const override;
@@ -495,7 +495,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
 
   auto vecAttr = LLVM::LoopVectorizeAttr::get(
       rewriter.getContext(),
-      /* disable */ rewriter.getBoolAttr(false), {}, {}, {}, {}, {}, {});
+      /*disable=*/rewriter.getBoolAttr(false), {}, {}, {}, {}, {}, {});
   auto loopAnnotation = LLVM::LoopAnnotationAttr::get(
       rewriter.getContext(), {}, /*vectorize=*/vecAttr, {}, {}, {}, {}, {}, {},
       {}, {}, {}, {}, {}, {}, {});
@@ -530,7 +530,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
       rewriter.create<scf::YieldOp>(loc, forOp.getResults());
     }
 
-    if (enableVectorizeHits)
+    if (enableVectorizeHints)
       forOp->setAttr(LLVM::BrOp::getLoopAnnotationAttrName(OperationName(
                          LLVM::BrOp::getOperationName(), getContext())),
                      loopAnnotation);
@@ -724,18 +724,18 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
 }
 
 void mlir::populateSCFToControlFlowConversionPatterns(
-    RewritePatternSet &patterns, bool enableVectorizeHits) {
+    RewritePatternSet &patterns, bool enableVectorizeHints) {
   patterns.add<ForallLowering, ForLowering, IfLowering, WhileLowering,
                ExecuteRegionLowering, IndexSwitchLowering>(
       patterns.getContext());
-  patterns.add<ParallelLowering>(patterns.getContext(), enableVectorizeHits);
+  patterns.add<ParallelLowering>(patterns.getContext(), enableVectorizeHints);
   patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
 }
 
 void SCFToControlFlowPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   populateSCFToControlFlowConversionPatterns(patterns,
-                                             enableVectorizeHits.getValue());
+                                             enableVectorizeHints.getValue());
 
   // Configure conversion to lower out SCF operations.
   ConversionTarget target(getContext());
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index c8c9635aa530f..8db76d6c7466e 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s --check-prefixes=CHECK,NO-VEC
-// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf="enable-vectorize-hits=true" \
+// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf="enable-vectorize-hints=true" \
 // RUN:          -split-input-file %s | FileCheck %s --check-prefixes=CHECK,VEC
 
 // VEC: #loop_vectorize = #llvm.loop_vectorize<disable = false>



More information about the Mlir-commits mailing list