[Mlir-commits] [mlir] [MLIR][Vector] Implement TransferOpReduceRank as MaskableOpRewritePattern (PR #92426)

Hugo Trachino llvmlistbot at llvm.org
Wed May 22 01:53:15 PDT 2024


https://github.com/nujaa updated https://github.com/llvm/llvm-project/pull/92426

>From d4388526945d0950c0425a822cc1a041bfcd610a Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Mon, 13 May 2024 21:54:04 +0800
Subject: [PATCH 1/2] [MLIR][Vector] Implement TransferOpReduceRank  as
 MaskableOpRewritePattern

---
 .../Vector/Transforms/LowerVectorTransfer.cpp | 25 +++++++++++--------
 .../vector-transfer-to-vector-load-store.mlir | 15 +++++++++++
 2 files changed, 30 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index c59012266ceb3..d2692103e76e6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -322,14 +322,19 @@ struct TransferWriteNonPermutationLowering
 ///     %v = vector.transfer_read ...
 ///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
 ///     vector.broadcast %v
-struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
+struct TransferOpReduceRank
+    : public MaskableOpRewritePattern<vector::TransferReadOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferReadOp op,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferReadOp op,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    if (maskOp)
+      return rewriter.notifyMatchFailure(op, "Masked case not supported");
 
     AffineMap map = op.getPermutationMap();
     unsigned numLeadingBroadcast = 0;
@@ -369,9 +374,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
             op.getLoc(), originalVecType.getElementType(), op.getSource(),
             op.getIndices());
       }
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
-                                                       newRead);
-      return success();
+      return rewriter
+          .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
+          .getVector();
     }
 
     SmallVector<int64_t> newShape(
@@ -393,9 +398,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
         newInBoundsAttr);
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
-                                                     newRead);
-    return success();
+    return rewriter
+        .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
+        .getVector();
   }
 };
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 2f2bdcaab5b3e..7c50bfa155472 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -219,6 +219,21 @@ func.func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vecto
   return %res : vector<4x4xf32>
 }
 
+// CHECK-LABEL:   func @masked_transfer_read_reduce_rank_with_broadcast(
+// CHECK-SAME:                                %[[MEM:.*]]: memref<8x8x8x8xf32>,
+// CHECK-SAME:                                %[[MASK:.*]]: vector<4x4xi1>,
+// CHECK-SAME:                                %[[IDX:.*]]: index) -> vector<4x4x4x4xf32> {
+//      CHECK:      %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %cst {in_bounds = [true, true, true, true], permutation_map = #map2} : memref<8x8x8x8xf32>, vector<4x4x4x4xf32> } : vector<4x4xi1> -> vector<4x4x4x4xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<4x4x4x4xf32>
+#rank_reducing = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
+func.func @masked_transfer_read_reduce_rank_with_broadcast(%mem : memref<8x8x8x8xf32>, %mask : vector<4x4xi1>, %i : index) -> vector<4x4x4x4xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %res = vector.mask %mask {vector.transfer_read %mem[%i, %i, %i, %i], %cf0
+    {in_bounds = [true, true, true, true], permutation_map = #rank_reducing}
+      : memref<8x8x8x8xf32>, vector<4x4x4x4xf32>} : vector<4x4xi1> -> vector<4x4x4x4xf32>
+  return %res : vector<4x4x4x4xf32>
+}
+
 // More complex broadcasting case (here a `vector.load` is generated).
 // CHECK-LABEL:   func @transfer_broadcasting_complex(
 // CHECK-SAME:                                %[[MEM:.*]]: memref<10x20x30x8x8xf32>,

>From d1acbb16b849d23b902ac62687dde58349e99de9 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Thu, 16 May 2024 23:55:48 +0800
Subject: [PATCH 2/2] Move test to vector-transfer-permutation-lowering.mlir

