[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (PR #173938)

Aniket Singh llvmlistbot at llvm.org
Mon Dec 29 16:57:36 PST 2025


https://github.com/Aniketsingh54 updated https://github.com/llvm/llvm-project/pull/173938

>From 014de0741b47fbf1ae1819fa6afda3fc3ea6be05 Mon Sep 17 00:00:00 2001
From: Aniket Singh <amiket.singh.3200.00 at gmail.com>
Date: Tue, 30 Dec 2025 06:08:07 +0530
Subject: [PATCH] [MLIR][SCFToOpenMP] Fix crash when lowering vector reductions

The SCF to OpenMP conversion pass crashed when encountering scf.parallel
with vector reductions because it assumed all reduction types were scalar.
This patch:
- Adds support for VectorType in reduction initializers.
- Uses DenseElementsAttr for vector splat initializers.
- Prevents the use of llvm.atomicrmw for vector types as they are
  not supported by the LLVM instruction.

Fixes #173860
---
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    | 256 +++++++++---------
 .../SCFToOpenMP/vector-reduction.mlir         |  22 ++
 2 files changed, 150 insertions(+), 128 deletions(-)
 create mode 100644 mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir

diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..66ea0716a2292 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -153,35 +153,56 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
 /// Returns an attribute with the minimum (if `min` is set) or the maximum value
 /// (otherwise) for the given float type.
 static Attribute minMaxValueForFloat(Type type, bool min) {
-  auto fltType = cast<FloatType>(type);
-  return FloatAttr::get(
-      type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
+  // If the type is a vector, we need to find the neutral value for the
+  // underlying element type and then create a splat attribute.
+  Type elType = type;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    elType = vecType.getElementType();
+
+  auto fltType = cast<FloatType>(elType);
+  auto val = llvm::APFloat::getLargest(fltSemanticsForType(fltType), min);
+
+  // For vector types, return a DenseElementsAttr (splat).
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return DenseElementsAttr::get(vecType, val);
+
+  return FloatAttr::get(type, val);
 }
 
-/// Returns an attribute with the signed integer minimum (if `min` is set) or
-/// the maximum value (otherwise) for the given integer type, regardless of its
-/// signedness semantics (only the width is considered).
 static Attribute minMaxValueForSignedInt(Type type, bool min) {
-  auto intType = cast<IntegerType>(type);
+  // Extract scalar element type to handle vector reductions.
+  Type elType = type;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    elType = vecType.getElementType();
+
+  auto intType = cast<IntegerType>(elType);
   unsigned bitwidth = intType.getWidth();
-  return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
-                                    : llvm::APInt::getSignedMaxValue(bitwidth));
+  auto val = min ? llvm::APInt::getSignedMinValue(bitwidth)
+                 : llvm::APInt::getSignedMaxValue(bitwidth);
+
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return DenseElementsAttr::get(vecType, val);
+  return IntegerAttr::get(type, val);
 }
 
-/// Returns an attribute with the unsigned integer minimum (if `min` is set) or
-/// the maximum value (otherwise) for the given integer type, regardless of its
-/// signedness semantics (only the width is considered).
 static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
-  auto intType = cast<IntegerType>(type);
+  // Extract scalar element type to handle vector reductions.
+  Type elType = type;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    elType = vecType.getElementType();
+
+  auto intType = cast<IntegerType>(elType);
   unsigned bitwidth = intType.getWidth();
-  return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
-                                    : llvm::APInt::getAllOnes(bitwidth));
+  auto val =
+      min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth);
+
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return DenseElementsAttr::get(vecType, val);
+  return IntegerAttr::get(type, val);
 }
 
 /// Creates an OpenMP reduction declaration and inserts it into the provided
-/// symbol table. The declaration has a constant initializer with the neutral
-/// value `initValue`, and the `reductionIndex`-th reduction combiner carried
-/// over from `reduce`.
+/// symbol table.
 static omp::DeclareReductionOp
 createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
            scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) {
@@ -203,7 +224,7 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
   Operation *terminator =
       &reduce.getReductions()[reductionIndex].front().back();
   assert(isa<scf::ReduceReturnOp>(terminator) &&
-         "expected reduce op to be terminated by redure return");
+         "expected reduce op to be terminated by reduce return");
   builder.setInsertionPoint(terminator);
   builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
                                            terminator->getOperands());
