[Mlir-commits] [mlir] [mlir][vector] Add support for masks in castAwayContractionLeadingOneDim (PR #81906)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Mar 22 01:20:31 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/81906

>From 734fbd381982459c86512378354b808a29fd201e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 15 Feb 2024 19:12:40 +0000
Subject: [PATCH] [mlir][Vector] Add support for masks in
 castAwayContractionLeadingOneDim

Updates `castAwayContractionLeadingOneDim` to inherit from
`MaskableOpRewritePattern` so that this pattern can support masking.

Builds on top of #83827
---
 .../Vector/Transforms/VectorTransforms.h      |   6 +-
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   |  10 +-
 .../Transforms/VectorDropLeadUnitDim.cpp      |  50 +++++----
 .../vector-dropleadunitdim-transforms.mlir    | 104 +++++++++++++-----
 4 files changed, 114 insertions(+), 56 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
index 08d3bb157a0e39..1f7d6411cd5a46 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
@@ -110,8 +110,10 @@ void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp);
 
 /// Cast away the leading unit dim, if exists, for the given contract op.
 /// Return success if the transformation applies; return failure otherwise.
-LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
-                                               RewriterBase &rewriter);
+FailureOr<Value>
+castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
+                                 MaskingOpInterface maskingOp,
+                                 RewriterBase &rewriter);
 
 } // namespace vector
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 35e76a8b623a32..2c548fb6740251 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -127,8 +127,8 @@ SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
 /// responsible for providing an updated ("rewritten") version of:
 ///   a. the source Op when mask _is not_ present,
 ///   b. the source Op and the masking Op when mask _is_ present.
-/// Note that the return value from `matchAndRewriteMaskableOp` depends on the
-/// case above.
+/// To use this pattern, implement `matchAndRewriteMaskableOp`. Note that
+/// the return value will depend on the case above.
 template <class SourceOp>
 struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
   using OpRewritePattern<SourceOp>::OpRewritePattern;
@@ -162,9 +162,9 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
   }
 
 public:
-  // Matches SourceOp that can potentially be masked with `maskingOp`. If the
-  // latter is present, returns an updated masking op (with a replacement for
-  // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
+  // Matches `sourceOp` that can potentially be masked with `maskingOp`. If the
+  // latter is present, returns a replacement for `maskingOp`. Otherwise,
+  // returns a replacement for `sourceOp`.
   virtual FailureOr<Value>
   matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
                             PatternRewriter &rewriter) const = 0;
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 74382b027c2f48..593c1e53557ad8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -329,12 +329,10 @@ struct CastAwayTransferWriteLeadingOneDim
 
 } // namespace
 
-LogicalResult
+FailureOr<Value>
 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
+                                               MaskingOpInterface maskingOp,
                                                RewriterBase &rewriter) {
-  // TODO(#78787): Not supported masked op yet.
-  if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
-    return failure();
   VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
   if (oldAccType == nullptr)
     return failure();
@@ -368,6 +366,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
   SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
                                  contractOp.getAcc()};
   SmallVector<Value> newOperands;
+  auto loc = contractOp.getLoc();
 
   for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
     // Check if the dim to be dropped exists as a leading dim in the operand
@@ -405,7 +404,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
         map = AffineMap::get(map.getNumDims(), 0, transposeResults,
                              contractOp.getContext());
         operands[it.index()] = rewriter.create<vector::TransposeOp>(
-            contractOp.getLoc(), operands[it.index()], perm);
+            loc, operands[it.index()], perm);
       }
     }
     // We have taken care to have the dim to be dropped be
@@ -429,18 +428,29 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
     // Extract if its a valid extraction, otherwise use the operand
     // without extraction.
     newOperands.push_back(
-        validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
-                                                          operands[it.index()],
-                                                          splatZero(dropDim))
+        validExtract ? rewriter.create<vector::ExtractOp>(
+                           loc, operands[it.index()], splatZero(dropDim))
                      : operands[it.index()]);
   }
