[Mlir-commits] [mlir] [mlir][vector] Support more mask types in foldTransferFullMask() (PR #96761)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Jun 27 04:14:02 PDT 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/96761
>From fd499c37b0df2d90eac71ffaf11ab3ffe2bf01e5 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 26 Jun 2024 12:51:01 +0000
Subject: [PATCH 1/2] [mlir][vector] Support more mask types in
foldTransferFullMask()
Using the existing `getMaskFormat()` this can be extended to support
`arith.constant` masks.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 6 +-----
mlir/test/Dialect/Vector/canonicalize.mlir | 12 ++++++++++--
2 files changed, 11 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6734c80f2760d..149723f51cc12 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4172,11 +4172,7 @@ static LogicalResult foldTransferFullMask(TransferOp op) {
if (!mask)
return failure();
- auto constantMask = mask.template getDefiningOp<vector::ConstantMaskOp>();
- if (!constantMask)
- return failure();
-
- if (!constantMask.isAllOnesMask())
+ if (getMaskFormat(mask) != MaskFormat::AllTrue)
return failure();
op.getMaskMutable().clear();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8181f1a8c5d13..ecd49df3b2141 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -868,7 +868,7 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
// -----
// CHECK-LABEL: fold_vector_transfer_masks
-func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>) {
+func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
@@ -876,6 +876,8 @@ func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>)
%mask = vector.constant_mask [8, 4] : vector<8x4xi1>
+ %mask_splat = arith.constant dense<true> : vector<4x[4]xi1>
+
// CHECK: vector.transfer_read %{{.*}}, %[[F0]] {permutation_map
%1 = vector.transfer_read %A[%c0, %c0], %f0, %mask
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<?x?xf32>, vector<4x8xf32>
@@ -884,8 +886,14 @@ func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>)
vector.transfer_write %1, %A[%c0, %c0], %mask
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<4x8xf32>, memref<?x?xf32>
+ // CHECK: vector.transfer_read %{{.*}}, %[[F0]] :
+ %2 = vector.transfer_read %A[%c0, %c0], %f0, %mask_splat : memref<?x?xf32>, vector<4x[4]xf32>
+
+ // CHECK: vector.transfer_write {{.*}}[%[[C0]], %[[C0]]] :
+ vector.transfer_write %2, %A[%c0, %c0], %mask_splat : vector<4x[4]xf32>, memref<?x?xf32>
+
// CHECK: return
- return %1 : vector<4x8xf32>
+ return %1, %2 : vector<4x8xf32>, vector<4x[4]xf32>
}
// -----
>From fb57ecd8e6ac1fa9cbede6163a6137f141f4c09f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 27 Jun 2024 11:13:01 +0000
Subject: [PATCH 2/2] Fixups
---
mlir/test/Dialect/Vector/canonicalize.mlir | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index ecd49df3b2141..fc5651f5bb02f 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -876,7 +876,7 @@ func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>,
%mask = vector.constant_mask [8, 4] : vector<8x4xi1>
- %mask_splat = arith.constant dense<true> : vector<4x[4]xi1>
+ %arith_all_true_mask = arith.constant dense<true> : vector<4x[4]xi1>
// CHECK: vector.transfer_read %{{.*}}, %[[F0]] {permutation_map
%1 = vector.transfer_read %A[%c0, %c0], %f0, %mask
@@ -887,10 +887,10 @@ func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>,
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<4x8xf32>, memref<?x?xf32>
// CHECK: vector.transfer_read %{{.*}}, %[[F0]] :
- %2 = vector.transfer_read %A[%c0, %c0], %f0, %mask_splat : memref<?x?xf32>, vector<4x[4]xf32>
+ %2 = vector.transfer_read %A[%c0, %c0], %f0, %arith_all_true_mask : memref<?x?xf32>, vector<4x[4]xf32>
// CHECK: vector.transfer_write {{.*}}[%[[C0]], %[[C0]]] :
- vector.transfer_write %2, %A[%c0, %c0], %mask_splat : vector<4x[4]xf32>, memref<?x?xf32>
+ vector.transfer_write %2, %A[%c0, %c0], %arith_all_true_mask : vector<4x[4]xf32>, memref<?x?xf32>
// CHECK: return
return %1, %2 : vector<4x8xf32>, vector<4x[4]xf32>
More information about the Mlir-commits
mailing list