@@ -213,8 +234,7 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
   return decl;
 }
 
-/// Adds an atomic reduction combiner to the given OpenMP reduction declaration
-/// using llvm.atomicrmw of the given kind.
+/// Adds an atomic reduction combiner using llvm.atomicrmw.
 static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
                                             LLVM::AtomicBinOp atomicKind,
                                             omp::DeclareReductionOp decl,
@@ -238,17 +258,13 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
 }
 
 /// Creates an OpenMP reduction declaration that corresponds to the given SCF
-/// reduction and returns it. Recognizes common reductions in order to identify
-/// the neutral value, necessary for the OpenMP declaration. If the reduction
-/// cannot be recognized, returns null.
+/// reduction.
 static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
                                                 scf::ReduceOp reduce,
                                                 int64_t reductionIndex) {
   Operation *container = SymbolTable::getNearestSymbolTable(reduce);
   SymbolTable symbolTable(container);
 
-  // Insert reduction declarations in the symbol-table ancestor before the
-  // ancestor of the current insertion point.
   Operation *insertionPoint = reduce;
   while (insertionPoint->getParentOp() != container)
     insertionPoint = insertionPoint->getParentOp();
@@ -258,56 +274,83 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
   assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) &&
          "expected reduction region to have a single element");
 
-  // Match simple binary reductions that can be expressed with atomicrmw.
   Type type = reduce.getOperands()[reductionIndex].getType();
   Block &reduction = reduce.getReductions()[reductionIndex].front();
+
+  // Extract scalar element type to handle vector bitwidths correctly.
+  Type elType = type;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    elType = vecType.getElementType();
+
+  // Helper to create splat or scalar integer attributes.
+  auto getIntAttr = [&](Type t, const llvm::APInt &value) -> Attribute {
+    if (auto vecType = dyn_cast<VectorType>(t))
+      return DenseElementsAttr::get(vecType, value);
+    return builder.getIntegerAttr(t, value);
+  };
+
+  // Helper to create splat or scalar float attributes.
+  auto getFloatAttr = [&](Type t, double value) -> Attribute {
+    if (auto vecType = dyn_cast<VectorType>(t))
+      return DenseElementsAttr::get(
+          vecType, builder.getFloatAttr(vecType.getElementType(), value));
+    return builder.getFloatAttr(t, value);
+  };
+
+  // Match simple binary reductions and only add atomicRMW if type is not a
+  // vector.
   if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getFloatAttr(type, 0.0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
-                        reductionIndex);
+    omp::DeclareReductionOp decl = createDecl(
+        builder, symbolTable, reduce, reductionIndex, getFloatAttr(type, 0.0));
+    return isa<VectorType>(type)
+               ? decl
+               : addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
+                              reductionIndex);
   }
   if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
-                        reductionIndex);
+    omp::DeclareReductionOp decl = createDecl(
+        builder, symbolTable, reduce, reductionIndex,
+        getIntAttr(type, llvm::APInt::getZero(elType.getIntOrFloatBitWidth())));
+    return isa<VectorType>(type) ? decl
+                                 : addAtomicRMW(builder, LLVM::AtomicBinOp::add,
+                                                decl, reduce, reductionIndex);
   }
   if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
-                        reductionIndex);
+    omp::DeclareReductionOp decl = createDecl(
+        builder, symbolTable, reduce, reductionIndex,
+        getIntAttr(type, llvm::APInt::getZero(elType.getIntOrFloatBitWidth())));
+    return isa<VectorType>(type) ? decl
+                                 : addAtomicRMW(builder, LLVM::AtomicBinOp::_or,
+                                                decl, reduce, reductionIndex);
   }
   if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