---
 .../Vector/Transforms/LowerVectorTransfer.cpp |  1 +
 .../vector-transfer-permutation-lowering.mlir | 46 +++++++++++++++++++
 .../vector-transfer-to-vector-load-store.mlir | 15 ------
 3 files changed, 47 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index d2692103e76e6..6bfb2eb689418 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -333,6 +333,7 @@ struct TransferOpReduceRank
     // TODO: support 0-d corner case.
     if (op.getTransferRank() == 0)
       return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+    // TODO: support masked case.
     if (maskOp)
       return rewriter.notifyMatchFailure(op, "Masked case not supported");
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 349dc1ab31d4c..3fba65079ec0b 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -187,3 +187,49 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+
+//       CHECK:   #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
+//       CHECK:   func.func @transfer_read_reduce_rank_scalable(
+//  CHECK-SAME:       %[[ARG_0:.*]]: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
+//       CHECK:     %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:     %[[TFR:.*]] = vector.transfer_read %arg0[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
+//       CHECK:     %[[BC:.*]] = vector.broadcast %[[TFR]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
+//       CHECK:     return %[[BC]] : vector<8x[4]x2x3xf32>
+func.func @transfer_read_reduce_rank_scalable(%mem: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
+  %c0 = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %1 = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0
+    {in_bounds = [true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>}
+    : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>
+  return %1 : vector<8x[4]x2x3xf32>
+}
+
+// Masked case not supported.
+// CHECK-LABEL:   func.func @masked_transfer_read_reduce_rank(
+//  CHECK-SAME:       %[[ARG_0:.*]]: memref<?x?x?x?xf32>,
+//  CHECK-SAME:       %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> {
+//   CHECK-NOT:     vector.broadcast
+//       CHECK:     %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %arg0{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
+func.func @masked_transfer_read_reduce_rank(%mem: memref<?x?x?x?xf32>, %dim: index) -> vector<8x[4]x2x3xf32> {
+  %c0 = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %mask = vector.create_mask %dim, %dim: vector<[4]x3xi1>
+  %res = vector.mask %mask { vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0
+    {in_bounds = [true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>}
+    : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
+  return %res : vector<8x[4]x2x3xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.any_op
+    transform.yield
+  }
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 7c50bfa155472..2f2bdcaab5b3e 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -219,21 +219,6 @@ func.func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vecto
   return %res : vector<4x4xf32>
 }
 
-// CHECK-LABEL:   func @masked_transfer_read_reduce_rank_with_broadcast(
-// CHECK-SAME:                                %[[MEM:.*]]: memref<8x8x8x8xf32>,
-// CHECK-SAME:                                %[[MASK:.*]]: vector<4x4xi1>,
-// CHECK-SAME:                                %[[IDX:.*]]: index) -> vector<4x4x4x4xf32> {
-//      CHECK:      %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %cst {in_bounds = [true, true, true, true], permutation_map = #map2} : memref<8x8x8x8xf32>, vector<4x4x4x4xf32> } : vector<4x4xi1> -> vector<4x4x4x4xf32>
-// CHECK-NEXT:      return %[[RES]] : vector<4x4x4x4xf32>
-#rank_reducing = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
-func.func @masked_transfer_read_reduce_rank_with_broadcast(%mem : memref<8x8x8x8xf32>, %mask : vector<4x4xi1>, %i : index) -> vector<4x4x4x4xf32> {
-  %cf0 = arith.constant 0.0 : f32
-  %res = vector.mask %mask {vector.transfer_read %mem[%i, %i, %i, %i], %cf0
-    {in_bounds = [true, true, true, true], permutation_map = #rank_reducing}
-      : memref<8x8x8x8xf32>, vector<4x4x4x4xf32>} : vector<4x4xi1> -> vector<4x4x4x4xf32>
-  return %res : vector<4x4x4x4xf32>
-}
-
 // More complex broadcasting case (here a `vector.load` is generated).
 // CHECK-LABEL:   func @transfer_broadcasting_complex(
 // CHECK-SAME:                                %[[MEM:.*]]: memref<10x20x30x8x8xf32>,



More information about the Mlir-commits mailing list