[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (PR #173978)

Aniket Singh llvmlistbot at llvm.org
Fri Jan 9 12:52:52 PST 2026


https://github.com/Aniketsingh54 updated https://github.com/llvm/llvm-project/pull/173978

>From 3b4af6a7c7ee6b200643e93f29ddbd988b5d9564 Mon Sep 17 00:00:00 2001
From: Aniket Singh <amiket.singh.3200.00 at gmail.com>
Date: Tue, 30 Dec 2025 16:43:59 +0530
Subject: [PATCH 1/2] [MLIR][SCFToOpenMP] Fix crash when lowering vector
 reductions

This patch fixes a crash in the SCF to OpenMP conversion pass when
encountering scf.parallel with vector reductions.

- Extracts scalar element types for bitwidth calculations.
- Uses DenseElementsAttr for vector splat initializers.
- Bypasses llvm.atomicrmw for vector types (not supported in LLVM IR).

Fixes #173860
---
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    | 158 ++++++++++++------
 .../SCFToOpenMP/vector-reduction.mlir         |  22 +++
 2 files changed, 130 insertions(+), 50 deletions(-)
 create mode 100644 mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir

diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..3d3c601d92d1b 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -153,29 +153,58 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
 /// Returns an attribute with the minimum (if `min` is set) or the maximum value
 /// (otherwise) for the given float type.
 static Attribute minMaxValueForFloat(Type type, bool min) {
-  auto fltType = cast<FloatType>(type);
-  return FloatAttr::get(
-      type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
+  // If the type is a vector, we need to find the neutral value for the
+  // underlying element type and then create a splat attribute.
+  Type elType = type;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    elType = vecType.getElementType();
+
+  auto fltType = cast<FloatType>(elType);
+  auto val = llvm::APFloat::getLargest(fltSemanticsForType(fltType), min);
+
+  // For vector types, return a DenseElementsAttr (splat).
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return DenseElementsAttr::get(vecType, val);
+
+  return FloatAttr::get(type, val);
 }
 
 /// Returns an attribute with the signed integer minimum (if `min` is set) or
 /// the maximum value (otherwise) for the given integer type, regardless of its
 /// signedness semantics (only the width is considered).
 static Attribute minMaxValueForSignedInt(Type type, bool min) {
-  auto intType = cast<IntegerType>(type);
+  // Extract scalar element type to handle vector reductions.
+  Type elType = type;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    elType = vecType.getElementType();
+
+  auto intType = cast<IntegerType>(elType);
   unsigned bitwidth = intType.getWidth();
-  return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
-                                    : llvm::APInt::getSignedMaxValue(bitwidth));
+  auto val = min ? llvm::APInt::getSignedMinValue(bitwidth)
+                 : llvm::APInt::getSignedMaxValue(bitwidth);
+
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return DenseElementsAttr::get(vecType, val);
+  return IntegerAttr::get(type, val);
 }
 
 /// Returns an attribute with the unsigned integer minimum (if `min` is set) or
 /// the maximum value (otherwise) for the given integer type, regardless of its
 /// signedness semantics (only the width is considered).
 static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
-  auto intType = cast<IntegerType>(type);
+  // Extract scalar element type to handle vector reductions.
+  Type elType = type;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    elType = vecType.getElementType();
+
+  auto intType = cast<IntegerType>(elType);
   unsigned bitwidth = intType.getWidth();
-  return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
-                                    : llvm::APInt::getAllOnes(bitwidth));
+  auto val =
+      min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth);
+
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return DenseElementsAttr::get(vecType, val);
+  return IntegerAttr::get(type, val);
 }
 
 /// Creates an OpenMP reduction declaration and inserts it into the provided
@@ -203,7 +232,7 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
   Operation *terminator =
       &reduce.getReductions()[reductionIndex].front().back();
   assert(isa<scf::ReduceReturnOp>(terminator) &&
-         "expected reduce op to be terminated by redure return");
+         "expected reduce op to be terminated by reduce return");
   builder.setInsertionPoint(terminator);
   builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
                                            terminator->getOperands());
@@ -236,6 +265,11 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
   omp::YieldOp::create(builder, reduce.getLoc(), ArrayRef<Value>());
   return decl;
 }
+/// Returns true if the type is supported by llvm.atomicrmw. 
+/// LLVM IR does not support atomic operations on vector types.
+static bool supportsAtomic(Type type) {
+  return !isa<VectorType>(type);
+}
 
 /// Creates an OpenMP reduction declaration that corresponds to the given SCF
 /// reduction and returns it. Recognizes common reductions in order to identify