-    omp::DeclareReductionOp decl =
-        createDecl(builder, symbolTable, reduce, reductionIndex,
-                   builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
-                        reductionIndex);
-  }
-  if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
     omp::DeclareReductionOp decl = createDecl(
         builder, symbolTable, reduce, reductionIndex,
-        builder.getIntegerAttr(
-            type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
-                        reductionIndex);
+        getIntAttr(type, llvm::APInt::getZero(elType.getIntOrFloatBitWidth())));
+    return isa<VectorType>(type)
+               ? decl
+               : addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
+                              reductionIndex);
+  }
+  if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
+    auto allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
+    omp::DeclareReductionOp decl =
+        createDecl(builder, symbolTable, reduce, reductionIndex,
+                   getIntAttr(type, allOnes));
+    return isa<VectorType>(type)
+               ? decl
+               : addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
+                              reductionIndex);
   }
 
-  // Match simple binary reductions that cannot be expressed with atomicrmw.
-  // TODO: add atomic region using cmpxchg (which needs atomic load to be
-  // available as an op).
   if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
     return createDecl(builder, symbolTable, reduce, reductionIndex,
-                      builder.getFloatAttr(type, 1.0));
+                      getFloatAttr(type, 1.0));
   }
   if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
-    return createDecl(builder, symbolTable, reduce, reductionIndex,
-                      builder.getIntegerAttr(type, 1));
+    return createDecl(
+        builder, symbolTable, reduce, reductionIndex,
+        getIntAttr(type, llvm::APInt(elType.getIntOrFloatBitWidth(), 1)));
   }
 
   // Match select-based min/max reductions.
@@ -330,9 +373,11 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
     omp::DeclareReductionOp decl =
         createDecl(builder, symbolTable, reduce, reductionIndex,
                    minMaxValueForSignedInt(type, !isMin));
-    return addAtomicRMW(builder,
-                        isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
-                        decl, reduce, reductionIndex);
+    return isa<VectorType>(type) ? decl
+                                 : addAtomicRMW(builder,
+                                                isMin ? LLVM::AtomicBinOp::min
+                                                      : LLVM::AtomicBinOp::max,
+                                                decl, reduce, reductionIndex);
   }
   if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
           reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
@@ -343,9 +388,11 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
     omp::DeclareReductionOp decl =
         createDecl(builder, symbolTable, reduce, reductionIndex,
                    minMaxValueForUnsignedInt(type, !isMin));
-    return addAtomicRMW(
-        builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
-        decl, reduce, reductionIndex);
+    return isa<VectorType>(type) ? decl
+                                 : addAtomicRMW(builder,
+                                                isMin ? LLVM::AtomicBinOp::umin
+                                                      : LLVM::AtomicBinOp::umax,
+                                                decl, reduce, reductionIndex);
   }
 
   return nullptr;
@@ -363,13 +410,17 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 
   LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
                                 PatternRewriter &rewriter) const override {
-    // Declare reductions.
-    // TODO: consider checking it here is already a compatible reduction
-    // declaration and use it instead of redeclaring.
     SmallVector<Attribute> reductionSyms;
     SmallVector<omp::DeclareReductionOp> ompReductionDecls;
     auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
     for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
+      // Ensure validity of reduction type for vector bitwidth calculations.
+      Type reductionType = reduce.getOperands()[i].getType();
+      if (auto vecType = dyn_cast<VectorType>(reductionType))
+        (void)vecType.getElementType().getIntOrFloatBitWidth();
+      else
+        (void)reductionType.getIntOrFloatBitWidth();
+
       omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i);
       ompReductionDecls.push_back(decl);
       if (!decl)
@@ -378,8 +429,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
           SymbolRefAttr::get(rewriter.getContext(), decl.getSymName()));
     }
 
-    // Allocate reduction variables. Make sure the we don't overflow the stack
-    // with local `alloca`s by saving and restoring the stack pointer.
     Location loc = parallelOp.getLoc();
     Value one =
         LLVM::ConstantOp::create(rewriter, loc, rewriter.getIntegerType(64),
@@ -398,24 +447,16 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
       reductionVariables.push_back(storage);
     }
 
