[Mlir-commits] [mlir] [MLIR][Vector] Add support for inner-parallel masked multi-reductions (PR #126722)
Manupa Karunaratne
llvmlistbot at llvm.org
Wed Feb 12 02:49:50 PST 2025
https://github.com/manupak updated https://github.com/llvm/llvm-project/pull/126722
>From 6c94d5e1568f087710e0e96cad22fefe953aa69e Mon Sep 17 00:00:00 2001
From: Manupa Karunaratne <manupa.karunaratne at amd.com>
Date: Tue, 11 Feb 2025 04:26:10 -0800
Subject: [PATCH] [MLIR][Vector] Add support for inner-parallel masked
multi-reductions
This commit adds suppot to lower inner-parallel flavor
of masked vector multi-reductions.
---
.../Transforms/LowerVectorMultiReduction.cpp | 27 +++++++++++++------
.../vector-multi-reduction-pass-lowering.mlir | 20 ++++++++++++++
2 files changed, 39 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 0cafc9cd35517..fa2c2d3b55833 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -308,12 +308,6 @@ struct TwoDimMultiReductionToElementWise
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
- auto maskableOp =
- cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
- if (maskableOp.isMasked())
- // TODO: Support masking.
- return failure();
-
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
// Rank-2 ["parallel", "reduce"] or bail.
if (srcRank != 2)
@@ -330,15 +324,32 @@ struct TwoDimMultiReductionToElementWise
if (!elementType.isIntOrIndexOrFloat())
return failure();
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ Operation *rootOp;
+ Value mask = nullptr;
+ if (maskableOp.isMasked()) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
Value result = multiReductionOp.getAcc();
for (int64_t i = 0; i < srcShape[0]; i++) {
auto operand = rewriter.create<vector::ExtractOp>(
loc, multiReductionOp.getSource(), i);
+ Value extractMask = nullptr;
+ if (mask) {
+ extractMask = rewriter.create<vector::ExtractOp>(loc, mask, i);
+ }
result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
- operand, result);
+ operand, result, /*fastmath=*/ nullptr, extractMask);
}
- rewriter.replaceOp(multiReductionOp, result);
+ rewriter.replaceOp(rootOp, result);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
index 68621ffaac3d2..ddbc5c7bdb2c0 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -41,3 +41,23 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc
// ALL-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
// INNER-REDUCTION: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
// INNER-PARALLEL: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>
+
+// -----
+
+func.func @vector_multi_reduction_masked(%arg0: vector<2x4xf32>, %acc: vector<2xf32>, %mask: vector<2x4xi1>) -> vector<2xf32> {
+ %0 = vector.mask %mask { vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> } : vector<2x4xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// ALL-LABEL: func @vector_multi_reduction_masked
+// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<2x4xi1>
+// INNER-REDUCTION: %[[INNERVEC:.+]] = vector.extract %[[INPUT]][0] : vector<4xf32> from vector<2x4xf32>
+// INNER-REDUCTION: %[[INNERACC:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+// INNER-REDUCTION: %[[INNERMASK:.+]] = vector.extract %[[MASK]][0] : vector<4xi1> from vector<2x4xi1>
+// INNER-REDUCTION: vector.mask %[[INNERMASK]] { vector.reduction <mul>, %[[INNERVEC]], %[[INNERACC]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
+// INNER-PARALLEL: %[[TPMASK:.+]] = vector.transpose %[[MASK]], [1, 0] : vector<2x4xi1> to vector<4x2xi1>
+// INNER-PARALLEL: %[[TPINPUT:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+// INNER-PARALLEL: %[[INNERVEC:.+]] = vector.extract %[[TPINPUT]][0] : vector<2xf32> from vector<4x2xf32>
+// INNER-PARALLEL: %[[INNERMASK:.+]] = vector.extract %[[TPMASK]][0] : vector<2xi1> from vector<4x2xi1>
+// INNER-PARALLEL: %[[REDUCED:.+]] = arith.mulf %[[INNERVEC]], %[[ACC]] : vector<2xf32>
+// INNER-PARALLEL: arith.select %[[INNERMASK]], %[[REDUCED]], %[[ACC]] : vector<2xi1>, vector<2xf32>
More information about the Mlir-commits
mailing list