[Mlir-commits] [mlir] [mlir] Exclude masked ops in VectorDropLeadUnitDim (PR #76468)

Jerry Wu llvmlistbot at llvm.org
Fri Jan 19 13:22:42 PST 2024


https://github.com/pzread updated https://github.com/llvm/llvm-project/pull/76468

>From b522397626dbf86f73d9707ef2d0061116ebe985 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Fri, 22 Dec 2023 00:00:42 +0000
Subject: [PATCH 1/3] Handle masked op in VectorDropLeadUnitDim patterns

---
 .../Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp  | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 84294e4552a607e..65517295aa72d29 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -223,6 +223,9 @@ struct CastAwayTransferReadLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::TransferReadOp read,
                                 PatternRewriter &rewriter) const override {
+    // Not supported masked op yet.
+    if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
+      return failure();
     // TODO: support 0-d corner case.
     if (read.getTransferRank() == 0)
       return failure();
@@ -274,6 +277,9 @@ struct CastAwayTransferWriteLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
                                 PatternRewriter &rewriter) const override {
+    // Not supported masked op yet.
+    if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
+      return failure();
     // TODO: support 0-d corner case.
     if (write.getTransferRank() == 0)
       return failure();
@@ -325,6 +331,9 @@ struct CastAwayTransferWriteLeadingOneDim
 LogicalResult
 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
                                                RewriterBase &rewriter) {
+  // Not supported masked op yet.
+  if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
+    return failure();
   VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
   if (oldAccType == nullptr)
     return failure();

>From c2ee016a6c14a6a890b032870b075f517f68bc43 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Wed, 27 Dec 2023 21:47:08 +0000
Subject: [PATCH 2/3] Add tests

---
 .../vector-dropleadunitdim-transforms.mlir    | 66 +++++++++++++++++++
 1 file changed, 66 insertions(+)

diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 71dffca8f14da59..f601be041681445 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -164,6 +164,37 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
   return %0: vector<1x1x2x16xf32>
 }
 
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
+// CHECK:      %[[MASK:.+]] = vector.constant_mask
+// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK:      %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME:   vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
+// CHECK:      return %[[RET]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+  affine_map<(l, i, j, k) -> (l, i, k)>,
+  affine_map<(l, i, j, k) -> (l, k, j)>,
+  affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+  indexing_maps = #contraction_accesses0,
+  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
+  %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
+  %0 = vector.mask %mask {
+    vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+  } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+  return %0 : vector<1x16x16xf32>
+}
+
 // -----
 // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
 func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
@@ -253,6 +284,24 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16
 
 // -----
 
+// CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask
+// CHECK:      %[[MASK:.+]] = vector.constant_mask
+// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK:      %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME:   vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> }
+// CHECK:      return %[[RET]] : vector<1x4xf16>
+func.func @not_insert_cast_fo4_transfer_read_under_mask(%arg0: memref<1x1x4xf16>) -> vector<1x4xf16> {
+  %c0 = arith.constant 0 : index
+  %f0 = arith.constant 0. : f16
+  %mask = vector.constant_mask [1, 3] : vector<1x4xi1>
+  %ret = vector.mask %mask {
+    vector.transfer_read %arg0[%c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x4xf16>, vector<1x4xf16>
+  } : vector<1x4xi1> -> vector<1x4xf16>
+  return %ret: vector<1x4xf16>
+}
+
+// -----
+
 // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
 func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -286,6 +335,23 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
 
 // -----
 
+// CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask
+// CHECK:      %[[MASK:.+]] = vector.constant_mask
+// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK:      vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME:   vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> }
+// CHECK:      return
+func.func @not_insert_cast_for_transfer_write_under_mask(%arg0: memref<1x1x4xf16>, %arg1: vector<1x4xf16>) {
+  %c0 = arith.constant 0 : index
+  %mask = vector.constant_mask [1, 3] : vector<1x4xi1>
+  vector.mask %mask {
+    vector.transfer_write %arg1, %arg0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x1x4xf16>
+  } : vector<1x4xi1>
+  return
+}
+
+// -----
+
 // CHECK:       #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
 // CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
 func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {

>From 5b534f4d171e23b27a65c10368676015ae95db16 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Fri, 19 Jan 2024 21:22:12 +0000
Subject: [PATCH 3/3] Add TODOs

---
 .../lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 65517295aa72d29..e1ed5d81625d8e9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -223,7 +223,7 @@ struct CastAwayTransferReadLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::TransferReadOp read,
                                 PatternRewriter &rewriter) const override {
-    // Not supported masked op yet.
+    // TODO(#78787): Not supported masked op yet.
     if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
       return failure();
     // TODO: support 0-d corner case.
@@ -277,7 +277,7 @@ struct CastAwayTransferWriteLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
                                 PatternRewriter &rewriter) const override {
-    // Not supported masked op yet.
+    // TODO(#78787): Not supported masked op yet.
     if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
       return failure();
     // TODO: support 0-d corner case.
@@ -331,7 +331,7 @@ struct CastAwayTransferWriteLeadingOneDim
 LogicalResult
 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
                                                RewriterBase &rewriter) {
-  // Not supported masked op yet.
+  // TODO(#78787): Not supported masked op yet.
   if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
     return failure();
   VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());



More information about the Mlir-commits mailing list