[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (PR #173978)

Tom Eccles llvmlistbot at llvm.org
Mon Jan 5 08:04:17 PST 2026


================
@@ -261,54 +295,67 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
   // Match simple binary reductions that can be expressed with atomicrmw.
   Type type = reduce.getOperands()[reductionIndex].getType();
   Block &reduction = reduce.getReductions()[reductionIndex].front();
+
+  // Handle scalar element type extraction for vector bitwidth safety.
+  Type elType = type;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    elType = vecType.getElementType();
+
+  // Helper to create splat (for vectors) or scalar attributes.
+  auto getAttr = [&](Attribute val) -> Attribute {
+    if (auto vecType = dyn_cast<VectorType>(type))
+      return DenseElementsAttr::get(vecType, val);
+    return val;
+  };
+
+  // Arithmetic Reductions
   if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getFloatAttr(type, 0.0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
-                        reductionIndex);
+    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+                           getAttr(builder.getFloatAttr(elType, 0.0)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::fadd,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
   if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
-                        reductionIndex);
+    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+                           getAttr(builder.getIntegerAttr(elType, 0)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::add,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
   if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
-                        reductionIndex);
+    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+                           getAttr(builder.getIntegerAttr(elType, 0)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_or,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
   if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
-                        reductionIndex);
+    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+                           getAttr(builder.getIntegerAttr(elType, 0)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_xor,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
   if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
-    omp::DeclareReductionOp decl = createDecl(
-        builder, symbolTable, reduce, reductionIndex,
-        builder.getIntegerAttr(
-            type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
-                        reductionIndex);
+    auto allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
+    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+                           getAttr(builder.getIntegerAttr(elType, allOnes)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_and,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
 
   // Match simple binary reductions that cannot be expressed with atomicrmw.
   // TODO: add atomic region using cmpxchg (which needs atomic load to be
   // available as an op).
   if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
     return createDecl(builder, symbolTable, reduce, reductionIndex,
-                      builder.getFloatAttr(type, 1.0));
-  }
-  if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
+                      getAttr(builder.getFloatAttr(elType, 1.0)));
----------------
tblah wrote:

missing end brace after this line?

https://github.com/llvm/llvm-project/pull/173978


More information about the Mlir-commits mailing list