[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (PR #173978)
Aniket Singh
llvmlistbot at llvm.org
Mon Jan 12 05:07:19 PST 2026
https://github.com/Aniketsingh54 updated https://github.com/llvm/llvm-project/pull/173978
>From 3b4af6a7c7ee6b200643e93f29ddbd988b5d9564 Mon Sep 17 00:00:00 2001
From: Aniket Singh <amiket.singh.3200.00 at gmail.com>
Date: Tue, 30 Dec 2025 16:43:59 +0530
Subject: [PATCH 1/3] [MLIR][SCFToOpenMP] Fix crash when lowering vector
reductions
This patch fixes a crash in the SCF to OpenMP conversion pass when
encountering scf.parallel with vector reductions.
- Extracts scalar element types for bitwidth calculations.
- Uses DenseElementsAttr for vector splat initializers.
- Bypasses llvm.atomicrmw for vector types (not supported in LLVM IR).
Fixes #173860
---
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 158 ++++++++++++------
.../SCFToOpenMP/vector-reduction.mlir | 22 +++
2 files changed, 130 insertions(+), 50 deletions(-)
create mode 100644 mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..3d3c601d92d1b 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -153,29 +153,58 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
/// Returns an attribute with the minimum (if `min` is set) or the maximum value
/// (otherwise) for the given float type.
static Attribute minMaxValueForFloat(Type type, bool min) {
- auto fltType = cast<FloatType>(type);
- return FloatAttr::get(
- type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
+ // If the type is a vector, we need to find the neutral value for the
+ // underlying element type and then create a splat attribute.
+ Type elType = type;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ elType = vecType.getElementType();
+
+ auto fltType = cast<FloatType>(elType);
+ auto val = llvm::APFloat::getLargest(fltSemanticsForType(fltType), min);
+
+ // For vector types, return a DenseElementsAttr (splat).
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+
+ return FloatAttr::get(type, val);
}
/// Returns an attribute with the signed integer minimum (if `min` is set) or
/// the maximum value (otherwise) for the given integer type, regardless of its
/// signedness semantics (only the width is considered).
static Attribute minMaxValueForSignedInt(Type type, bool min) {
- auto intType = cast<IntegerType>(type);
+ // Extract scalar element type to handle vector reductions.
+ Type elType = type;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ elType = vecType.getElementType();
+
+ auto intType = cast<IntegerType>(elType);
unsigned bitwidth = intType.getWidth();
- return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
- : llvm::APInt::getSignedMaxValue(bitwidth));
+ auto val = min ? llvm::APInt::getSignedMinValue(bitwidth)
+ : llvm::APInt::getSignedMaxValue(bitwidth);
+
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+ return IntegerAttr::get(type, val);
}
/// Returns an attribute with the unsigned integer minimum (if `min` is set) or
/// the maximum value (otherwise) for the given integer type, regardless of its
/// signedness semantics (only the width is considered).
static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
- auto intType = cast<IntegerType>(type);
+ // Extract scalar element type to handle vector reductions.
+ Type elType = type;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ elType = vecType.getElementType();
+
+ auto intType = cast<IntegerType>(elType);
unsigned bitwidth = intType.getWidth();
- return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
- : llvm::APInt::getAllOnes(bitwidth));
+ auto val =
+ min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth);
+
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+ return IntegerAttr::get(type, val);
}
/// Creates an OpenMP reduction declaration and inserts it into the provided
@@ -203,7 +232,7 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
Operation *terminator =
&reduce.getReductions()[reductionIndex].front().back();
assert(isa<scf::ReduceReturnOp>(terminator) &&
- "expected reduce op to be terminated by redure return");
+ "expected reduce op to be terminated by reduce return");
builder.setInsertionPoint(terminator);
builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
terminator->getOperands());
@@ -236,6 +265,11 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
omp::YieldOp::create(builder, reduce.getLoc(), ArrayRef<Value>());
return decl;
}
+/// Returns true if the type is supported by llvm.atomicrmw.
+/// LLVM IR does not support atomic operations on vector types.
+static bool supportsAtomic(Type type) {
+ return !isa<VectorType>(type);
+}
/// Creates an OpenMP reduction declaration that corresponds to the given SCF
/// reduction and returns it. Recognizes common reductions in order to identify
@@ -261,41 +295,55 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
// Match simple binary reductions that can be expressed with atomicrmw.
Type type = reduce.getOperands()[reductionIndex].getType();
Block &reduction = reduce.getReductions()[reductionIndex].front();
+
+ // Handle scalar element type extraction for vector bitwidth safety.
+ Type elType = type;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ elType = vecType.getElementType();
+
+ // Helper to create splat (for vectors) or scalar attributes.
+ auto getAttr = [&](Attribute val) -> Attribute {
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+ return val;
+ };
+
+ // Arithmetic Reductions
if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getFloatAttr(type, 0.0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
- reductionIndex);
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getFloatAttr(elType, 0.0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::fadd,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
- reductionIndex);
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::add,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
- reductionIndex);
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_or,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
- reductionIndex);
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_xor,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
- omp::DeclareReductionOp decl = createDecl(
- builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(
- type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
- reductionIndex);
+ auto allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getIntegerAttr(elType, allOnes)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_and,
+ decl, reduce, reductionIndex)
+ : decl;
}
// Match simple binary reductions that cannot be expressed with atomicrmw.
@@ -303,12 +351,11 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
// available as an op).
if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
return createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getFloatAttr(type, 1.0));
- }
- if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
+ getAttr(builder.getFloatAttr(elType, 1.0)));
+
+ if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction))
return createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 1));
- }
+ getAttr(builder.getIntegerAttr(elType, 1)));
// Match select-based min/max reductions.
bool isMin;
@@ -329,10 +376,12 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
{LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
omp::DeclareReductionOp decl =
createDecl(builder, symbolTable, reduce, reductionIndex,
- minMaxValueForSignedInt(type, !isMin));
- return addAtomicRMW(builder,
- isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
- decl, reduce, reductionIndex);
+ minMaxValueForSignedInt(type, !isMin));
+ return supportsAtomic(type)
+ ? addAtomicRMW(builder,
+ isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
@@ -342,10 +391,12 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
{LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
omp::DeclareReductionOp decl =
createDecl(builder, symbolTable, reduce, reductionIndex,
- minMaxValueForUnsignedInt(type, !isMin));
- return addAtomicRMW(
- builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
- decl, reduce, reductionIndex);
+ minMaxValueForUnsignedInt(type, !isMin));
+ return supportsAtomic(type)
+ ? addAtomicRMW(builder,
+ isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
+ decl, reduce, reductionIndex)
+ : decl;
}
return nullptr;
@@ -370,6 +421,13 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
SmallVector<omp::DeclareReductionOp> ompReductionDecls;
auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
+ // Ensure validity of reduction type for vector bitwidth calculations.
+ Type reductionType = reduce.getOperands()[i].getType();
+ if (auto vecType = dyn_cast<VectorType>(reductionType))
+ (void)vecType.getElementType().getIntOrFloatBitWidth();
+ else
+ (void)reductionType.getIntOrFloatBitWidth();
+
omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i);
ompReductionDecls.push_back(decl);
if (!decl)
@@ -427,7 +485,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
Operation *cloneOp = builder.clone(op, mapper);
if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
assert(yieldOp && yieldOp.getResults().size() == 1 &&
- "expect YieldOp in reduction region to return one result");
+ "expect YieldOp in reduction region to return one result");
Value redVal = yieldOp.getResults()[0];
LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
rewriter.eraseOp(yieldOp);
diff --git a/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
new file mode 100644
index 0000000000000..38d7e3ec2aff1
--- /dev/null
+++ b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s --convert-scf-to-openmp | FileCheck %s
+
+// CHECK-LABEL: omp.declare_reduction @__scf_reduction : vector<2xi1> init
+// CHECK-NEXT: ^bb0(%arg0: vector<2xi1>):
+// CHECK-NEXT: %[[CONST:.*]] = llvm.mlir.constant(dense<true> : vector<2xi1>) : vector<2xi1>
+// CHECK-NEXT: omp.yield(%[[CONST]] : vector<2xi1>)
+
+func.func @vector_and_reduction() {
+ %v_mask = vector.constant_mask [1] : vector<2xi1>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %result = scf.parallel (%i) = (%c0) to (%c2) step (%c1) init(%v_mask) -> vector<2xi1> {
+ %val = vector.constant_mask [1] : vector<2xi1>
+ scf.reduce (%val : vector<2xi1>) {
+ ^bb0(%lhs: vector<2xi1>, %rhs: vector<2xi1>):
+ %0 = arith.andi %lhs, %rhs : vector<2xi1>
+ scf.reduce.return %0 : vector<2xi1>
+ }
+ }
+ return
+}
\ No newline at end of file
>From 5f4e6edc1e952e898dc078066797b83fbc99290f Mon Sep 17 00:00:00 2001
From: Aniket Singh <amiket.singh.3200.00 at gmail.com>
Date: Sat, 10 Jan 2026 02:20:10 +0530
Subject: [PATCH 2/3] Address review comments: Refactor splat logic and fix
vector reduction tests
---
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 75 +++++++++----------
.../SCFToOpenMP/vector-reduction.mlir | 15 +++-
2 files changed, 46 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 3d3c601d92d1b..8161d9f4c7a28 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -150,6 +150,14 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
llvm_unreachable("unknown float type");
}
+/// Helper to create a splat attribute for vector types, or return the scalar
+/// attribute for scalar types.
+static Attribute getSplatOrScalarAttr(Type type, Attribute val) {
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+ return val;
+}
+
/// Returns an attribute with the minimum (if `min` is set) or the maximum value
/// (otherwise) for the given float type.
static Attribute minMaxValueForFloat(Type type, bool min) {
@@ -162,11 +170,7 @@ static Attribute minMaxValueForFloat(Type type, bool min) {
auto fltType = cast<FloatType>(elType);
auto val = llvm::APFloat::getLargest(fltSemanticsForType(fltType), min);
- // For vector types, return a DenseElementsAttr (splat).
- if (auto vecType = dyn_cast<VectorType>(type))
- return DenseElementsAttr::get(vecType, val);
-
- return FloatAttr::get(type, val);
+ return getSplatOrScalarAttr(type, FloatAttr::get(elType, val));
}
/// Returns an attribute with the signed integer minimum (if `min` is set) or
@@ -183,9 +187,7 @@ static Attribute minMaxValueForSignedInt(Type type, bool min) {
auto val = min ? llvm::APInt::getSignedMinValue(bitwidth)
: llvm::APInt::getSignedMaxValue(bitwidth);
- if (auto vecType = dyn_cast<VectorType>(type))
- return DenseElementsAttr::get(vecType, val);
- return IntegerAttr::get(type, val);
+ return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val));
}
/// Returns an attribute with the unsigned integer minimum (if `min` is set) or
@@ -202,9 +204,7 @@ static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
auto val =
min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth);
- if (auto vecType = dyn_cast<VectorType>(type))
- return DenseElementsAttr::get(vecType, val);
- return IntegerAttr::get(type, val);
+ return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val));
}
/// Creates an OpenMP reduction declaration and inserts it into the provided
@@ -265,8 +265,10 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
omp::YieldOp::create(builder, reduce.getLoc(), ArrayRef<Value>());
return decl;
}
+
/// Returns true if the type is supported by llvm.atomicrmw.
-/// LLVM IR does not support atomic operations on vector types.
+/// LLVM IR currently does not support atomic operations on vector types.
+/// See LLVM Language Reference Manual on 'atomicrmw'.
static bool supportsAtomic(Type type) {
return !isa<VectorType>(type);
}
@@ -301,46 +303,44 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
if (auto vecType = dyn_cast<VectorType>(type))
elType = vecType.getElementType();
- // Helper to create splat (for vectors) or scalar attributes.
- auto getAttr = [&](Attribute val) -> Attribute {
- if (auto vecType = dyn_cast<VectorType>(type))
- return DenseElementsAttr::get(vecType, val);
- return val;
- };
-
// Arithmetic Reductions
if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
- auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
- getAttr(builder.getFloatAttr(elType, 0.0)));
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 0.0)));
return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::fadd,
decl, reduce, reductionIndex)
: decl;
}
if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
- auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
- getAttr(builder.getIntegerAttr(elType, 0)));
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::add,
decl, reduce, reductionIndex)
: decl;
}
if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
- auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
- getAttr(builder.getIntegerAttr(elType, 0)));
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_or,
decl, reduce, reductionIndex)
: decl;
}
if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
- auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
- getAttr(builder.getIntegerAttr(elType, 0)));
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_xor,
decl, reduce, reductionIndex)
: decl;
}
if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
- auto allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
- auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
- getAttr(builder.getIntegerAttr(elType, allOnes)));
+ APInt allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, allOnes)));
return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_and,
decl, reduce, reductionIndex)
: decl;
@@ -351,11 +351,13 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
// available as an op).
if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
return createDecl(builder, symbolTable, reduce, reductionIndex,
- getAttr(builder.getFloatAttr(elType, 1.0)));
+ getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 1.0)));
+ }
- if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction))
+ if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
return createDecl(builder, symbolTable, reduce, reductionIndex,
- getAttr(builder.getIntegerAttr(elType, 1)));
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 1)));
+ }
// Match select-based min/max reductions.
bool isMin;
@@ -421,13 +423,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
SmallVector<omp::DeclareReductionOp> ompReductionDecls;
auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
- // Ensure validity of reduction type for vector bitwidth calculations.
- Type reductionType = reduce.getOperands()[i].getType();
- if (auto vecType = dyn_cast<VectorType>(reductionType))
- (void)vecType.getElementType().getIntOrFloatBitWidth();
- else
- (void)reductionType.getIntOrFloatBitWidth();
-
omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i);
ompReductionDecls.push_back(decl);
if (!decl)
diff --git a/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
index 38d7e3ec2aff1..018f8a03c8e34 100644
--- a/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
@@ -1,9 +1,16 @@
// RUN: mlir-opt %s --convert-scf-to-openmp | FileCheck %s
-// CHECK-LABEL: omp.declare_reduction @__scf_reduction : vector<2xi1> init
-// CHECK-NEXT: ^bb0(%arg0: vector<2xi1>):
-// CHECK-NEXT: %[[CONST:.*]] = llvm.mlir.constant(dense<true> : vector<2xi1>) : vector<2xi1>
-// CHECK-NEXT: omp.yield(%[[CONST]] : vector<2xi1>)
+// CHECK-LABEL: omp.declare_reduction @__scf_reduction : vector<2xi1>
+// CHECK: init {
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant(dense<true> : vector<2xi1>) : vector<2xi1>
+// CHECK: omp.yield(%[[INIT]] : vector<2xi1>)
+// CHECK: }
+// CHECK: combiner {
+// CHECK: ^bb0(%[[ARG0:.*]]: vector<2xi1>, %[[ARG1:.*]]: vector<2xi1>):
+// CHECK: %[[RES:.*]] = arith.andi %[[ARG0]], %[[ARG1]] : vector<2xi1>
+// CHECK: omp.yield(%[[RES]] : vector<2xi1>)
+// CHECK: }
+// CHECK-NOT: atomic
func.func @vector_and_reduction() {
%v_mask = vector.constant_mask [1] : vector<2xi1>
>From 56c1bcb0d7034ee51ec8e64c2a19f31f70ef5a19 Mon Sep 17 00:00:00 2001
From: Aniket Singh <amiket.singh.3200.00 at gmail.com>
Date: Sat, 10 Jan 2026 02:46:47 +0530
Subject: [PATCH 3/3] Address review comments: Refactor type check logic
---
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 31 +++++++------------
1 file changed, 12 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 8161d9f4c7a28..77895470e75b8 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -150,6 +150,14 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
llvm_unreachable("unknown float type");
}
+/// Helper to extract the element type from a potential vector type.
+/// If the type is scalar, it returns the type itself.
+static Type getElementTypeOrSelf(Type type) {
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return vecType.getElementType();
+ return type;
+}
+
/// Helper to create a splat attribute for vector types, or return the scalar
/// attribute for scalar types.
static Attribute getSplatOrScalarAttr(Type type, Attribute val) {
@@ -161,12 +169,7 @@ static Attribute getSplatOrScalarAttr(Type type, Attribute val) {
/// Returns an attribute with the minimum (if `min` is set) or the maximum value
/// (otherwise) for the given float type.
static Attribute minMaxValueForFloat(Type type, bool min) {
- // If the type is a vector, we need to find the neutral value for the
- // underlying element type and then create a splat attribute.
- Type elType = type;
- if (auto vecType = dyn_cast<VectorType>(type))
- elType = vecType.getElementType();
-
+ Type elType = getElementTypeOrSelf(type);
auto fltType = cast<FloatType>(elType);
auto val = llvm::APFloat::getLargest(fltSemanticsForType(fltType), min);
@@ -177,11 +180,7 @@ static Attribute minMaxValueForFloat(Type type, bool min) {
/// the maximum value (otherwise) for the given integer type, regardless of its
/// signedness semantics (only the width is considered).
static Attribute minMaxValueForSignedInt(Type type, bool min) {
- // Extract scalar element type to handle vector reductions.
- Type elType = type;
- if (auto vecType = dyn_cast<VectorType>(type))
- elType = vecType.getElementType();
-
+ Type elType = getElementTypeOrSelf(type);
auto intType = cast<IntegerType>(elType);
unsigned bitwidth = intType.getWidth();
auto val = min ? llvm::APInt::getSignedMinValue(bitwidth)
@@ -194,11 +193,7 @@ static Attribute minMaxValueForSignedInt(Type type, bool min) {
/// the maximum value (otherwise) for the given integer type, regardless of its
/// signedness semantics (only the width is considered).
static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
- // Extract scalar element type to handle vector reductions.
- Type elType = type;
- if (auto vecType = dyn_cast<VectorType>(type))
- elType = vecType.getElementType();
-
+ Type elType = getElementTypeOrSelf(type);
auto intType = cast<IntegerType>(elType);
unsigned bitwidth = intType.getWidth();
auto val =
@@ -299,9 +294,7 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
Block &reduction = reduce.getReductions()[reductionIndex].front();
// Handle scalar element type extraction for vector bitwidth safety.
- Type elType = type;
- if (auto vecType = dyn_cast<VectorType>(type))
- elType = vecType.getElementType();
+ Type elType = getElementTypeOrSelf(type);
// Arithmetic Reductions
if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
More information about the Mlir-commits
mailing list