[Mlir-commits] [mlir] [MLIR][Vector] Implement TransferOpReduceRank as MaskableOpRewritePattern (PR #92426)

Hugo Trachino llvmlistbot at llvm.org
Thu May 16 09:17:46 PDT 2024


https://github.com/nujaa created https://github.com/llvm/llvm-project/pull/92426

Implements `TransferOpReduceRank` as a MaskableOpRewritePattern. Allowing to exit gracefully when such use of a `vector::transfer_read` is inside a vector::MaskOp

Split of https://github.com/llvm/llvm-project/pull/90835

>From e40bffa399db38743f1031281c66368e5e6791b5 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Tue, 14 May 2024 22:21:23 +0800
Subject: [PATCH 1/4] [MLIR][Vector] Implement transferXXPermutationLowering as
 MaskableOpRewritePattern

---
 .../Vector/Transforms/LowerVectorTransfer.cpp | 64 +++++++++------
 .../vector-transfer-permutation-lowering.mlir | 82 +++++++++++++++++++
 2 files changed, 122 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index b30b43d70bf0f..7f5703b635068 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -90,14 +90,19 @@ namespace {
 /// Note that an alternative is to transform it to linalg.transpose +
 /// vector.transfer_read to do the transpose in memory instead.
 struct TransferReadPermutationLowering
-    : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public MaskableOpRewritePattern<vector::TransferReadOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferReadOp op,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferReadOp op,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    // TODO: Support transfer_read inside MaskOp case.
+    if (maskOp)
+      return rewriter.notifyMatchFailure(op, "Masked case not supported");
 
     SmallVector<unsigned> permutation;
     AffineMap map = op.getPermutationMap();
@@ -142,9 +147,9 @@ struct TransferReadPermutationLowering
 
     // Transpose result of transfer_read.
     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
-    rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
-                                                     transposePerm);
-    return success();
+    return rewriter
+        .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
+        .getResult();
   }
 };
 
@@ -165,14 +170,19 @@ struct TransferReadPermutationLowering
 ///     %v = vector.transfer_write %tmp ...
 ///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
 struct TransferWritePermutationLowering
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public MaskableOpRewritePattern<vector::TransferWriteOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferWriteOp op,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    // TODO: Support transfer_write inside MaskOp case.
+    if (maskOp)
+      return rewriter.notifyMatchFailure(op, "Masked case not supported");
 
     SmallVector<unsigned> permutation;
     AffineMap map = op.getPermutationMap();
@@ -207,11 +217,11 @@ struct TransferWritePermutationLowering
         op.getLoc(), op.getVector(), indices);
     auto newMap = AffineMap::getMinorIdentityMap(
         map.getNumDims(), map.getNumResults(), rewriter.getContext());
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
-        op.getMask(), newInBoundsAttr);
-
-    return success();
+    return rewriter
+        .create<vector::TransferWriteOp>(
+            op.getLoc(), newVec, op.getSource(), op.getIndices(),
+            AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr)
+        .getResult();
   }
 };
 
@@ -231,14 +241,19 @@ struct TransferWritePermutationLowering
 ///     vector<1x8x16xf32>
 /// ```
 struct TransferWriteNonPermutationLowering
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public MaskableOpRewritePattern<vector::TransferWriteOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp op,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferWriteOp op,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    // TODO: Support transfer_write inside MaskOp case.
+    if (maskOp)
+      return rewriter.notifyMatchFailure(op, "Masked case not supported");
 
     SmallVector<unsigned> permutation;
     AffineMap map = op.getPermutationMap();
@@ -285,10 +300,11 @@ struct TransferWriteNonPermutationLowering
       newInBoundsValues.push_back(op.isDimInBounds(i));
     }
     ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
-        newMask, newInBoundsAttr);
-    return success();
+    return rewriter
+        .create<vector::TransferWriteOp>(
+            op.getLoc(), newVec, op.getSource(), op.getIndices(),
+            AffineMapAttr::get(newMap), newMask, newInBoundsAttr)
+        .getResult();
   }
 };
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index e48af3cd7aace..d63d47fe4481d 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -46,6 +46,55 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %
     return
 }
 
