[Mlir-commits] [mlir] 2731d26 - [mlir][vector] Support more mask types in foldTransferFullMask() (#96761)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 27 04:14:34 PDT 2024


Author: Benjamin Maxwell
Date: 2024-06-27T12:14:30+01:00
New Revision: 2731d26948384f1ec2c30ce6692c60e9414ea2ec

URL: https://github.com/llvm/llvm-project/commit/2731d26948384f1ec2c30ce6692c60e9414ea2ec
DIFF: https://github.com/llvm/llvm-project/commit/2731d26948384f1ec2c30ce6692c60e9414ea2ec.diff

LOG: [mlir][vector] Support more mask types in foldTransferFullMask() (#96761)

Using the existing `getMaskFormat()` this can be extended to support
`arith.constant` masks.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6734c80f2760d..149723f51cc12 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4172,11 +4172,7 @@ static LogicalResult foldTransferFullMask(TransferOp op) {
   if (!mask)
     return failure();
 
-  auto constantMask = mask.template getDefiningOp<vector::ConstantMaskOp>();
-  if (!constantMask)
-    return failure();
-
-  if (!constantMask.isAllOnesMask())
+  if (getMaskFormat(mask) != MaskFormat::AllTrue)
     return failure();
 
   op.getMaskMutable().clear();

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8181f1a8c5d13..fc5651f5bb02f 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -868,7 +868,7 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
 // -----
 
 // CHECK-LABEL: fold_vector_transfer_masks
-func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>) {
+func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
   // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
@@ -876,6 +876,8 @@ func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>)
 
   %mask = vector.constant_mask [8, 4] : vector<8x4xi1>
 
+  %arith_all_true_mask = arith.constant dense<true> : vector<4x[4]xi1>
+
   // CHECK: vector.transfer_read %{{.*}}, %[[F0]] {permutation_map
   %1 = vector.transfer_read %A[%c0, %c0], %f0, %mask
       {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<?x?xf32>, vector<4x8xf32>
@@ -884,8 +886,14 @@ func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>)
   vector.transfer_write %1, %A[%c0, %c0], %mask
       {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<4x8xf32>, memref<?x?xf32>
 
+  // CHECK: vector.transfer_read %{{.*}}, %[[F0]] :
+  %2 = vector.transfer_read %A[%c0, %c0], %f0, %arith_all_true_mask : memref<?x?xf32>, vector<4x[4]xf32>
+
+  // CHECK: vector.transfer_write {{.*}}[%[[C0]], %[[C0]]] :
+  vector.transfer_write %2, %A[%c0, %c0], %arith_all_true_mask : vector<4x[4]xf32>, memref<?x?xf32>
+
   // CHECK: return
-  return %1 : vector<4x8xf32>
+  return %1, %2 : vector<4x8xf32>, vector<4x[4]xf32>
 }
 
 // -----


        


More information about the Mlir-commits mailing list