[llvm-branch-commits] [mlir] fe68b17 - [MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (#173978)
Cullen Rhodes via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 15 01:50:01 PST 2026
Author: Aniket Singh
Date: 2026-01-15T09:49:53Z
New Revision: fe68b17f46d470c2aa5223bb3cc4fec0d14801f9
URL: https://github.com/llvm/llvm-project/commit/fe68b17f46d470c2aa5223bb3cc4fec0d14801f9
DIFF: https://github.com/llvm/llvm-project/commit/fe68b17f46d470c2aa5223bb3cc4fec0d14801f9.diff
LOG: [MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (#173978)
This patch fixes a crash in the SCF to OpenMP conversion pass when
encountering scf.parallel with vector reductions.
- Extracts scalar element types for bitwidth calculations.
- Uses DenseElementsAttr for vector splat initializers.
- Bypasses llvm.atomicrmw for vector types (not supported in LLVM IR).
Fixes #173860
---------
Co-authored-by: Aniket Singh <amiket.singh.3200.00 at gmail.com>
Added:
mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
Modified:
mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..5fcaea7f39c3c 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -150,32 +150,48 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
llvm_unreachable("unknown float type");
}
+/// Helper to create a splat attribute for vector types, or return the scalar
+/// attribute for scalar types.
+static Attribute getSplatOrScalarAttr(Type type, Attribute val) {
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+ return val;
+}
+
/// Returns an attribute with the minimum (if `min` is set) or the maximum value
/// (otherwise) for the given float type.
static Attribute minMaxValueForFloat(Type type, bool min) {
- auto fltType = cast<FloatType>(type);
- return FloatAttr::get(
- type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
+ Type elType = getElementTypeOrSelf(type);
+ auto fltType = cast<FloatType>(elType);
+ auto val = llvm::APFloat::getLargest(fltSemanticsForType(fltType), min);
+
+ return getSplatOrScalarAttr(type, FloatAttr::get(elType, val));
}
/// Returns an attribute with the signed integer minimum (if `min` is set) or
/// the maximum value (otherwise) for the given integer type, regardless of its
/// signedness semantics (only the width is considered).
static Attribute minMaxValueForSignedInt(Type type, bool min) {
- auto intType = cast<IntegerType>(type);
+ Type elType = getElementTypeOrSelf(type);
+ auto intType = cast<IntegerType>(elType);
unsigned bitwidth = intType.getWidth();
- return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
- : llvm::APInt::getSignedMaxValue(bitwidth));
+ auto val = min ? llvm::APInt::getSignedMinValue(bitwidth)
+ : llvm::APInt::getSignedMaxValue(bitwidth);
+
+ return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val));
}
/// Returns an attribute with the unsigned integer minimum (if `min` is set) or
/// the maximum value (otherwise) for the given integer type, regardless of its
/// signedness semantics (only the width is considered).
static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
- auto intType = cast<IntegerType>(type);
+ Type elType = getElementTypeOrSelf(type);
+ auto intType = cast<IntegerType>(elType);
unsigned bitwidth = intType.getWidth();
- return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
- : llvm::APInt::getAllOnes(bitwidth));
+ auto val =
+ min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth);
+
+ return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val));
}
/// Creates an OpenMP reduction declaration and inserts it into the provided
@@ -203,7 +219,7 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
Operation *terminator =
&reduce.getReductions()[reductionIndex].front().back();
assert(isa<scf::ReduceReturnOp>(terminator) &&
- "expected reduce op to be terminated by redure return");
+ "expected reduce op to be terminated by reduce return");
builder.setInsertionPoint(terminator);
builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
terminator->getOperands());
@@ -237,6 +253,11 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
return decl;
}
+/// Returns true if the type is supported by llvm.atomicrmw.
+/// LLVM IR currently does not support atomic operations on vector types.
+/// See LLVM Language Reference Manual on 'atomicrmw'.
+static bool supportsAtomic(Type type) { return !isa<VectorType>(type); }
+
/// Creates an OpenMP reduction declaration that corresponds to the given SCF
/// reduction and returns it. Recognizes common reductions in order to identify
/// the neutral value, necessary for the OpenMP declaration. If the reduction
@@ -261,91 +282,119 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
// Match simple binary reductions that can be expressed with atomicrmw.
Type type = reduce.getOperands()[reductionIndex].getType();
Block &reduction = reduce.getReductions()[reductionIndex].front();
+
+ // Handle scalar element type extraction for vector bitwidth safety.
+ Type elType = getElementTypeOrSelf(type);
+
+ // Arithmetic Reductions
if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getFloatAttr(type, 0.0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
- reductionIndex);
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 0.0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::fadd,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
- reductionIndex);
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::add,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
- reductionIndex);
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_or,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
- reductionIndex);
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_xor,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
+ APInt allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
omp::DeclareReductionOp decl = createDecl(
builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(
- type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
- reductionIndex);
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, allOnes)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_and,
+ decl, reduce, reductionIndex)
+ : decl;
}
// Match simple binary reductions that cannot be expressed with atomicrmw.
// TODO: add atomic region using cmpxchg (which needs atomic load to be
// available as an op).
if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
- return createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getFloatAttr(type, 1.0));
+ return createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 1.0)));
}
+
if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
- return createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 1));
+ return createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 1)));
}
// Match select-based min/max reductions.
bool isMin;
- if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
+ // Floating Point Min/Max
+ if (matchSelectReduction<arith::CmpFOp, arith::SelectOp,
+ arith::CmpFPredicate>(
reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
{arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
- matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
- reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
- {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
+ matchSelectReduction<arith::CmpFOp, arith::SelectOp,
+ arith::CmpFPredicate>(
+ reduction, {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE},
+ {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, isMin)) {
return createDecl(builder, symbolTable, reduce, reductionIndex,
minMaxValueForFloat(type, !isMin));
}
- if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
+
+ // Integer Min/Max
+ if (matchSelectReduction<arith::CmpIOp, arith::SelectOp,
+ arith::CmpIPredicate>(
reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
{arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
- matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
- reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
- {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
+ matchSelectReduction<arith::CmpIOp, arith::SelectOp,
+ arith::CmpIPredicate>(
+ reduction, {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge},
+ {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, isMin)) {
omp::DeclareReductionOp decl =
createDecl(builder, symbolTable, reduce, reductionIndex,
minMaxValueForSignedInt(type, !isMin));
- return addAtomicRMW(builder,
- isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
- decl, reduce, reductionIndex);
+ return supportsAtomic(type) ? addAtomicRMW(builder,
+ isMin ? LLVM::AtomicBinOp::min
+ : LLVM::AtomicBinOp::max,
+ decl, reduce, reductionIndex)
+ : decl;
}
- if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
+
+ // Unsigned Integer Min/Max
+ if (matchSelectReduction<arith::CmpIOp, arith::SelectOp,
+ arith::CmpIPredicate>(
reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
{arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
- matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
- reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
- {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
+ matchSelectReduction<arith::CmpIOp, arith::SelectOp,
+ arith::CmpIPredicate>(
+ reduction, {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge},
+ {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, isMin)) {
omp::DeclareReductionOp decl =
createDecl(builder, symbolTable, reduce, reductionIndex,
minMaxValueForUnsignedInt(type, !isMin));
- return addAtomicRMW(
- builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
- decl, reduce, reductionIndex);
+ return supportsAtomic(type) ? addAtomicRMW(builder,
+ isMin ? LLVM::AtomicBinOp::umin
+ : LLVM::AtomicBinOp::umax,
+ decl, reduce, reductionIndex)
+ : decl;
}
return nullptr;
diff --git a/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
new file mode 100644
index 0000000000000..018f8a03c8e34
--- /dev/null
+++ b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s --convert-scf-to-openmp | FileCheck %s
+
+// CHECK-LABEL: omp.declare_reduction @__scf_reduction : vector<2xi1>
+// CHECK: init {
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant(dense<true> : vector<2xi1>) : vector<2xi1>
+// CHECK: omp.yield(%[[INIT]] : vector<2xi1>)
+// CHECK: }
+// CHECK: combiner {
+// CHECK: ^bb0(%[[ARG0:.*]]: vector<2xi1>, %[[ARG1:.*]]: vector<2xi1>):
+// CHECK: %[[RES:.*]] = arith.andi %[[ARG0]], %[[ARG1]] : vector<2xi1>
+// CHECK: omp.yield(%[[RES]] : vector<2xi1>)
+// CHECK: }
+// CHECK-NOT: atomic
+
+func.func @vector_and_reduction() {
+ %v_mask = vector.constant_mask [1] : vector<2xi1>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %result = scf.parallel (%i) = (%c0) to (%c2) step (%c1) init(%v_mask) -> vector<2xi1> {
+ %val = vector.constant_mask [1] : vector<2xi1>
+ scf.reduce (%val : vector<2xi1>) {
+ ^bb0(%lhs: vector<2xi1>, %rhs: vector<2xi1>):
+ %0 = arith.andi %lhs, %rhs : vector<2xi1>
+ scf.reduce.return %0 : vector<2xi1>
+ }
+ }
+ return
+}
\ No newline at end of file
More information about the llvm-branch-commits
mailing list