[Mlir-commits] [mlir] [mlir][vector] Add folders for full constant transfer masks (PR #71676)
Quinn Dawkins
llvmlistbot at llvm.org
Wed Nov 8 05:50:18 PST 2023
https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/71676
When the mask bounds of a `vector.constant_mask` exactly equal the shape of the vector, any transfer op consuming that mask will be unaffected by it. Drop the mask in such cases.
>From c75afe7def0e1ff1a378941dd375660efdd4d0fb Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sun, 5 Nov 2023 16:03:10 -0500
Subject: [PATCH] [MLIR][Vector] Add folders for full constant transfer masks
When the mask bounds of a `vector.constant_mask` exactly equal the shape
of the vector, any transfer op consuming that mask will be unaffected by
it. Drop the mask in such cases.
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 10 +++++
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 37 +++++++++++++++++++
mlir/test/Dialect/Vector/canonicalize.mlir | 23 ++++++++++++
3 files changed, 70 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index bbff7e2d18b01b4..3a6d81b59aeb1f9 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2276,6 +2276,16 @@ def Vector_ConstantMaskOp :
```
}];
+ let extraClassDeclaration = [{
+ /// Return the result type of this op.
+ VectorType getVectorType() {
+ return cast<VectorType>(getOperation()->getResultTypes()[0]);
+ }
+
+ /// Return whether the mask is a uniform vector of `1`s.
+ bool isFullMask();
+ }];
+
let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 69cbdcd3f536f98..9a2b71e459f56af 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3937,6 +3937,23 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
return success();
}
+template <typename TransferOp>
+static LogicalResult foldTransferFullMask(TransferOp op) {
+ auto mask = op.getMask();
+ if (!mask)
+ return failure();
+
+ auto constantMask = mask.template getDefiningOp<vector::ConstantMaskOp>();
+ if (!constantMask)
+ return failure();
+
+ if (!constantMask.isFullMask())
+ return failure();
+
+ op.getMaskMutable().clear();
+ return success();
+}
+
/// ```
/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
/// : vector<1x4xf32>, tensor<4x4xf32>
@@ -3969,6 +3986,8 @@ OpFoldResult TransferReadOp::fold(FoldAdaptor) {
/// transfer_read(memrefcast) -> transfer_read
if (succeeded(foldTransferInBoundsAttribute(*this)))
return getResult();
+ if (succeeded(foldTransferFullMask(*this)))
+ return getResult();
if (succeeded(memref::foldMemRefCast(*this)))
return getResult();
if (succeeded(tensor::foldTensorCast(*this)))
@@ -4334,6 +4353,8 @@ LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
return success();
if (succeeded(foldTransferInBoundsAttribute(*this)))
return success();
+ if (succeeded(foldTransferFullMask(*this)))
+ return success();
return memref::foldMemRefCast(*this);
}
@@ -5601,6 +5622,22 @@ LogicalResult ConstantMaskOp::verify() {
return success();
}
+bool ConstantMaskOp::isFullMask() {
+ auto resultType = getVectorType();
+ // Check the corner case of 0-D vectors first.
+ if (resultType.getRank() == 0) {
+ assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
+ return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
+ }
+ for (const auto [resultSize, intAttr] :
+ llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
+ int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
+ if (maskDimSize < resultSize)
+ return false;
+ }
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// CreateMaskOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 163fdd67b0cfd34..1021c73cc57d341 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -854,6 +854,29 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
// -----
+// CHECK-LABEL: fold_vector_transfer_masks
+func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>) {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+ %f0 = arith.constant 0.0 : f32
+
+ %mask = vector.constant_mask [8, 4] : vector<8x4xi1>
+
+ // CHECK: vector.transfer_read %{{.*}}, %[[F0]] {permutation_map
+ %1 = vector.transfer_read %A[%c0, %c0], %f0, %mask
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<?x?xf32>, vector<4x8xf32>
+
+ // CHECK: vector.transfer_write {{.*}}[%[[C0]], %[[C0]]] {permutation_map
+ vector.transfer_write %1, %A[%c0, %c0], %mask
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<4x8xf32>, memref<?x?xf32>
+
+ // CHECK: return
+ return %1 : vector<4x8xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfers
func.func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list