+// transfer_write in MaskOp case not supported.
+// CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
+//  CHECK-SAME:        %[[ARG_0:.*]]: tensor<?x?xf32>,
+//  CHECK-SAME:        %[[ARG_1:.*]]: vector<16xf32>,
+//  CHECK-SAME:        %[[IDX:.*]]: index,
+//  CHECK-SAME:        %[[MASK:.*]]: vector<16xi1>
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+//       CHECK:   return %[[RES]]
+func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor<?x?xf32> {
+  %r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+  return %r : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
+//  CHECK-SAME:        %[[ARG_0:.*]]: vector<4x[8]xi16>,
+//  CHECK-SAME:        %[[ARG_1:.*]]: tensor<?x?x?x?xf32>,
+//  CHECK-SAME:        %[[MASK:.*]]: vector<4x[8]xi1>)
+//  CHECK-SAME:        -> tensor<?x?x?x?xf32> {
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]][%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = #[[MAP:.*]]} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
+//       CHECK:   return %[[R]] : tensor<?x?x?x?xf32>
+func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor<?x?x?x?xf32>, %mask:  vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
+     %c0 = arith.constant 0 : index
+     %r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
+
+    return %r : tensor<?x?x?x?xf32>
+}
+
+// transfer_write in MaskOp case not supported.
+// CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
+//  CHECK-SAME:      %[[ARG0:.*]]: tensor<?x?x?x?xf32>
+//  CHECK-SAME:      %[[ARG1:.*]]: vector<14x8x16xf32>
+//  CHECK-SAME:      %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
+//   CHECK-NOT:   vector.broadcast
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}}permutation_map = #[[MAP:.*]]} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
+func.func @masked_non_permutation_xfer_write_fixed_width(
+    %arg0 : tensor<?x?x?x?xf32>,
+    %v1 : vector<14x8x16xf32>, %dim : index) -> tensor<?x?x?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %dim, %dim, %dim : vector<14x8x16xi1>
+  %0 = vector.mask %mask { vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
+
+  return %0 : tensor<?x?x?x?xf32>
+}
+
 ///----------------------------------------------------------------------------------------
 /// vector.transfer_read
 ///----------------------------------------------------------------------------------------
@@ -101,6 +150,39 @@ func.func @permutation_with_mask_xfer_read_scalable(%mem: memref<?x?xf32>, %dim_
   return %1 : vector<8x[4]x2xf32>
 }
 
