[Mlir-commits] [mlir] [mlir][vector] Refactor multi-reduction patterns (NFC) (PR #183048)

Kunwar Grover llvmlistbot at llvm.org
Tue Feb 24 04:45:53 PST 2026


================
@@ -317,100 +342,94 @@ struct TwoDimMultiReductionToElementWise
     if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
       return failure();
 
+    Value mask = maskingOp ? maskingOp.getMask() : Value();
+
     auto loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
     ArrayRef<int64_t> srcShape =
         multiReductionOp.getSourceVectorType().getShape();
-
-    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
-    if (!elementType.isIntOrIndexOrFloat())
-      return failure();
-
-    OpBuilder::InsertionGuard guard(rewriter);
-    auto maskableOp =
-        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
-    Operation *rootOp;
-    Value mask = nullptr;
-    if (maskableOp.isMasked()) {
-      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
-      rootOp = maskableOp.getMaskingOp();
-      mask = maskableOp.getMaskingOp().getMask();
-    } else {
-      rootOp = multiReductionOp;
-    }
+    int outerDim = srcShape[0];
 
     Value result = multiReductionOp.getAcc();
-    for (int64_t i = 0; i < srcShape[0]; i++) {
-      auto operand = vector::ExtractOp::create(rewriter, loc,
-                                               multiReductionOp.getSource(), i);
-      Value extractMask = nullptr;
-      if (mask) {
-        extractMask = vector::ExtractOp::create(rewriter, loc, mask, i);
-      }
-      result =
-          makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,
-                             result, /*fastmath=*/nullptr, extractMask);
+    for (int64_t i = 0; i < outerDim; i++) {
+      auto v = vector::ExtractOp::create(rewriter, loc, source, i);
+      Value m;
+      m = mask ? Value(vector::ExtractOp::create(rewriter, loc, mask, i))
+               : nullptr;
----------------
Groverkss wrote:

nit: Any reason not to write this as `Value m = ...`?

https://github.com/llvm/llvm-project/pull/183048


More information about the Mlir-commits mailing list