-    // Replace the reduction operations contained in this loop. Must be done
-    // here rather than in a separate pattern to have access to the list of
-    // reduction variables.
     for (auto [x, y, rD] : llvm::zip_equal(
              reductionVariables, reduce.getOperands(), ompReductionDecls)) {
       OpBuilder::InsertionGuard guard(rewriter);
       rewriter.setInsertionPoint(reduce);
       Region &redRegion = rD.getReductionRegion();
-      // The SCF dialect by definition contains only structured operations
-      // and hence the SCF reduction region will contain a single block.
-      // The ompReductionDecls region is a copy of the SCF reduction region
-      // and hence has the same property.
       assert(redRegion.hasOneBlock() &&
              "expect reduction region to have one block");
       Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
       Value pvtRedVal = LLVM::LoadOp::create(rewriter, reduce.getLoc(),
                                              rD.getType(), pvtRedVar);
-      // Make a copy of the reduction combiner region in the body
       mlir::OpBuilder builder(rewriter.getContext());
       builder.setInsertionPoint(reduce);
       mlir::IRMapping mapper;
@@ -426,8 +467,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
       for (auto &op : redRegion.getOps()) {
         Operation *cloneOp = builder.clone(op, mapper);
         if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
-          assert(yieldOp && yieldOp.getResults().size() == 1 &&
-                 "expect YieldOp in reduction region to return one result");
           Value redVal = yieldOp.getResults()[0];
           LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
           rewriter.eraseOp(yieldOp);
@@ -442,67 +481,39 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
       numThreadsVar = LLVM::ConstantOp::create(
           rewriter, loc, rewriter.getI32IntegerAttr(numThreads));
     }
-    // Create the parallel wrapper.
     auto ompParallel = omp::ParallelOp::create(
-        rewriter, loc,
-        /* allocate_vars = */ llvm::SmallVector<Value>{},
-        /* allocator_vars = */ llvm::SmallVector<Value>{},
-        /* if_expr = */ Value{},
-        /* num_threads = */ numThreadsVar,
-        /* private_vars = */ ValueRange(),
-        /* private_syms = */ nullptr,
-        /* private_needs_barrier = */ nullptr,
-        /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
-        /* reduction_mod = */ nullptr,
-        /* reduction_vars = */ llvm::SmallVector<Value>{},
-        /* reduction_byref = */ DenseBoolArrayAttr{},
-        /* reduction_syms = */ ArrayAttr{});
+        rewriter, loc, {}, {}, Value{}, numThreadsVar, ValueRange(), nullptr,
+        nullptr, omp::ClauseProcBindKindAttr{}, nullptr, {},
+        DenseBoolArrayAttr{}, ArrayAttr{});
     {
-
       OpBuilder::InsertionGuard guard(rewriter);
       rewriter.createBlock(&ompParallel.getRegion());
-
-      // Replace the loop.
       {
         OpBuilder::InsertionGuard allocaGuard(rewriter);
-        // Create worksharing loop wrapper.
         auto wsloopOp = omp::WsloopOp::create(rewriter, parallelOp.getLoc());
         if (!reductionVariables.empty()) {
           wsloopOp.setReductionSymsAttr(
               ArrayAttr::get(rewriter.getContext(), reductionSyms));
           wsloopOp.getReductionVarsMutable().append(reductionVariables);
-          llvm::SmallVector<bool> reductionByRef;
-          // false because these reductions always reduce scalars and so do
-          // not need to pass by reference
-          reductionByRef.resize(reductionVariables.size(), false);
+          llvm::SmallVector<bool> reductionByRef(reductionVariables.size(),
+                                                 false);
           wsloopOp.setReductionByref(
               DenseBoolArrayAttr::get(rewriter.getContext(), reductionByRef));
         }
-        omp::TerminatorOp::create(rewriter, loc); // omp.parallel terminator.
-
-        // The wrapper's entry block arguments will define the reduction
-        // variables.
+        omp::TerminatorOp::create(rewriter, loc);
         llvm::SmallVector<mlir::Type> reductionTypes;
-        reductionTypes.reserve(reductionVariables.size());
         llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
                         [](mlir::Value v) { return v.getType(); });
         rewriter.createBlock(
             &wsloopOp.getRegion(), {}, reductionTypes,
             llvm::SmallVector<mlir::Location>(reductionVariables.size(),
                                               parallelOp.getLoc()));
-
-        // Create loop nest and populate region with contents of scf.parallel.
         auto loopOp = omp::LoopNestOp::create(
             rewriter, parallelOp.getLoc(), parallelOp.getLowerBound().size(),
             parallelOp.getLowerBound(), parallelOp.getUpperBound(),
-            parallelOp.getStep(), /*loop_inclusive=*/false,
-            /*tile_sizes=*/nullptr);
-
+            parallelOp.getStep(), /*loop_inclusive=*/false, nullptr);
         rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
                                     loopOp.getRegion().begin());
