[llvm-branch-commits] [mlir] fe68b17 - [MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (#173978)

Cullen Rhodes via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 15 01:50:01 PST 2026


Author: Aniket Singh
Date: 2026-01-15T09:49:53Z
New Revision: fe68b17f46d470c2aa5223bb3cc4fec0d14801f9

URL: https://github.com/llvm/llvm-project/commit/fe68b17f46d470c2aa5223bb3cc4fec0d14801f9
DIFF: https://github.com/llvm/llvm-project/commit/fe68b17f46d470c2aa5223bb3cc4fec0d14801f9.diff

LOG: [MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (#173978)

This patch fixes a crash in the SCF to OpenMP conversion pass when
encountering scf.parallel with vector reductions.

- Extracts scalar element types for bitwidth calculations.
- Uses DenseElementsAttr for vector splat initializers.
- Bypasses llvm.atomicrmw for vector types (not supported in LLVM IR).

Fixes #173860

---------

Co-authored-by: Aniket Singh <amiket.singh.3200.00 at gmail.com>

Added: 
    mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir

Modified: 
    mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..5fcaea7f39c3c 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -150,32 +150,48 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
   llvm_unreachable("unknown float type");
 }
 
+/// Helper to create a splat attribute for vector types, or return the scalar
+/// attribute for scalar types.
+static Attribute getSplatOrScalarAttr(Type type, Attribute val) {
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return DenseElementsAttr::get(vecType, val);
+  return val;
+}
+
 /// Returns an attribute with the minimum (if `min` is set) or the maximum value
 /// (otherwise) for the given float type.
 static Attribute minMaxValueForFloat(Type type, bool min) {
-  auto fltType = cast<FloatType>(type);
-  return FloatAttr::get(
-      type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
+  Type elType = getElementTypeOrSelf(type);
+  auto fltType = cast<FloatType>(elType);
+  auto val = llvm::APFloat::getLargest(fltSemanticsForType(fltType), min);
+
+  return getSplatOrScalarAttr(type, FloatAttr::get(elType, val));
 }
 
 /// Returns an attribute with the signed integer minimum (if `min` is set) or
 /// the maximum value (otherwise) for the given integer type, regardless of its
 /// signedness semantics (only the width is considered).
 static Attribute minMaxValueForSignedInt(Type type, bool min) {
-  auto intType = cast<IntegerType>(type);
+  Type elType = getElementTypeOrSelf(type);
+  auto intType = cast<IntegerType>(elType);
   unsigned bitwidth = intType.getWidth();
-  return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
-                                    : llvm::APInt::getSignedMaxValue(bitwidth));
+  auto val = min ? llvm::APInt::getSignedMinValue(bitwidth)
+                 : llvm::APInt::getSignedMaxValue(bitwidth);
+
+  return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val));
 }
 
 /// Returns an attribute with the unsigned integer minimum (if `min` is set) or
 /// the maximum value (otherwise) for the given integer type, regardless of its
 /// signedness semantics (only the width is considered).
 static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
-  auto intType = cast<IntegerType>(type);
+  Type elType = getElementTypeOrSelf(type);
+  auto intType = cast<IntegerType>(elType);
   unsigned bitwidth = intType.getWidth();
-  return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
-                                    : llvm::APInt::getAllOnes(bitwidth));
+  auto val =
+      min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth);
+
+  return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val));
 }
 
 /// Creates an OpenMP reduction declaration and inserts it into the provided
@@ -203,7 +219,7 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
   Operation *terminator =
       &reduce.getReductions()[reductionIndex].front().back();
   assert(isa<scf::ReduceReturnOp>(terminator) &&
-         "expected reduce op to be terminated by redure return");
+         "expected reduce op to be terminated by reduce return");
   builder.setInsertionPoint(terminator);
   builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
                                            terminator->getOperands());
@@ -237,6 +253,11 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
   return decl;
 }
 
+/// Returns true if the type is supported by llvm.atomicrmw.
+/// LLVM IR currently does not support atomic operations on vector types.
+/// See LLVM Language Reference Manual on 'atomicrmw'.
+static bool supportsAtomic(Type type) { return !isa<VectorType>(type); }
+
 /// Creates an OpenMP reduction declaration that corresponds to the given SCF
 /// reduction and returns it. Recognizes common reductions in order to identify
 /// the neutral value, necessary for the OpenMP declaration. If the reduction
@@ -261,91 +282,119 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
   // Match simple binary reductions that can be expressed with atomicrmw.
   Type type = reduce.getOperands()[reductionIndex].getType();
   Block &reduction = reduce.getReductions()[reductionIndex].front();
+
+  // Handle scalar element type extraction for vector bitwidth safety.
+  Type elType = getElementTypeOrSelf(type);
+
+  // Arithmetic Reductions
   if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getFloatAttr(type, 0.0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
-                        reductionIndex);
+    omp::DeclareReductionOp decl = createDecl(
+        builder, symbolTable, reduce, reductionIndex,
+        getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 0.0)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::fadd,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
   if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