-  auto newContractOp = rewriter.create<vector::ContractionOp>(
-      contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
+
+  // Depending on whether this vector.contract is masked, the replacing Op
+  // should either be a new vector.contract Op or vector.mask Op.
+  Operation *newOp = rewriter.create<vector::ContractionOp>(
+      loc, newOperands[0], newOperands[1], newOperands[2],
       rewriter.getAffineMapArrayAttr(newIndexingMaps),
       rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
-  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-      contractOp, contractOp->getResultTypes()[0], newContractOp);
-  return success();
+
+  if (maskingOp) {
+    auto newMask = rewriter.create<vector::ExtractOp>(loc, maskingOp.getMask(),
+                                                      splatZero(dropDim));
+
+    newOp = mlir::vector::maskOperation(rewriter, newOp, newMask);
+  }
+
+  return rewriter
+      .create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0],
+                                   newOp->getResults()[0])
+      .getResult();
 }
 
 namespace {
@@ -450,12 +460,14 @@ namespace {
 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
 /// prior to extract.
 struct CastAwayContractionLeadingOneDim
-    : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
-                                PatternRewriter &rewriter) const override {
-    return castAwayContractionLeadingOneDim(contractOp, rewriter);
+    : public MaskableOpRewritePattern<vector::ContractionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
+                            MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
+    return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter);
   }
 };
 
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index af6e636245b04e..4ba51c5953d13c 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -30,6 +30,80 @@ func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %ar
 }
 
 // -----
+// CHECK: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dim_under_const_mask
+// CHECK:           %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1>
+// CHECK:           %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK:           %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK:           %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
+// CHECK:           %[[CONTRACT:.*]] = vector.mask %[[MASK]] {
+// CHECK-SAME:        vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} 
+// CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
+// CHECK-SAME:      } : vector<16x16x8xi1> -> vector<16x16xf32>
+// CHECK:           %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK:           return %[[RES]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+  affine_map<(l, i, j, k) -> (l, i, k)>,
+  affine_map<(l, i, j, k) -> (l, k, j)>,
+  affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+  indexing_maps = #contraction_accesses0,
+  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
+  %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
+  %0 = vector.mask %mask {
+    vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+  } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+  return %0 : vector<1x16x16xf32>
+}
+
+// -----
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dim_under_mask
+// CHECK:           %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK:           %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK:           %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32>
+// CHECK:           %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1>
+// CHECK:           %[[CONTRACT:.*]] = vector.mask %[[M]] {
+// CHECK-SAME:      vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+// CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> 
+// CHECK-SAME:      } : vector<16x16x8xi1> -> vector<16x16xf32>
+// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+  affine_map<(l, i, j, k) -> (l, i, k)>,
+  affine_map<(l, i, j, k) -> (l, k, j)>,
+  affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+  indexing_maps = #contraction_accesses0,
+  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @cast_away_contraction_leading_one_dim_under_mask(
+  %arg0: vector<1x16x8xf32>,
+  %arg1: vector<1x8x16xf32>,
+  %arg2: vector<1x16x16xf32>,
+  %mask: vector<1x16x16x8xi1>) -> vector<1x16x16xf32> {
+  %0 = vector.mask %mask {
+    vector.contract #contraction_trait0 %arg0, %arg1, %arg2  : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+  } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+  return %0: vector<1x16x16xf32>
+}
+
+// -----
+
 // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d1)>
 // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1) -> (d1, d0)>
 // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)>
@@ -164,36 +238,6 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
   return %0: vector<1x1x2x16xf32>
 }
 
-// -----
-
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-
-// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
-// CHECK:      %[[MASK:.+]] = vector.constant_mask
-// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
-// CHECK:      %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
-// CHECK-SAME:   vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
-// CHECK:      return %[[RET]] : vector<1x16x16xf32>
-
-#contraction_accesses0 = [
-  affine_map<(l, i, j, k) -> (l, i, k)>,
-  affine_map<(l, i, j, k) -> (l, k, j)>,
-  affine_map<(l, i, j, k) -> (l, i, j)>
-]
-#contraction_trait0 = {
-  indexing_maps = #contraction_accesses0,
-  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-}
-
-func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
-  %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
-  %0 = vector.mask %mask {
-    vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
-  } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
-  return %0 : vector<1x16x16xf32>
-}
 
 // -----
 // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims



More information about the Mlir-commits mailing list