+// transfer_read in MaskOp case not supported.
+// CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
+//  CHECK-SAME:        %[[ARG_0:.*]]: tensor<?x1xf32>,
+//  CHECK-SAME:        %[[ARG_1:.*]]: vector<4x1xi1>
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   vector.mask %[[ARG_1]] { vector.transfer_read %[[ARG_0]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
+func.func @masked_permutation_xfer_read_fixed_width(%arg0: tensor<?x1xf32>, %mask : vector<4x1xi1>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %3 = vector.mask %mask { vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>} : tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
+  call @test.some_use(%3) : (vector<1x4x4xf32>) -> ()
+  return
+}
+func.func private @test.some_use(vector<1x4x4xf32>)
+
+// CHECK-LABEL:  func.func @masked_permutation_xfer_read_scalable(
+//  CHECK-SAME:      %[[ARG_0:.*]]: tensor<?x?xf32>,
+//  CHECK-SAME:      %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+//   CHECK-NOT:    vector.transpose
+//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:    %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true], permutation_map = #[[MAP:.*]]} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
+//       CHECK:   return %[[T_READ]] : vector<8x[4]x2xf32>
+func.func @masked_permutation_xfer_read_scalable(%t: tensor<?x?xf32>, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+
+  %c0 = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+
+  %1 = vector.mask %mask { vector.transfer_read %t[%c0, %c0], %cst_0
+    {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>}
+    : tensor<?x?xf32>, vector<8x[4]x2xf32> } :vector<2x[4]xi1> -> vector<8x[4]x2xf32>
+  return %1 : vector<8x[4]x2xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

>From b62a36d5df9ff67a4a3eef727495135263e9e35b Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Tue, 14 May 2024 22:25:02 +0800
Subject: [PATCH 2/4] Fixup: test less for negative tests.

---
 .../Vector/vector-transfer-permutation-lowering.mlir        | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index d63d47fe4481d..a7acffbbbf397 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -54,7 +54,6 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %
 //  CHECK-SAME:        %[[MASK:.*]]: vector<16xi1>
 //   CHECK-NOT:   vector.transpose
 //       CHECK:   %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
-//       CHECK:   return %[[RES]]
 func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor<?x?xf32> {
   %r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
   return %r : tensor<?x?xf32>
@@ -66,9 +65,7 @@ func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val:
 //  CHECK-SAME:        %[[MASK:.*]]: vector<4x[8]xi1>)
 //  CHECK-SAME:        -> tensor<?x?x?x?xf32> {
 //   CHECK-NOT:   vector.transpose
-//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
 //       CHECK:   %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]][%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = #[[MAP:.*]]} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
-//       CHECK:   return %[[R]] : tensor<?x?x?x?xf32>
 func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor<?x?x?x?xf32>, %mask:  vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
      %c0 = arith.constant 0 : index
      %r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
@@ -83,7 +80,6 @@ func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t:
 //  CHECK-SAME:      %[[ARG1:.*]]: vector<14x8x16xf32>
 //  CHECK-SAME:      %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
 //   CHECK-NOT:   vector.broadcast
-//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
 //       CHECK:   %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}}permutation_map = #[[MAP:.*]]} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
 func.func @masked_non_permutation_xfer_write_fixed_width(
     %arg0 : tensor<?x?x?x?xf32>,
@@ -169,9 +165,7 @@ func.func private @test.some_use(vector<1x4x4xf32>)
 //  CHECK-SAME:      %[[ARG_0:.*]]: tensor<?x?xf32>,
 //  CHECK-SAME:      %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
 //   CHECK-NOT:    vector.transpose
-//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
 //       CHECK:    %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true], permutation_map = #[[MAP:.*]]} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
-//       CHECK:   return %[[T_READ]] : vector<8x[4]x2xf32>
 func.func @masked_permutation_xfer_read_scalable(%t: tensor<?x?xf32>, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
 
   %c0 = arith.constant 0 : index

>From 90a6f29e770f217cae0171c5f90b9910f7c5986c Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Wed, 15 May 2024 00:24:13 +0800
Subject: [PATCH 3/4] Fixup MaskableOpRewritePattern when transfer_write has no
 result

---
 .../Vector/Transforms/LowerVectorTransfer.cpp | 26 ++++++++++++-------
 .../vector-transfer-permutation-lowering.mlir |  8 +++---
 2 files changed, 20 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 7f5703b635068..81f7591a7d86f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -217,11 +217,14 @@ struct TransferWritePermutationLowering
         op.getLoc(), op.getVector(), indices);
     auto newMap = AffineMap::getMinorIdentityMap(
         map.getNumDims(), map.getNumResults(), rewriter.getContext());
-    return rewriter
-        .create<vector::TransferWriteOp>(
-            op.getLoc(), newVec, op.getSource(), op.getIndices(),
-            AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr)
-        .getResult();
+    auto newWrite = rewriter.create<vector::TransferWriteOp>(
+        op.getLoc(), newVec, op.getSource(), op.getIndices(),
+        AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
+    if (newWrite.hasPureTensorSemantics())
+      return newWrite.getResult();
+    // In memref case, MaskableOpRewritePattern cannot replaceOp with result.
+    rewriter.eraseOp(op);
+    return failure();
   }
 };
 
@@ -300,11 +303,14 @@ struct TransferWriteNonPermutationLowering
       newInBoundsValues.push_back(op.isDimInBounds(i));
     }
     ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
-    return rewriter
-        .create<vector::TransferWriteOp>(
-            op.getLoc(), newVec, op.getSource(), op.getIndices(),
-            AffineMapAttr::get(newMap), newMask, newInBoundsAttr)
-        .getResult();
+    auto newWrite = rewriter.create<vector::TransferWriteOp>(
+        op.getLoc(), newVec, op.getSource(), op.getIndices(),
+        AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
+    if (newWrite.hasPureTensorSemantics())
+      return newWrite.getResult();
+    // In memref case, MaskableOpRewritePattern cannot replaceOp with result.
+    rewriter.eraseOp(op);
+    return failure();
   }
 };
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index a7acffbbbf397..349dc1ab31d4c 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -53,7 +53,7 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %
 //  CHECK-SAME:        %[[IDX:.*]]: index,
 //  CHECK-SAME:        %[[MASK:.*]]: vector<16xi1>
 //   CHECK-NOT:   vector.transpose
-//       CHECK:   %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+//       CHECK:   %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]]{{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
 func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor<?x?xf32> {
   %r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
   return %r : tensor<?x?xf32>
@@ -65,7 +65,7 @@ func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val:
 //  CHECK-SAME:        %[[MASK:.*]]: vector<4x[8]xi1>)
 //  CHECK-SAME:        -> tensor<?x?x?x?xf32> {
 //   CHECK-NOT:   vector.transpose
-//       CHECK:   %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]][%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = #[[MAP:.*]]} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
+//       CHECK:   %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]]{{.*}} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
 func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor<?x?x?x?xf32>, %mask:  vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
      %c0 = arith.constant 0 : index
      %r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
@@ -80,7 +80,7 @@ func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t:
 //  CHECK-SAME:      %[[ARG1:.*]]: vector<14x8x16xf32>
 //  CHECK-SAME:      %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
 //   CHECK-NOT:   vector.broadcast
-//       CHECK:   %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}}permutation_map = #[[MAP:.*]]} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
+//       CHECK:   %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
 func.func @masked_non_permutation_xfer_write_fixed_width(
     %arg0 : tensor<?x?x?x?xf32>,
     %v1 : vector<14x8x16xf32>, %dim : index) -> tensor<?x?x?x?xf32> {
@@ -165,7 +165,7 @@ func.func private @test.some_use(vector<1x4x4xf32>)
 //  CHECK-SAME:      %[[ARG_0:.*]]: tensor<?x?xf32>,
 //  CHECK-SAME:      %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
 //   CHECK-NOT:    vector.transpose
-//       CHECK:    %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true], permutation_map = #[[MAP:.*]]} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
+//       CHECK:    %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]]{{.*}} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
 func.func @masked_permutation_xfer_read_scalable(%t: tensor<?x?xf32>, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
 
   %c0 = arith.constant 0 : index

>From 7ab6c6e7438476c997fbd8f02c3c88cbb13996ab Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Mon, 13 May 2024 21:54:04 +0800
Subject: [PATCH 4/4] [MLIR][Vector] Implement TransferOpReduceRank  as
 MaskableOpRewritePattern

---
 .../Vector/Transforms/LowerVectorTransfer.cpp | 25 +++++++++++--------
 .../vector-transfer-to-vector-load-store.mlir | 15 +++++++++++
 2 files changed, 30 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 81f7591a7d86f..63d3ec91e512f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -322,14 +322,19 @@ struct TransferWriteNonPermutationLowering
 ///     %v = vector.transfer_read ...
 ///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
 ///     vector.broadcast %v
-struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
+struct TransferOpReduceRank
+    : public MaskableOpRewritePattern<vector::TransferReadOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferReadOp op,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferReadOp op,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    if (maskOp)
+      return rewriter.notifyMatchFailure(op, "Masked case not supported");
 
     AffineMap map = op.getPermutationMap();
     unsigned numLeadingBroadcast = 0;
@@ -369,9 +374,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
             op.getLoc(), originalVecType.getElementType(), op.getSource(),
             op.getIndices());
       }
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
-                                                       newRead);
-      return success();
+      return rewriter
+          .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
+          .getVector();
     }
 
     SmallVector<int64_t> newShape(
@@ -393,9 +398,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
         newInBoundsAttr);
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
-                                                     newRead);
-    return success();
+    return rewriter
+        .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
+        .getVector();
   }
 };
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 2f2bdcaab5b3e..7c50bfa155472 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -219,6 +219,21 @@ func.func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vecto
   return %res : vector<4x4xf32>
 }
 