-                        reductionIndex);
+    omp::DeclareReductionOp decl = createDecl(
+        builder, symbolTable, reduce, reductionIndex,
+        getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::add,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
   if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
-                        reductionIndex);
+    omp::DeclareReductionOp decl = createDecl(
+        builder, symbolTable, reduce, reductionIndex,
+        getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_or,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
   if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
-                        reductionIndex);
+    omp::DeclareReductionOp decl = createDecl(
+        builder, symbolTable, reduce, reductionIndex,
+        getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_xor,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
   if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
+    APInt allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
     omp::DeclareReductionOp decl = createDecl(
         builder, symbolTable, reduce, reductionIndex,
-        builder.getIntegerAttr(
-            type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
-                        reductionIndex);
+        getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, allOnes)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_and,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
 
   // Match simple binary reductions that cannot be expressed with atomicrmw.
   // TODO: add atomic region using cmpxchg (which needs atomic load to be
   // available as an op).
   if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
-    return createDecl(builder, symbolTable, reduce, reductionIndex,
-                      builder.getFloatAttr(type, 1.0));
+    return createDecl(
+        builder, symbolTable, reduce, reductionIndex,
+        getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 1.0)));
   }
+
   if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
-    return createDecl(builder, symbolTable, reduce, reductionIndex,
-                      builder.getIntegerAttr(type, 1));
+    return createDecl(
+        builder, symbolTable, reduce, reductionIndex,
+        getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 1)));
   }
 
   // Match select-based min/max reductions.
   bool isMin;
-  if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
+  // Floating Point Min/Max
+  if (matchSelectReduction<arith::CmpFOp, arith::SelectOp,
+                           arith::CmpFPredicate>(
           reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
           {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
-      matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
-          reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
-          {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
+      matchSelectReduction<arith::CmpFOp, arith::SelectOp,
+                           arith::CmpFPredicate>(
+          reduction, {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE},
+          {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, isMin)) {
     return createDecl(builder, symbolTable, reduce, reductionIndex,
                       minMaxValueForFloat(type, !isMin));
   }
-  if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
+
+  // Integer Min/Max
+  if (matchSelectReduction<arith::CmpIOp, arith::SelectOp,
+                           arith::CmpIPredicate>(
           reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
           {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
-      matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
-          reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
-          {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
+      matchSelectReduction<arith::CmpIOp, arith::SelectOp,
+                           arith::CmpIPredicate>(
+          reduction, {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge},
+          {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, isMin)) {
     omp::DeclareReductionOp decl =
         createDecl(builder, symbolTable, reduce, reductionIndex,
                    minMaxValueForSignedInt(type, !isMin));
-    return addAtomicRMW(builder,
-                        isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
-                        decl, reduce, reductionIndex);
+    return supportsAtomic(type) ? addAtomicRMW(builder,
+                                               isMin ? LLVM::AtomicBinOp::min
+                                                     : LLVM::AtomicBinOp::max,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
-  if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
+
+  // Unsigned Integer Min/Max
+  if (matchSelectReduction<arith::CmpIOp, arith::SelectOp,
+                           arith::CmpIPredicate>(
           reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
           {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
-      matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
-          reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
-          {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
+      matchSelectReduction<arith::CmpIOp, arith::SelectOp,
+                           arith::CmpIPredicate>(
+          reduction, {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge},
+          {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, isMin)) {
     omp::DeclareReductionOp decl =
         createDecl(builder, symbolTable, reduce, reductionIndex,
                    minMaxValueForUnsignedInt(type, !isMin));
-    return addAtomicRMW(
-        builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
-        decl, reduce, reductionIndex);
+    return supportsAtomic(type) ? addAtomicRMW(builder,
+                                               isMin ? LLVM::AtomicBinOp::umin
+                                                     : LLVM::AtomicBinOp::umax,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
 
   return nullptr;

diff  --git a/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
new file mode 100644
index 0000000000000..018f8a03c8e34
--- /dev/null
+++ b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s --convert-scf-to-openmp | FileCheck %s
+
+// CHECK-LABEL: omp.declare_reduction @__scf_reduction : vector<2xi1>
+// CHECK: init {
+// CHECK:   %[[INIT:.*]] = llvm.mlir.constant(dense<true> : vector<2xi1>) : vector<2xi1>
+// CHECK:   omp.yield(%[[INIT]] : vector<2xi1>)
+// CHECK: }
+// CHECK: combiner {
+// CHECK: ^bb0(%[[ARG0:.*]]: vector<2xi1>, %[[ARG1:.*]]: vector<2xi1>):
+// CHECK:   %[[RES:.*]] = arith.andi %[[ARG0]], %[[ARG1]] : vector<2xi1>
+// CHECK:   omp.yield(%[[RES]] : vector<2xi1>)
+// CHECK: }
+// CHECK-NOT: atomic
+
+func.func @vector_and_reduction() {
+  %v_mask = vector.constant_mask [1] : vector<2xi1>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %result = scf.parallel (%i) = (%c0) to (%c2) step (%c1) init(%v_mask) -> vector<2xi1> {
+    %val = vector.constant_mask [1] : vector<2xi1>
+    scf.reduce (%val : vector<2xi1>) {
+    ^bb0(%lhs: vector<2xi1>, %rhs: vector<2xi1>):
+      %0 = arith.andi %lhs, %rhs : vector<2xi1>
+      scf.reduce.return %0 : vector<2xi1>
+    }
+  }
+  return
+}
\ No newline at end of file


        


More information about the llvm-branch-commits mailing list