[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (PR #173938)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 29 16:55:40 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
Author: Aniket Singh (Aniketsingh54)
<details>
<summary>Changes</summary>
The SCF to OpenMP conversion pass crashed when encountering scf.parallel with vector reductions because it assumed all reduction types were scalar. This patch:
- Adds support for VectorType in reduction initializers.
- Uses DenseElementsAttr for vector splat initializers.
- Prevents the use of llvm.atomicrmw for vector types as they are not supported by the LLVM instruction.
Fixes #<!-- -->173860
---
Patch is 23.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/173938.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+128-128)
- (added) mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir (+22)
``````````diff
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..66ea0716a2292 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -153,35 +153,56 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
/// Returns an attribute with the minimum (if `min` is set) or the maximum value
/// (otherwise) for the given float type.
static Attribute minMaxValueForFloat(Type type, bool min) {
- auto fltType = cast<FloatType>(type);
- return FloatAttr::get(
- type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
+ // If the type is a vector, we need to find the neutral value for the
+ // underlying element type and then create a splat attribute.
+ Type elType = type;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ elType = vecType.getElementType();
+
+ auto fltType = cast<FloatType>(elType);
+ auto val = llvm::APFloat::getLargest(fltSemanticsForType(fltType), min);
+
+ // For vector types, return a DenseElementsAttr (splat).
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+
+ return FloatAttr::get(type, val);
}
-/// Returns an attribute with the signed integer minimum (if `min` is set) or
-/// the maximum value (otherwise) for the given integer type, regardless of its
-/// signedness semantics (only the width is considered).
static Attribute minMaxValueForSignedInt(Type type, bool min) {
- auto intType = cast<IntegerType>(type);
+ // Extract scalar element type to handle vector reductions.
+ Type elType = type;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ elType = vecType.getElementType();
+
+ auto intType = cast<IntegerType>(elType);
unsigned bitwidth = intType.getWidth();
- return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
- : llvm::APInt::getSignedMaxValue(bitwidth));
+ auto val = min ? llvm::APInt::getSignedMinValue(bitwidth)
+ : llvm::APInt::getSignedMaxValue(bitwidth);
+
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+ return IntegerAttr::get(type, val);
}
-/// Returns an attribute with the unsigned integer minimum (if `min` is set) or
-/// the maximum value (otherwise) for the given integer type, regardless of its
-/// signedness semantics (only the width is considered).
static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
- auto intType = cast<IntegerType>(type);
+ // Extract scalar element type to handle vector reductions.
+ Type elType = type;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ elType = vecType.getElementType();
+
+ auto intType = cast<IntegerType>(elType);
unsigned bitwidth = intType.getWidth();
- return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
- : llvm::APInt::getAllOnes(bitwidth));
+ auto val =
+ min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth);
+
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+ return IntegerAttr::get(type, val);
}
/// Creates an OpenMP reduction declaration and inserts it into the provided
-/// symbol table. The declaration has a constant initializer with the neutral
-/// value `initValue`, and the `reductionIndex`-th reduction combiner carried
-/// over from `reduce`.
+/// symbol table.
static omp::DeclareReductionOp
createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) {
@@ -203,7 +224,7 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
Operation *terminator =
&reduce.getReductions()[reductionIndex].front().back();
assert(isa<scf::ReduceReturnOp>(terminator) &&
- "expected reduce op to be terminated by redure return");
+ "expected reduce op to be terminated by reduce return");
builder.setInsertionPoint(terminator);
builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
terminator->getOperands());
@@ -213,8 +234,7 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
return decl;
}
-/// Adds an atomic reduction combiner to the given OpenMP reduction declaration
-/// using llvm.atomicrmw of the given kind.
+/// Adds an atomic reduction combiner using llvm.atomicrmw.
static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
LLVM::AtomicBinOp atomicKind,
omp::DeclareReductionOp decl,
@@ -238,17 +258,13 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
}
/// Creates an OpenMP reduction declaration that corresponds to the given SCF
-/// reduction and returns it. Recognizes common reductions in order to identify
-/// the neutral value, necessary for the OpenMP declaration. If the reduction
-/// cannot be recognized, returns null.
+/// reduction.
static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
scf::ReduceOp reduce,
int64_t reductionIndex) {
Operation *container = SymbolTable::getNearestSymbolTable(reduce);
SymbolTable symbolTable(container);
- // Insert reduction declarations in the symbol-table ancestor before the
- // ancestor of the current insertion point.
Operation *insertionPoint = reduce;
while (insertionPoint->getParentOp() != container)
insertionPoint = insertionPoint->getParentOp();
@@ -258,56 +274,83 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) &&
"expected reduction region to have a single element");
- // Match simple binary reductions that can be expressed with atomicrmw.
Type type = reduce.getOperands()[reductionIndex].getType();
Block &reduction = reduce.getReductions()[reductionIndex].front();
+
+ // Extract scalar element type to handle vector bitwidths correctly.
+ Type elType = type;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ elType = vecType.getElementType();
+
+ // Helper to create splat or scalar integer attributes.
+ auto getIntAttr = [&](Type t, const llvm::APInt &value) -> Attribute {
+ if (auto vecType = dyn_cast<VectorType>(t))
+ return DenseElementsAttr::get(vecType, value);
+ return builder.getIntegerAttr(t, value);
+ };
+
+ // Helper to create splat or scalar float attributes.
+ auto getFloatAttr = [&](Type t, double value) -> Attribute {
+ if (auto vecType = dyn_cast<VectorType>(t))
+ return DenseElementsAttr::get(
+ vecType, builder.getFloatAttr(vecType.getElementType(), value));
+ return builder.getFloatAttr(t, value);
+ };
+
+ // Match simple binary reductions and only add atomicRMW if type is not a
+ // vector.
if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getFloatAttr(type, 0.0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
- reductionIndex);
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex, getFloatAttr(type, 0.0));
+ return isa<VectorType>(type)
+ ? decl
+ : addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
+ reductionIndex);
}
if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
- reductionIndex);
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getIntAttr(type, llvm::APInt::getZero(elType.getIntOrFloatBitWidth())));
+ return isa<VectorType>(type) ? decl
+ : addAtomicRMW(builder, LLVM::AtomicBinOp::add,
+ decl, reduce, reductionIndex);
}
if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
- reductionIndex);
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getIntAttr(type, llvm::APInt::getZero(elType.getIntOrFloatBitWidth())));
+ return isa<VectorType>(type) ? decl
+ : addAtomicRMW(builder, LLVM::AtomicBinOp::_or,
+ decl, reduce, reductionIndex);
}
if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
- reductionIndex);
- }
- if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
omp::DeclareReductionOp decl = createDecl(
builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(
- type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
- reductionIndex);
+ getIntAttr(type, llvm::APInt::getZero(elType.getIntOrFloatBitWidth())));
+ return isa<VectorType>(type)
+ ? decl
+ : addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
+ reductionIndex);
+ }
+ if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
+ auto allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
+ omp::DeclareReductionOp decl =
+ createDecl(builder, symbolTable, reduce, reductionIndex,
+ getIntAttr(type, allOnes));
+ return isa<VectorType>(type)
+ ? decl
+ : addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
+ reductionIndex);
}
- // Match simple binary reductions that cannot be expressed with atomicrmw.
- // TODO: add atomic region using cmpxchg (which needs atomic load to be
- // available as an op).
if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
return createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getFloatAttr(type, 1.0));
+ getFloatAttr(type, 1.0));
}
if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
- return createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 1));
+ return createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getIntAttr(type, llvm::APInt(elType.getIntOrFloatBitWidth(), 1)));
}
// Match select-based min/max reductions.
@@ -330,9 +373,11 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
omp::DeclareReductionOp decl =
createDecl(builder, symbolTable, reduce, reductionIndex,
minMaxValueForSignedInt(type, !isMin));
- return addAtomicRMW(builder,
- isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
- decl, reduce, reductionIndex);
+ return isa<VectorType>(type) ? decl
+ : addAtomicRMW(builder,
+ isMin ? LLVM::AtomicBinOp::min
+ : LLVM::AtomicBinOp::max,
+ decl, reduce, reductionIndex);
}
if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
@@ -343,9 +388,11 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
omp::DeclareReductionOp decl =
createDecl(builder, symbolTable, reduce, reductionIndex,
minMaxValueForUnsignedInt(type, !isMin));
- return addAtomicRMW(
- builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
- decl, reduce, reductionIndex);
+ return isa<VectorType>(type) ? decl
+ : addAtomicRMW(builder,
+ isMin ? LLVM::AtomicBinOp::umin
+ : LLVM::AtomicBinOp::umax,
+ decl, reduce, reductionIndex);
}
return nullptr;
@@ -363,13 +410,17 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
PatternRewriter &rewriter) const override {
- // Declare reductions.
- // TODO: consider checking it here is already a compatible reduction
- // declaration and use it instead of redeclaring.
SmallVector<Attribute> reductionSyms;
SmallVector<omp::DeclareReductionOp> ompReductionDecls;
auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
+ // Ensure validity of reduction type for vector bitwidth calculations.
+ Type reductionType = reduce.getOperands()[i].getType();
+ if (auto vecType = dyn_cast<VectorType>(reductionType))
+ (void)vecType.getElementType().getIntOrFloatBitWidth();
+ else
+ (void)reductionType.getIntOrFloatBitWidth();
+
omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i);
ompReductionDecls.push_back(decl);
if (!decl)
@@ -378,8 +429,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
SymbolRefAttr::get(rewriter.getContext(), decl.getSymName()));
}
- // Allocate reduction variables. Make sure the we don't overflow the stack
- // with local `alloca`s by saving and restoring the stack pointer.
Location loc = parallelOp.getLoc();
Value one =
LLVM::ConstantOp::create(rewriter, loc, rewriter.getIntegerType(64),
@@ -398,24 +447,16 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
reductionVariables.push_back(storage);
}
- // Replace the reduction operations contained in this loop. Must be done
- // here rather than in a separate pattern to have access to the list of
- // reduction variables.
for (auto [x, y, rD] : llvm::zip_equal(
reductionVariables, reduce.getOperands(), ompReductionDecls)) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(reduce);
Region &redRegion = rD.getReductionRegion();
- // The SCF dialect by definition contains only structured operations
- // and hence the SCF reduction region will contain a single block.
- // The ompReductionDecls region is a copy of the SCF reduction region
- // and hence has the same property.
assert(redRegion.hasOneBlock() &&
"expect reduction region to have one block");
Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
Value pvtRedVal = LLVM::LoadOp::create(rewriter, reduce.getLoc(),
rD.getType(), pvtRedVar);
- // Make a copy of the reduction combiner region in the body
mlir::OpBuilder builder(rewriter.getContext());
builder.setInsertionPoint(reduce);
mlir::IRMapping mapper;
@@ -426,8 +467,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
for (auto &op : redRegion.getOps()) {
Operation *cloneOp = builder.clone(op, mapper);
if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
- assert(yieldOp && yieldOp.getResults().size() == 1 &&
- "expect YieldOp in reduction region to return one result");
Value redVal = yieldOp.getResults()[0];
LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
rewriter.eraseOp(yieldOp);
@@ -442,67 +481,39 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
numThreadsVar = LLVM::ConstantOp::create(
rewriter, loc, rewriter.getI32IntegerAttr(numThreads));
}
- // Create the parallel wrapper.
auto ompParallel = omp::ParallelOp::create(
- rewriter, loc,
- /* allocate_vars = */ llvm::SmallVector<Value>{},
- /* allocator_vars = */ llvm::SmallVector<Value>{},
- /* if_expr = */ Value{},
- /* num_threads = */ numThreadsVar,
- /* private_vars = */ ValueRange(),
- /* private_syms = */ nullptr,
- /* private_needs_barrier = */ nullptr,
- /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
- /* reduction_mod = */ nullptr,
- /* reduction_vars = */ llvm::SmallVector<Value>{},
- /* reduction_byref = */ DenseBoolArrayAttr{},
- /* reduction_syms = */ ArrayAttr{});
+ rewriter, loc, {}, {}, Value{}, numThreadsVar, ValueRange(), nullptr,
+ nullptr, omp::ClauseProcBindKindAttr{}, nullptr, {},
+ DenseBoolArrayAttr{}, ArrayAttr{});
{
-
OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(&ompParallel.getRegion());
-
- // Replace the loop.
{
OpBuilder::InsertionGuard allocaGuard(rewriter);
- // Create worksharing loop wrapper.
auto wsloopOp = omp::WsloopOp::create(rewriter, parallelOp.getLoc());
if (!reductionVariables.empty()) {
wsloopOp.setReductionSymsAttr(
ArrayAttr::get(rewriter.getContext(), reductionSyms));
wsloopOp.getReductionVarsMutable().append(reductionVariables);
- llvm::SmallVector<bool> reductionByRef;
- // false because these reductions always reduce scalars and so do
- // not need to pass by reference
- reductionByRef.resize(reductionVariables.size(), false);
+ llvm::SmallVector<bool> reductionByRef(reductionVariables.size(),
+ false);
wsloopOp.setReductionByref(
DenseBoolArrayAttr::get(rewriter.getContext(), reductionByRef));
}
- omp::TerminatorOp::create(rewriter, loc); // omp.parallel terminator.
-
- // The wrapper's entry block arguments will define the reduction
- // variables.
+ omp::TerminatorOp::create(rewriter, loc);
llvm::SmallVector<mlir::Type> reductionTypes;
- reductionTypes.reserve(reductionVariables.size());
llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
[](mlir::Value v) { return v.getType(); });
rewriter.createBlock(
&wsloopOp.getRegion(), {}, reductionTypes,
llvm::SmallVector<mlir::Location>(reductionVariables.size(),
parallelOp.getLoc()));
-
- // Create loop nest and populate region with contents of scf.parallel.
auto loopOp = omp::LoopNestOp::create(
rewriter, parallelOp.getLoc(), parallelOp.getLowerBound().size(),
parallelOp.getLowerBound(), parallelOp.getUpperBound(),
- parallelOp.getStep(), /*loop_inclusive=*/false,
- /*tile_sizes=*/nullptr);
-
+ parallelOp.getStep(), /*loop_inclusive=*/false, nullptr);
rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
loopOp.getRegion().begin());
-
- // Remove reduction-related block arguments from omp.loop_nest and
- // redirect uses to the corresponding omp.wsloop block argument.
mlir::Block &loopOpEntryBlock = loopOp.getRegion().fr...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/173938
More information about the Mlir-commits
mailing list