@@ -261,41 +295,55 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
   // Match simple binary reductions that can be expressed with atomicrmw.
   Type type = reduce.getOperands()[reductionIndex].getType();
   Block &reduction = reduce.getReductions()[reductionIndex].front();
+
+  // Handle scalar element type extraction for vector bitwidth safety.
+  Type elType = type;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    elType = vecType.getElementType();
+
+  // Helper to create splat (for vectors) or scalar attributes.
+  auto getAttr = [&](Attribute val) -> Attribute {
+    if (auto vecType = dyn_cast<VectorType>(type))
+      return DenseElementsAttr::get(vecType, val);
+    return val;
+  };
+
+  // Arithmetic Reductions
   if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getFloatAttr(type, 0.0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
-                        reductionIndex);
+    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+                           getAttr(builder.getFloatAttr(elType, 0.0)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::fadd,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
   if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
-                        reductionIndex);
+    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+                           getAttr(builder.getIntegerAttr(elType, 0)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::add,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
   if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
-                        reductionIndex);
+    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+                           getAttr(builder.getIntegerAttr(elType, 0)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_or,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
   if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
-                        reductionIndex);
+    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+                           getAttr(builder.getIntegerAttr(elType, 0)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_xor,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
   if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
-    omp::DeclareReductionOp decl = createDecl(
-        builder, symbolTable, reduce, reductionIndex,
-        builder.getIntegerAttr(
-            type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
-                        reductionIndex);
+    auto allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
+    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+                           getAttr(builder.getIntegerAttr(elType, allOnes)));
+    return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_and,
+                                               decl, reduce, reductionIndex)
+                                : decl;
   }
 
   // Match simple binary reductions that cannot be expressed with atomicrmw.
@@ -303,12 +351,11 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
   // available as an op).
   if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
     return createDecl(builder, symbolTable, reduce, reductionIndex,
-                      builder.getFloatAttr(type, 1.0));
-  }
-  if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
+                      getAttr(builder.getFloatAttr(elType, 1.0)));
+
+  if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction))
     return createDecl(builder, symbolTable, reduce, reductionIndex,
-                      builder.getIntegerAttr(type, 1));
-  }
+                      getAttr(builder.getIntegerAttr(elType, 1)));
 
   // Match select-based min/max reductions.
   bool isMin;
@@ -329,10 +376,12 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
           {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
     omp::DeclareReductionOp decl =
         createDecl(builder, symbolTable, reduce, reductionIndex,
-                   minMaxValueForSignedInt(type, !isMin));
-    return addAtomicRMW(builder,
-                        isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
-                        decl, reduce, reductionIndex);
+                           minMaxValueForSignedInt(type, !isMin));
+    return supportsAtomic(type)
+               ? addAtomicRMW(builder,
+                              isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
+                              decl, reduce, reductionIndex)
+               : decl;
   }
   if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
           reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
@@ -342,10 +391,12 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
           {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
     omp::DeclareReductionOp decl =
         createDecl(builder, symbolTable, reduce, reductionIndex,
-                   minMaxValueForUnsignedInt(type, !isMin));
-    return addAtomicRMW(
-        builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
-        decl, reduce, reductionIndex);
+                           minMaxValueForUnsignedInt(type, !isMin));
+    return supportsAtomic(type)
+               ? addAtomicRMW(builder,
+                              isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
+                              decl, reduce, reductionIndex)
+               : decl;
   }
 
   return nullptr;
@@ -370,6 +421,13 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
     SmallVector<omp::DeclareReductionOp> ompReductionDecls;
     auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
     for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
+      // Ensure validity of reduction type for vector bitwidth calculations.
+      Type reductionType = reduce.getOperands()[i].getType();
+      if (auto vecType = dyn_cast<VectorType>(reductionType))
+        (void)vecType.getElementType().getIntOrFloatBitWidth();
+      else
+        (void)reductionType.getIntOrFloatBitWidth();
+
       omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i);
       ompReductionDecls.push_back(decl);
       if (!decl)
@@ -427,7 +485,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         Operation *cloneOp = builder.clone(op, mapper);
         if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
           assert(yieldOp && yieldOp.getResults().size() == 1 &&
-                 "expect YieldOp in reduction region to return one result");
+                    "expect YieldOp in reduction region to return one result");
           Value redVal = yieldOp.getResults()[0];
           LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
           rewriter.eraseOp(yieldOp);
