[Mlir-commits] [mlir] [mlir][vector] Add folders for full constant transfer masks (PR #71676)

Quinn Dawkins llvmlistbot at llvm.org
Wed Nov 8 05:50:18 PST 2023


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/71676

When the mask bounds of a `vector.constant_mask` exactly equal the shape of the vector, any transfer op consuming that mask will be unaffected by it. Drop the mask in such cases.

>From c75afe7def0e1ff1a378941dd375660efdd4d0fb Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sun, 5 Nov 2023 16:03:10 -0500
Subject: [PATCH] [MLIR][Vector] Add folders for full constant transfer masks

When the mask bounds of a `vector.constant_mask` exactly equal the shape
of the vector, any transfer op consuming that mask will be unaffected by
it. Drop the mask in such cases.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 10 +++++
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 37 +++++++++++++++++++
 mlir/test/Dialect/Vector/canonicalize.mlir    | 23 ++++++++++++
 3 files changed, 70 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index bbff7e2d18b01b4..3a6d81b59aeb1f9 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2276,6 +2276,16 @@ def Vector_ConstantMaskOp :
     ```
   }];
 
+  let extraClassDeclaration = [{
+    /// Return the result type of this op.
+    VectorType getVectorType() {
+      return cast<VectorType>(getOperation()->getResultTypes()[0]);
+    }
+
+    /// Return whether the mask is a uniform vector of `1`s.
+    bool isFullMask();
+  }];
+
   let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 69cbdcd3f536f98..9a2b71e459f56af 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3937,6 +3937,23 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
   return success();
 }
 
+template <typename TransferOp>
+static LogicalResult foldTransferFullMask(TransferOp op) {
+  auto mask = op.getMask();
+  if (!mask)
+    return failure();
+
+  auto constantMask = mask.template getDefiningOp<vector::ConstantMaskOp>();
+  if (!constantMask)
+    return failure();
+
+  if (!constantMask.isFullMask())
+    return failure();
+
+  op.getMaskMutable().clear();
+  return success();
+}
+
 ///  ```
 ///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
 ///    : vector<1x4xf32>, tensor<4x4xf32>
@@ -3969,6 +3986,8 @@ OpFoldResult TransferReadOp::fold(FoldAdaptor) {
   /// transfer_read(memrefcast) -> transfer_read
   if (succeeded(foldTransferInBoundsAttribute(*this)))
     return getResult();
+  if (succeeded(foldTransferFullMask(*this)))
+    return getResult();
   if (succeeded(memref::foldMemRefCast(*this)))
     return getResult();
   if (succeeded(tensor::foldTensorCast(*this)))
@@ -4334,6 +4353,8 @@ LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
     return success();
   if (succeeded(foldTransferInBoundsAttribute(*this)))
     return success();
+  if (succeeded(foldTransferFullMask(*this)))
+    return success();
   return memref::foldMemRefCast(*this);
 }
 
@@ -5601,6 +5622,22 @@ LogicalResult ConstantMaskOp::verify() {
   return success();
 }
 
+bool ConstantMaskOp::isFullMask() {
+  auto resultType = getVectorType();
+  // Check the corner case of 0-D vectors first.
+  if (resultType.getRank() == 0) {
+    assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
+    return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
+  }
+  for (const auto [resultSize, intAttr] :
+       llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
+    int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
+    if (maskDimSize < resultSize)
+      return false;
+  }
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // CreateMaskOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 163fdd67b0cfd34..1021c73cc57d341 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -854,6 +854,29 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
 
 // -----
 
+// CHECK-LABEL: fold_vector_transfer_masks
+func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>) {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+  %f0 = arith.constant 0.0 : f32
+
+  %mask = vector.constant_mask [8, 4] : vector<8x4xi1>
+
+  // CHECK: vector.transfer_read %{{.*}}, %[[F0]] {permutation_map
+  %1 = vector.transfer_read %A[%c0, %c0], %f0, %mask
+      {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<?x?xf32>, vector<4x8xf32>
+
+  // CHECK: vector.transfer_write {{.*}}[%[[C0]], %[[C0]]] {permutation_map
+  vector.transfer_write %1, %A[%c0, %c0], %mask
+      {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<4x8xf32>, memref<?x?xf32>
+
+  // CHECK: return
+  return %1 : vector<4x8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfers
 func.func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
   %c0 = arith.constant 0 : index



More information about the Mlir-commits mailing list