[Mlir-commits] [mlir] [mlir][vector] Refactor multi-reduction patterns (NFC) (PR #183048)
Kunwar Grover
llvmlistbot at llvm.org
Tue Feb 24 04:45:53 PST 2026
================
@@ -317,100 +342,94 @@ struct TwoDimMultiReductionToElementWise
if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
return failure();
+ Value mask = maskingOp ? maskingOp.getMask() : Value();
+
auto loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
ArrayRef<int64_t> srcShape =
multiReductionOp.getSourceVectorType().getShape();
-
- Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
- if (!elementType.isIntOrIndexOrFloat())
- return failure();
-
- OpBuilder::InsertionGuard guard(rewriter);
- auto maskableOp =
- cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
- Operation *rootOp;
- Value mask = nullptr;
- if (maskableOp.isMasked()) {
- rewriter.setInsertionPoint(maskableOp.getMaskingOp());
- rootOp = maskableOp.getMaskingOp();
- mask = maskableOp.getMaskingOp().getMask();
- } else {
- rootOp = multiReductionOp;
- }
+ int outerDim = srcShape[0];
Value result = multiReductionOp.getAcc();
- for (int64_t i = 0; i < srcShape[0]; i++) {
- auto operand = vector::ExtractOp::create(rewriter, loc,
- multiReductionOp.getSource(), i);
- Value extractMask = nullptr;
- if (mask) {
- extractMask = vector::ExtractOp::create(rewriter, loc, mask, i);
- }
- result =
- makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,
- result, /*fastmath=*/nullptr, extractMask);
+ for (int64_t i = 0; i < outerDim; i++) {
+ auto v = vector::ExtractOp::create(rewriter, loc, source, i);
+ Value m;
+ m = mask ? Value(vector::ExtractOp::create(rewriter, loc, mask, i))
+ : nullptr;
----------------
Groverkss wrote:
nit: Any reason not to write this as `Value m = ...`?
https://github.com/llvm/llvm-project/pull/183048
More information about the Mlir-commits
mailing list