diff --git a/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
new file mode 100644
index 0000000000000..38d7e3ec2aff1
--- /dev/null
+++ b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s --convert-scf-to-openmp | FileCheck %s
+
+// CHECK-LABEL: omp.declare_reduction @__scf_reduction : vector<2xi1> init
+// CHECK-NEXT:  ^bb0(%arg0: vector<2xi1>):
+// CHECK-NEXT:    %[[CONST:.*]] = llvm.mlir.constant(dense<true> : vector<2xi1>) : vector<2xi1>
+// CHECK-NEXT:    omp.yield(%[[CONST]] : vector<2xi1>)
+
+func.func @vector_and_reduction() {
+  %v_mask = vector.constant_mask [1] : vector<2xi1>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %result = scf.parallel (%i) = (%c0) to (%c2) step (%c1) init(%v_mask) -> vector<2xi1> {
+    %val = vector.constant_mask [1] : vector<2xi1>
+    scf.reduce (%val : vector<2xi1>) {
+    ^bb0(%lhs: vector<2xi1>, %rhs: vector<2xi1>):
+      %0 = arith.andi %lhs, %rhs : vector<2xi1>
+      scf.reduce.return %0 : vector<2xi1>
+    }
+  }
+  return
+}
\ No newline at end of file

>From 5f4e6edc1e952e898dc078066797b83fbc99290f Mon Sep 17 00:00:00 2001
From: Aniket Singh <amiket.singh.3200.00 at gmail.com>
Date: Sat, 10 Jan 2026 02:20:10 +0530
Subject: [PATCH 2/2] Address review comments: Refactor splat logic and fix
 vector reduction tests

---
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    | 75 +++++++++----------
 .../SCFToOpenMP/vector-reduction.mlir         | 15 +++-
 2 files changed, 46 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 3d3c601d92d1b..8161d9f4c7a28 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -150,6 +150,14 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
   llvm_unreachable("unknown float type");
 }
 
+/// Helper to create a splat attribute for vector types, or return the scalar
+/// attribute for scalar types.
+static Attribute getSplatOrScalarAttr(Type type, Attribute val) {
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return DenseElementsAttr::get(vecType, val);
+  return val;
+}
+
 /// Returns an attribute with the minimum (if `min` is set) or the maximum value
 /// (otherwise) for the given float type.
 static Attribute minMaxValueForFloat(Type type, bool min) {
@@ -162,11 +170,7 @@ static Attribute minMaxValueForFloat(Type type, bool min) {
   auto fltType = cast<FloatType>(elType);
   auto val = llvm::APFloat::getLargest(fltSemanticsForType(fltType), min);
 
-  // For vector types, return a DenseElementsAttr (splat).
-  if (auto vecType = dyn_cast<VectorType>(type))
-    return DenseElementsAttr::get(vecType, val);
-
-  return FloatAttr::get(type, val);
+  return getSplatOrScalarAttr(type, FloatAttr::get(elType, val));
 }
 
 /// Returns an attribute with the signed integer minimum (if `min` is set) or
@@ -183,9 +187,7 @@ static Attribute minMaxValueForSignedInt(Type type, bool min) {
   auto val = min ? llvm::APInt::getSignedMinValue(bitwidth)
                  : llvm::APInt::getSignedMaxValue(bitwidth);
 
-  if (auto vecType = dyn_cast<VectorType>(type))
-    return DenseElementsAttr::get(vecType, val);
-  return IntegerAttr::get(type, val);
+  return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val));
 }
 
 /// Returns an attribute with the unsigned integer minimum (if `min` is set) or
@@ -202,9 +204,7 @@ static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
   auto val =
       min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth);
 
-  if (auto vecType = dyn_cast<VectorType>(type))
-    return DenseElementsAttr::get(vecType, val);
-  return IntegerAttr::get(type, val);
+  return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val));
 }
 
 /// Creates an OpenMP reduction declaration and inserts it into the provided
@@ -265,8 +265,10 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
   omp::YieldOp::create(builder, reduce.getLoc(), ArrayRef<Value>());
   return decl;
 }
+
 /// Returns true if the type is supported by llvm.atomicrmw. 
-/// LLVM IR does not support atomic operations on vector types.
+/// LLVM IR currently does not support atomic operations on vector types.
+/// See LLVM Language Reference Manual on 'atomicrmw'.
 static bool supportsAtomic(Type type) {
   return !isa<VectorType>(type);
 }
@@ -301,46 +303,44 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
   if (auto vecType = dyn_cast<VectorType>(type))
     elType = vecType.getElementType();
 
