[Mlir-commits] [mlir] b1d8205 - [mlir][Vector] Add lowering support for 1-D masked multi-reductions

Diego Caballero llvmlistbot at llvm.org
Tue Feb 7 12:05:24 PST 2023


Author: Diego Caballero
Date: 2023-02-07T20:03:38Z
New Revision: b1d82057ed7b5374e9e8ebf5d7c53104e555ee53

URL: https://github.com/llvm/llvm-project/commit/b1d82057ed7b5374e9e8ebf5d7c53104e555ee53
DIFF: https://github.com/llvm/llvm-project/commit/b1d82057ed7b5374e9e8ebf5d7c53104e555ee53.diff

LOG: [mlir][Vector] Add lowering support for 1-D masked multi-reductions

1-D multi-reductions follow a different lowering path (they are
converted to 2-D multi-reductions) so masked variants need to be
supported explicitly.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D143453

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
    mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
index 117fdcb84c809..b790d141415aa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
@@ -385,17 +385,25 @@ struct OneDimMultiReductionToTwoDim
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
-    auto maskableOp =
-        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
-    if (maskableOp.isMasked())
-      // TODO: Support masking.
-      return failure();
-
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     // Rank-1 or bail.
     if (srcRank != 1)
       return failure();
 
+    // Vector mask setup.
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    Operation *rootOp;
+    Value mask;
+    if (maskableOp.isMasked()) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
     auto loc = multiReductionOp.getLoc();
     auto srcVectorType = multiReductionOp.getSourceVectorType();
     auto srcShape = srcVectorType.getShape();
@@ -408,16 +416,27 @@ struct OneDimMultiReductionToTwoDim
 
     // If the unique dim is reduced and we insert a parallel in front, we need a
     // {false, true} mask.
-    SmallVector<bool, 2> mask{false, true};
+    SmallVector<bool, 2> reductionMask{false, true};
 
     /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
     Value cast = rewriter.create<vector::ShapeCastOp>(
         loc, castedType, multiReductionOp.getSource());
     Value castAcc = rewriter.create<vector::BroadcastOp>(
         loc, accType, multiReductionOp.getAcc());
-    Value reduced = rewriter.create<vector::MultiDimReductionOp>(
-        loc, cast, castAcc, mask, multiReductionOp.getKind());
-    rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
+    Value castMask;
+    if (maskableOp.isMasked()) {
+      auto maskType = mask.getType().cast<ShapedType>();
+      auto castMaskType =
+          VectorType::get(ArrayRef<int64_t>{1, maskType.getShape().back()},
+                          maskType.getElementType());
+      castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
+    }
+
+    Operation *newOp = rewriter.create<vector::MultiDimReductionOp>(
+        loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
+    newOp = vector::maskOperation(rewriter, newOp, castMask);
+
+    rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
                                                    ArrayRef<int64_t>{0});
     return success();
   }

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 5647089d2ed5b..1a9aaf38ac3b0 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -189,6 +189,26 @@ func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf
 
 // -----
 
+func.func @vectorize_1d_dynamic_reduction(%arg0: tensor<?xf32>) -> f32 {
+  %c0 = arith.constant 0 : index
+  %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+  %c0_1 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = vector.create_mask %dim : vector<8xi1>
+  %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1], %cst {in_bounds = [true]} : tensor<?xf32>, vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+  %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %cst [0] : vector<8xf32> to f32 } : vector<8xi1> -> f32
+  return %4 : f32
+}
+
+// Verify that a 1-D vector.multi_reduction is transformed into a vector.reduction.
+// This transform expands 1-D vectors into 2-D.
+
+// CHECK-LABEL:   func.func @vectorize_1d_dynamic_reduction(
+// CHECK:           %[[VAL_5:.*]] = vector.create_mask {{.*}} : vector<8xi1>
+// CHECK:           %[[VAL_7:.*]] = vector.mask %[[VAL_5]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+
+// -----
+
 func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %c0 = arith.constant 0 : index
   %dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>


        


More information about the Mlir-commits mailing list