[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (PR #173978)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 5 07:57:24 PST 2026


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff origin/main HEAD --extensions cpp -- mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp --diff_from_common_commit
``````````

:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 3d3c601d9..44c27abb2 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -265,11 +265,9 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
   omp::YieldOp::create(builder, reduce.getLoc(), ArrayRef<Value>());
   return decl;
 }
-/// Returns true if the type is supported by llvm.atomicrmw. 
+/// Returns true if the type is supported by llvm.atomicrmw.
 /// LLVM IR does not support atomic operations on vector types.
-static bool supportsAtomic(Type type) {
-  return !isa<VectorType>(type);
-}
+static bool supportsAtomic(Type type) { return !isa<VectorType>(type); }
 
 /// Creates an OpenMP reduction declaration that corresponds to the given SCF
 /// reduction and returns it. Recognizes common reductions in order to identify
@@ -353,21 +351,21 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
     return createDecl(builder, symbolTable, reduce, reductionIndex,
                       getAttr(builder.getFloatAttr(elType, 1.0)));
 
-  if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction))
-    return createDecl(builder, symbolTable, reduce, reductionIndex,
-                      getAttr(builder.getIntegerAttr(elType, 1)));
-
-  // Match select-based min/max reductions.
-  bool isMin;
-  if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
-          reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
-          {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
-      matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
-          reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
-          {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
-    return createDecl(builder, symbolTable, reduce, reductionIndex,
-                      minMaxValueForFloat(type, !isMin));
-  }
+    if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction))
+      return createDecl(builder, symbolTable, reduce, reductionIndex,
+                        getAttr(builder.getIntegerAttr(elType, 1)));
+
+    // Match select-based min/max reductions.
+    bool isMin;
+    if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
+            reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
+            {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
+        matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
+            reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
+            {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
+      return createDecl(builder, symbolTable, reduce, reductionIndex,
+                        minMaxValueForFloat(type, !isMin));
+    }
   if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
           reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
           {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
@@ -376,12 +374,12 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
           {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
     omp::DeclareReductionOp decl =
         createDecl(builder, symbolTable, reduce, reductionIndex,
-                           minMaxValueForSignedInt(type, !isMin));
-    return supportsAtomic(type)
-               ? addAtomicRMW(builder,
-                              isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
-                              decl, reduce, reductionIndex)
-               : decl;
+                   minMaxValueForSignedInt(type, !isMin));
+    return supportsAtomic(type) ? addAtomicRMW(builder,
+                                               isMin ? LLVM::AtomicBinOp::min
+                                                     : LLVM::AtomicBinOp::max,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
   if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
           reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
@@ -391,12 +389,12 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
           {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
     omp::DeclareReductionOp decl =
         createDecl(builder, symbolTable, reduce, reductionIndex,
-                           minMaxValueForUnsignedInt(type, !isMin));
-    return supportsAtomic(type)
-               ? addAtomicRMW(builder,
-                              isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
-                              decl, reduce, reductionIndex)
-               : decl;
+                   minMaxValueForUnsignedInt(type, !isMin));
+    return supportsAtomic(type) ? addAtomicRMW(builder,
+                                               isMin ? LLVM::AtomicBinOp::umin
+                                                     : LLVM::AtomicBinOp::umax,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
 
   return nullptr;
@@ -485,7 +483,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         Operation *cloneOp = builder.clone(op, mapper);
         if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
           assert(yieldOp && yieldOp.getResults().size() == 1 &&
-                    "expect YieldOp in reduction region to return one result");
+                 "expect YieldOp in reduction region to return one result");
           Value redVal = yieldOp.getResults()[0];
           LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
           rewriter.eraseOp(yieldOp);

``````````

</details>


https://github.com/llvm/llvm-project/pull/173978


More information about the Mlir-commits mailing list