-
-        // Remove reduction-related block arguments from omp.loop_nest and
-        // redirect uses to the corresponding omp.wsloop block argument.
         mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
         unsigned numLoops = parallelOp.getNumLoops();
         rewriter.replaceAllUsesWith(
@@ -510,11 +521,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
             wsloopOp.getRegion().getArguments());
         loopOpEntryBlock.eraseArguments(
             numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
-
         Block *ops =
             rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin());
         rewriter.setInsertionPointToStart(&loopOpEntryBlock);
-
         auto scope = memref::AllocaScopeOp::create(
             rewriter, parallelOp.getLoc(), TypeRange());
         omp::YieldOp::create(rewriter, loc, ValueRange());
@@ -525,21 +534,16 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
       }
     }
 
-    // Load loop results.
     SmallVector<Value> results;
-    results.reserve(reductionVariables.size());
     for (auto [variable, type] :
          llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
-      Value res = LLVM::LoadOp::create(rewriter, loc, type, variable);
-      results.push_back(res);
+      results.push_back(LLVM::LoadOp::create(rewriter, loc, type, variable));
     }
     rewriter.replaceOp(parallelOp, results);
-
     return success();
   }
 };
 
-/// Applies the conversion patterns in the given function.
 static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
   RewritePatternSet patterns(module.getContext());
   patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
@@ -555,17 +559,13 @@ static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
   return failure(status.wasInterrupted());
 }
 
-/// A pass converting SCF operations to OpenMP operations.
 struct SCFToOpenMPPass
     : public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
-
   using Base::Base;
-
-  /// Pass entry point.
   void runOnOperation() override {
     if (failed(applyPatterns(getOperation(), numThreads)))
       signalPassFailure();
   }
 };
 
-} // namespace
+} // namespace
\ No newline at end of file
diff --git a/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
new file mode 100644
index 0000000000000..38d7e3ec2aff1
--- /dev/null
+++ b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s --convert-scf-to-openmp | FileCheck %s
+
+// CHECK-LABEL: omp.declare_reduction @__scf_reduction : vector<2xi1> init
+// CHECK-NEXT:  ^bb0(%arg0: vector<2xi1>):
+// CHECK-NEXT:    %[[CONST:.*]] = llvm.mlir.constant(dense<true> : vector<2xi1>) : vector<2xi1>
+// CHECK-NEXT:    omp.yield(%[[CONST]] : vector<2xi1>)
+
+func.func @vector_and_reduction() {
+  %v_mask = vector.constant_mask [1] : vector<2xi1>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %result = scf.parallel (%i) = (%c0) to (%c2) step (%c1) init(%v_mask) -> vector<2xi1> {
+    %val = vector.constant_mask [1] : vector<2xi1>
+    scf.reduce (%val : vector<2xi1>) {
+    ^bb0(%lhs: vector<2xi1>, %rhs: vector<2xi1>):
+      %0 = arith.andi %lhs, %rhs : vector<2xi1>
+      scf.reduce.return %0 : vector<2xi1>
+    }
+  }
+  return
+}
\ No newline at end of file



More information about the Mlir-commits mailing list