[Mlir-commits] [mlir] 3669d07 - [mlir][linalg] Only apply masking on xfer_write when needed.

Hanhan Wang llvmlistbot at llvm.org
Wed May 24 18:24:32 PDT 2023


Author: Hanhan Wang
Date: 2023-05-24T18:24:19-07:00
New Revision: 3669d07987a7fe142db6d911b64a356bd9b5a0a3

URL: https://github.com/llvm/llvm-project/commit/3669d07987a7fe142db6d911b64a356bd9b5a0a3
DIFF: https://github.com/llvm/llvm-project/commit/3669d07987a7fe142db6d911b64a356bd9b5a0a3.diff

LOG: [mlir][linalg] Only apply masking on xfer_write when needed.

If the input vector sizes are as same as tensor.pad result shape, the
masking is not needed. Otherwise, the masking is needed and the masking
operands should be as same as tensor.empty op.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D151391

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization-masked.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b9ccbd28038a8..aae36035eeece 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1325,7 +1325,14 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
       /*source=*/emptyOp,
       /*indices=*/SmallVector<Value>(rank, zero),
       /*inBounds=*/SmallVector<bool>(rank, true));
-  write = mlir::vector::maskOperation(rewriter, write, mask);
+  bool needMaskForWrite = llvm::any_of(
+      llvm::zip_equal(inputVectorSizes, padOp.getResultType().getShape()),
+      [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+  if (needMaskForWrite) {
+    Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
+        loc, maskType, reifiedReturnShapes[0]);
+    write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
+  }
   newResults.push_back(write->getResult(0));
   return success();
 }

diff  --git a/mlir/test/Dialect/Linalg/vectorization-masked.mlir b/mlir/test/Dialect/Linalg/vectorization-masked.mlir
index 7b363e7df61db..65b8b5b38461e 100644
--- a/mlir/test/Dialect/Linalg/vectorization-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-masked.mlir
@@ -405,14 +405,17 @@ transform.sequence failures(propagate) {
 
 // -----
 
-// CHECK-LABEL: func @test_masked_vectorize_dynamic_pad
+//       CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s1 + s0)>
+//       CHECK: func @test_masked_vectorize_dynamic_pad
 func.func @test_masked_vectorize_dynamic_pad(
   %0 : tensor<?x?xf32>, %h0 : index, %h1 : index)
     -> tensor<?x?xf32>
 {
   //  CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32
   //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
-  //      CHECK: %[[empty:.*]] = tensor.empty({{.+}}) : tensor<?x?xf32>
+  //  CHECK-DAG: %[[res_d0:.+]] = affine.apply #[[MAP]]()
+  //  CHECK-DAG: %[[res_d1:.+]] = affine.apply #[[MAP]]()
+  //  CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[res_d0]], %[[res_d1]]) : tensor<?x?xf32>
   //      CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
   //      CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
   //      CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
@@ -421,7 +424,8 @@ func.func @test_masked_vectorize_dynamic_pad(
   // CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]]
   // CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
   // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
-  //      CHECK: %[[masked_write:.*]] = vector.mask %[[mask]] {
+  //      CHECK: %[[mask_2:.*]] = vector.create_mask %[[res_d0]], %[[res_d1]] : vector<2x4xi1>
+  //      CHECK: %[[masked_write:.*]] = vector.mask %[[mask_2]] {
   // CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_2]], %[[c0_2]]]
   // CHECK-SAME:   {in_bounds = [true, true]} : vector<2x4xf32>, tensor<?x?xf32>
   //      CHECK: return %[[masked_write]] : tensor<?x?xf32>


        


More information about the Mlir-commits mailing list