+// CHECK-LABEL:   func @masked_transfer_read_reduce_rank_with_broadcast(
+// CHECK-SAME:                                %[[MEM:.*]]: memref<8x8x8x8xf32>,
+// CHECK-SAME:                                %[[MASK:.*]]: vector<4x4xi1>,
+// CHECK-SAME:                                %[[IDX:.*]]: index) -> vector<4x4x4x4xf32> {
+//      CHECK:      %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %cst {in_bounds = [true, true, true, true], permutation_map = #map2} : memref<8x8x8x8xf32>, vector<4x4x4x4xf32> } : vector<4x4xi1> -> vector<4x4x4x4xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<4x4x4x4xf32>
+#rank_reducing = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
+func.func @masked_transfer_read_reduce_rank_with_broadcast(%mem : memref<8x8x8x8xf32>, %mask : vector<4x4xi1>, %i : index) -> vector<4x4x4x4xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %res = vector.mask %mask {vector.transfer_read %mem[%i, %i, %i, %i], %cf0
+    {in_bounds = [true, true, true, true], permutation_map = #rank_reducing}
+      : memref<8x8x8x8xf32>, vector<4x4x4x4xf32>} : vector<4x4xi1> -> vector<4x4x4x4xf32>
+  return %res : vector<4x4x4x4xf32>
+}
+
 // More complex broadcasting case (here a `vector.load` is generated).
 // CHECK-LABEL:   func @transfer_broadcasting_complex(
 // CHECK-SAME:                                %[[MEM:.*]]: memref<10x20x30x8x8xf32>,



More information about the Mlir-commits mailing list