-  // Helper to create splat (for vectors) or scalar attributes.
-  auto getAttr = [&](Attribute val) -> Attribute {
-    if (auto vecType = dyn_cast<VectorType>(type))
-      return DenseElementsAttr::get(vecType, val);
-    return val;
-  };
-
   // Arithmetic Reductions
   if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
-    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
-                           getAttr(builder.getFloatAttr(elType, 0.0)));
+    omp::DeclareReductionOp decl = createDecl(
+        builder, symbolTable, reduce, reductionIndex,
+        getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 0.0)));
     return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::fadd,
                                                decl, reduce, reductionIndex)
                                 : decl;
   }
   if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
-    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
-                           getAttr(builder.getIntegerAttr(elType, 0)));
+    omp::DeclareReductionOp decl = createDecl(
+        builder, symbolTable, reduce, reductionIndex,
+        getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
     return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::add,
                                                decl, reduce, reductionIndex)
                                 : decl;
   }
   if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
-    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
-                           getAttr(builder.getIntegerAttr(elType, 0)));
+    omp::DeclareReductionOp decl = createDecl(
+        builder, symbolTable, reduce, reductionIndex,
+        getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
     return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_or,
                                                decl, reduce, reductionIndex)
                                 : decl;
   }
   if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
-    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
-                           getAttr(builder.getIntegerAttr(elType, 0)));
+    omp::DeclareReductionOp decl = createDecl(
+        builder, symbolTable, reduce, reductionIndex,
+        getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
     return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_xor,
                                                decl, reduce, reductionIndex)
                                 : decl;
   }
   if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
-    auto allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
-    auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
-                           getAttr(builder.getIntegerAttr(elType, allOnes)));
+    APInt allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
+    omp::DeclareReductionOp decl = createDecl(
+        builder, symbolTable, reduce, reductionIndex,
+        getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, allOnes)));
     return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_and,
                                                decl, reduce, reductionIndex)
                                 : decl;
@@ -351,11 +351,13 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
   // available as an op).
   if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
     return createDecl(builder, symbolTable, reduce, reductionIndex,
-                      getAttr(builder.getFloatAttr(elType, 1.0)));
+                      getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 1.0)));
+  }
 
-  if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction))
+  if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
     return createDecl(builder, symbolTable, reduce, reductionIndex,
-                      getAttr(builder.getIntegerAttr(elType, 1)));
+                      getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 1)));
+  }
 
   // Match select-based min/max reductions.
   bool isMin;
@@ -421,13 +423,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
     SmallVector<omp::DeclareReductionOp> ompReductionDecls;
     auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
     for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
-      // Ensure validity of reduction type for vector bitwidth calculations.
-      Type reductionType = reduce.getOperands()[i].getType();
-      if (auto vecType = dyn_cast<VectorType>(reductionType))
-        (void)vecType.getElementType().getIntOrFloatBitWidth();
-      else
-        (void)reductionType.getIntOrFloatBitWidth();
-
       omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i);
       ompReductionDecls.push_back(decl);
       if (!decl)
diff --git a/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
index 38d7e3ec2aff1..018f8a03c8e34 100644
--- a/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
@@ -1,9 +1,16 @@
 // RUN: mlir-opt %s --convert-scf-to-openmp | FileCheck %s
 
-// CHECK-LABEL: omp.declare_reduction @__scf_reduction : vector<2xi1> init
-// CHECK-NEXT:  ^bb0(%arg0: vector<2xi1>):
-// CHECK-NEXT:    %[[CONST:.*]] = llvm.mlir.constant(dense<true> : vector<2xi1>) : vector<2xi1>
-// CHECK-NEXT:    omp.yield(%[[CONST]] : vector<2xi1>)
+// CHECK-LABEL: omp.declare_reduction @__scf_reduction : vector<2xi1>
+// CHECK: init {
+// CHECK:   %[[INIT:.*]] = llvm.mlir.constant(dense<true> : vector<2xi1>) : vector<2xi1>
+// CHECK:   omp.yield(%[[INIT]] : vector<2xi1>)
+// CHECK: }
+// CHECK: combiner {
+// CHECK: ^bb0(%[[ARG0:.*]]: vector<2xi1>, %[[ARG1:.*]]: vector<2xi1>):
+// CHECK:   %[[RES:.*]] = arith.andi %[[ARG0]], %[[ARG1]] : vector<2xi1>
+// CHECK:   omp.yield(%[[RES]] : vector<2xi1>)
+// CHECK: }
+// CHECK-NOT: atomic
 
 func.func @vector_and_reduction() {
   %v_mask = vector.constant_mask [1] : vector<2xi1>



More information about the Mlir-commits mailing list