[Mlir-commits] [mlir] [MLIR][SCF] Add support for vectorization hints in `scf-to-cf` lowering. (PR #134201)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 2 22:25:34 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: MingYan (NexMing)
<details>
<summary>Changes</summary>
Add vectorization hints when convert SCF parallel loop to ControlFlow dialect, and provide an option to control it.
---
Full diff: https://github.com/llvm/llvm-project/pull/134201.diff
4 Files Affected:
- (modified) mlir/include/mlir/Conversion/Passes.td (+6-1)
- (modified) mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h (+2-1)
- (modified) mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp (+24-4)
- (modified) mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir (+16-6)
``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..5660c55cfb4a4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -984,7 +984,12 @@ def ReconcileUnrealizedCastsPass : Pass<"reconcile-unrealized-casts"> {
def SCFToControlFlowPass : Pass<"convert-scf-to-cf"> {
let summary = "Convert SCF dialect to ControlFlow dialect, replacing structured"
" control flow with a CFG";
- let dependentDialects = ["cf::ControlFlowDialect"];
+ let dependentDialects = ["cf::ControlFlowDialect", "LLVM::LLVMDialect"];
+
+ let options = [Option<"enableVectorizeHits", "enable-vectorize-hits", "bool",
+ /*default=*/"false",
+ "Add vectorization hints when convert SCF parallel "
+ "loop to ControlFlow dialect">];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h b/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h
index 2def01d208f72..e062debda9211 100644
--- a/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h
+++ b/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h
@@ -20,7 +20,8 @@ class RewritePatternSet;
/// Collect a set of patterns to convert SCF operations to CFG branch-based
/// operations within the ControlFlow dialect.
-void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns);
+void populateSCFToControlFlowConversionPatterns(
+ RewritePatternSet &patterns, bool enableVectorizeHits = false);
} // namespace mlir
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 114d634629d77..224b510294a10 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -38,6 +38,7 @@ namespace {
struct SCFToControlFlowPass
: public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> {
+ using Base::Base;
void runOnOperation() override;
};
@@ -212,6 +213,11 @@ struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
+ bool enableVectorizeHits;
+
+ ParallelLowering(mlir::MLIRContext *ctx, bool enableVectorizeHits)
+ : OpRewritePattern(ctx), enableVectorizeHits(enableVectorizeHits) {}
+
LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
PatternRewriter &rewriter) const override;
};
@@ -487,6 +493,13 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
return failure();
}
+ auto vecAttr = LLVM::LoopVectorizeAttr::get(
+ rewriter.getContext(),
+ /* disable */ rewriter.getBoolAttr(false), {}, {}, {}, {}, {}, {});
+ auto loopAnnotation = LLVM::LoopAnnotationAttr::get(
+ rewriter.getContext(), {}, /*vectorize=*/vecAttr, {}, {}, {}, {}, {}, {},
+ {}, {}, {}, {}, {}, {}, {});
+
// For a parallel loop, we essentially need to create an n-dimensional loop
// nest. We do this by translating to scf.for ops and have those lowered in
// a further rewrite. If a parallel loop contains reductions (and thus returns
@@ -517,6 +530,11 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
rewriter.create<scf::YieldOp>(loc, forOp.getResults());
}
+ if (enableVectorizeHits)
+ forOp->setAttr(LLVM::BrOp::getLoopAnnotationAttrName(OperationName(
+ LLVM::BrOp::getOperationName(), getContext())),
+ loopAnnotation);
+
rewriter.setInsertionPointToStart(forOp.getBody());
}
@@ -706,16 +724,18 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
}
void mlir::populateSCFToControlFlowConversionPatterns(
- RewritePatternSet &patterns) {
- patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
- WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
+ RewritePatternSet &patterns, bool enableVectorizeHits) {
+ patterns.add<ForallLowering, ForLowering, IfLowering, WhileLowering,
+ ExecuteRegionLowering, IndexSwitchLowering>(
patterns.getContext());
+ patterns.add<ParallelLowering>(patterns.getContext(), enableVectorizeHits);
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
}
void SCFToControlFlowPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
- populateSCFToControlFlowConversionPatterns(patterns);
+ populateSCFToControlFlowConversionPatterns(patterns,
+ enableVectorizeHits.getValue());
// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index 9ea0093eff786..c8c9635aa530f 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -1,4 +1,9 @@
-// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s --check-prefixes=CHECK,NO-VEC
+// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf="enable-vectorize-hits=true" \
+// RUN: -split-input-file %s | FileCheck %s --check-prefixes=CHECK,VEC
+
+// VEC: #loop_vectorize = #llvm.loop_vectorize<disable = false>
+// VEC-NEXT: #[[$VEC_ATTR:.+]] = #llvm.loop_annotation<vectorize = #loop_vectorize>
// CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: cf.br ^bb1(%{{.*}} : index)
@@ -332,7 +337,8 @@ func.func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
// variable and the current partially reduced value.
// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32
// CHECK: %[[COMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]]
- // CHECK: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
+ // NO-VEC: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
+ // VEC: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]] {loop_annotation = #[[$VEC_ATTR]]}
// Bodies of scf.reduce operations are folded into the main loop body. The
// result of this partial reduction is passed as argument to the condition
@@ -366,11 +372,13 @@ func.func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
// CHECK: %[[INIT2:.*]] = arith.constant 42
// CHECK: cf.br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]]
// CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64
- // CHECK: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
+ // NO-VEC: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
+ // VEC: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]] {loop_annotation = #[[$VEC_ATTR]]}
// CHECK: ^[[BODY_OUT]]:
// CHECK: cf.br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
// CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64
- // CHECK: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
+ // NO-VEC: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
+ // VEC: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]] {loop_annotation = #[[$VEC_ATTR]]}
// CHECK: ^[[BODY_IN]]:
// CHECK: %[[REDUCE1:.*]] = arith.addf %[[ITER_ARG1_IN]], %{{.*}}
// CHECK: %[[REDUCE2:.*]] = arith.ori %[[ITER_ARG2_IN]], %{{.*}}
@@ -551,7 +559,8 @@ func.func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1,
// CHECK: cf.br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index)
// CHECK: ^[[LOOP_LATCH]](%[[LOOP_IV:.*]]: index):
// CHECK: %[[LOOP_COND:.*]] = arith.cmpi slt, %[[LOOP_IV]], %[[ARG1]] : index
- // CHECK: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
+ // NO-VEC: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]]
+ // VEC: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]] {loop_annotation = #[[$VEC_ATTR]]}
// CHECK: ^[[LOOP_BODY]]:
// CHECK: cf.cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]]
// CHECK: ^[[IF1_THEN]]:
@@ -660,7 +669,8 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
// CHECK: cf.br ^[[bb1:.*]](%[[c0]] : index)
// CHECK: ^[[bb1]](%[[arg0:.*]]: index):
// CHECK: %[[cmpi:.*]] = arith.cmpi slt, %[[arg0]], %[[num_threads]]
-// CHECK: cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]]
+// NO-VEC: cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]]
+// VEC: cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]] {loop_annotation = #[[$VEC_ATTR]]}
// CHECK: ^[[bb2]]:
// CHECK: "test.foo"(%[[arg0]])
// CHECK: %[[addi:.*]] = arith.addi %[[arg0]], %[[c1]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/134201
More information about the Mlir-commits
mailing list