[Mlir-commits] [mlir] db1e15a - [MLIR][Vector] Add support for inner-parallel masked multi-reductions (#126722)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 14 11:08:27 PST 2025


Author: Manupa Karunaratne
Date: 2025-02-14T11:08:24-08:00
New Revision: db1e15a1da20ff15c9bac8e6d731ec11871431bb

URL: https://github.com/llvm/llvm-project/commit/db1e15a1da20ff15c9bac8e6d731ec11871431bb
DIFF: https://github.com/llvm/llvm-project/commit/db1e15a1da20ff15c9bac8e6d731ec11871431bb.diff

LOG: [MLIR][Vector] Add support for inner-parallel masked multi-reductions (#126722)

This commit adds support to lower inner-parallel flavor of masked vector
multi-reductions.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
    mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 0cafc9cd35517..ce524b259d8d4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -308,12 +308,6 @@ struct TwoDimMultiReductionToElementWise
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
-    auto maskableOp =
-        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
-    if (maskableOp.isMasked())
-      // TODO: Support masking.
-      return failure();
-
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     // Rank-2 ["parallel", "reduce"] or bail.
     if (srcRank != 2)
@@ -330,15 +324,33 @@ struct TwoDimMultiReductionToElementWise
     if (!elementType.isIntOrIndexOrFloat())
       return failure();
 
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    Operation *rootOp;
+    Value mask = nullptr;
+    if (maskableOp.isMasked()) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
     Value result = multiReductionOp.getAcc();
     for (int64_t i = 0; i < srcShape[0]; i++) {
       auto operand = rewriter.create<vector::ExtractOp>(
           loc, multiReductionOp.getSource(), i);
-      result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
-                                  operand, result);
+      Value extractMask = nullptr;
+      if (mask) {
+        extractMask = rewriter.create<vector::ExtractOp>(loc, mask, i);
+      }
+      result =
+          makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,
+                             result, /*fastmath=*/nullptr, extractMask);
     }
 
-    rewriter.replaceOp(multiReductionOp, result);
+    rewriter.replaceOp(rootOp, result);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
index 68621ffaac3d2..ddbc5c7bdb2c0 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -41,3 +41,23 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc
 //        ALL-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
 // INNER-REDUCTION: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
 //  INNER-PARALLEL: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>
+
+// -----
+
+func.func @vector_multi_reduction_masked(%arg0: vector<2x4xf32>, %acc: vector<2xf32>, %mask: vector<2x4xi1>) -> vector<2xf32> {
+    %0 = vector.mask %mask { vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> } : vector<2x4xi1> -> vector<2xf32>
+    return %0 : vector<2xf32>
+}
+
+//       ALL-LABEL: func @vector_multi_reduction_masked
+//        ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<2x4xi1>
+// INNER-REDUCTION: %[[INNERVEC:.+]] = vector.extract %[[INPUT]][0] : vector<4xf32> from vector<2x4xf32>
+// INNER-REDUCTION: %[[INNERACC:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+// INNER-REDUCTION: %[[INNERMASK:.+]] = vector.extract %[[MASK]][0] : vector<4xi1> from vector<2x4xi1>
+// INNER-REDUCTION: vector.mask %[[INNERMASK]] { vector.reduction <mul>, %[[INNERVEC]], %[[INNERACC]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
+//  INNER-PARALLEL: %[[TPMASK:.+]] = vector.transpose %[[MASK]], [1, 0] : vector<2x4xi1> to vector<4x2xi1>
+//  INNER-PARALLEL: %[[TPINPUT:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+//  INNER-PARALLEL: %[[INNERVEC:.+]] = vector.extract %[[TPINPUT]][0] : vector<2xf32> from vector<4x2xf32>
+//  INNER-PARALLEL: %[[INNERMASK:.+]] = vector.extract %[[TPMASK]][0] : vector<2xi1> from vector<4x2xi1>
+//  INNER-PARALLEL: %[[REDUCED:.+]] = arith.mulf %[[INNERVEC]], %[[ACC]] : vector<2xf32>
+//  INNER-PARALLEL: arith.select %[[INNERMASK]], %[[REDUCED]], %[[ACC]] : vector<2xi1>, vector<2xf32>


        


More information about the Mlir-commits mailing list