[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (PR #173978)
Tom Eccles
llvmlistbot at llvm.org
Mon Jan 5 08:04:17 PST 2026
================
@@ -261,54 +295,67 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
// Match simple binary reductions that can be expressed with atomicrmw.
Type type = reduce.getOperands()[reductionIndex].getType();
Block &reduction = reduce.getReductions()[reductionIndex].front();
+
+ // Handle scalar element type extraction for vector bitwidth safety.
+ Type elType = type;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ elType = vecType.getElementType();
+
+ // Helper to create splat (for vectors) or scalar attributes.
+ auto getAttr = [&](Attribute val) -> Attribute {
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+ return val;
+ };
+
+ // Arithmetic Reductions
if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getFloatAttr(type, 0.0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
- reductionIndex);
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getFloatAttr(elType, 0.0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::fadd,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
- reductionIndex);
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::add,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
- reductionIndex);
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_or,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
- reductionIndex);
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_xor,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
- omp::DeclareReductionOp decl = createDecl(
- builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(
- type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
- reductionIndex);
+ auto allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getIntegerAttr(elType, allOnes)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_and,
+ decl, reduce, reductionIndex)
+ : decl;
}
// Match simple binary reductions that cannot be expressed with atomicrmw.
// TODO: add atomic region using cmpxchg (which needs atomic load to be
// available as an op).
if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
return createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getFloatAttr(type, 1.0));
- }
- if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
+ getAttr(builder.getFloatAttr(elType, 1.0)));
----------------
tblah wrote:
missing end brace after this line?
https://github.com/llvm/llvm-project/pull/173978
More information about the Mlir-commits
mailing list