[Mlir-commits] [mlir] bcd65ba - [mlir][Vector] Verify that masked ops implement MaskableOpInterface (#108123)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 19 10:17:16 PDT 2024


Author: Diego Caballero
Date: 2024-09-19T10:17:13-07:00
New Revision: bcd65ba6129bea92485432fdd09874bc3fc6671e

URL: https://github.com/llvm/llvm-project/commit/bcd65ba6129bea92485432fdd09874bc3fc6671e
DIFF: https://github.com/llvm/llvm-project/commit/bcd65ba6129bea92485432fdd09874bc3fc6671e.diff

LOG: [mlir][Vector] Verify that masked ops implement MaskableOpInterface (#108123)

This PR fixes a bug in `MaskOp::verifier` that allowed `vector.mask` to
mask operations that did not implement the MaskableOpInterface.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 816447713de417..1438ddd1028bb9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6131,7 +6131,9 @@ LogicalResult MaskOp::verify() {
   Block &block = getMaskRegion().getBlocks().front();
   if (block.getOperations().empty())
     return emitOpError("expects a terminator within the mask region");
-  if (block.getOperations().size() > 2)
+
+  unsigned numMaskRegionOps = block.getOperations().size();
+  if (numMaskRegionOps > 2)
     return emitOpError("expects only one operation to mask");
 
   // Terminator checks.
@@ -6143,11 +6145,14 @@ LogicalResult MaskOp::verify() {
     return emitOpError(
         "expects number of results to match mask region yielded values");
 
-  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
   // Empty vector.mask. Nothing else to check.
-  if (!maskableOp)
+  if (numMaskRegionOps == 1)
     return success();
 
+  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
+  if (!maskableOp)
+    return emitOpError("expects a MaskableOpInterface within the mask region");
+
   // Result checks.
   if (maskableOp->getNumResults() != getNumResults())
     return emitOpError("expects number of results to match maskable operation "

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e71a6eb02ea46c..b7c78de4b5bd89 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2471,13 +2471,15 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
 // -----
 
 // CHECK-LABEL: func @all_true_vector_mask
-//  CHECK-SAME:     %[[IN:.*]]: vector<3x4xf32>
-func.func @all_true_vector_mask(%a : vector<3x4xf32>) -> vector<3x4xf32> {
+//  CHECK-SAME:     %[[IN:.*]]: tensor<3x4xf32>
+func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> {
 //   CHECK-NOT:   vector.mask
-//       CHECK:   %[[ADD:.*]] = arith.addf %[[IN]], %[[IN]] : vector<3x4xf32>
-//       CHECK:   return %[[ADD]] : vector<3x4xf32>
+//       CHECK:   %[[LD:.*]] = vector.transfer_read %[[IN]]
+//       CHECK:   return %[[LD]] : vector<3x4xf32>
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
   %all_true = vector.constant_mask [3, 4] : vector<3x4xi1>
-  %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
+  %0 = vector.mask %all_true { vector.transfer_read %ta[%c0, %c0], %cf0 : tensor<3x4xf32>, vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
   return %0 : vector<3x4xf32>
 }
 

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c95b8bd5ed6147..e2bc5ef6128e7d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1724,6 +1724,14 @@ func.func @vector_mask_passthru_no_return(%val: vector<16xf32>, %t0: tensor<?xf3
   vector.mask %m0, %pt0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> vector<16xf32>
   return
 }
+// -----
+
+func.func @vector_mask_non_maskable_op(%a : vector<3x4xf32>) -> vector<3x4xf32> {
+   %m0 = vector.constant_mask [2, 2] : vector<3x4xi1>
+  // expected-error at +1 {{'vector.mask' op expects a MaskableOpInterface within the mask region}}
+   %0 = vector.mask %m0 { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
+   return %0 : vector<3x4xf32>
+}
 
 // -----
 


        


More information about the Mlir-commits mailing list