[Mlir-commits] [mlir] [mlir][Vector] Verify that masked ops implement MaskableOpInterface (PR #108123)
Diego Caballero
llvmlistbot at llvm.org
Tue Sep 10 17:49:19 PDT 2024
https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/108123
This PR fixes a bug in `MaskOp::verifier` that allowed `vector.mask` to mask operations that did not implement the MaskableOpInterface.
>From 6c4bc5fb2f94a5fa1fe3fe64cb876ceb9cf61973 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Tue, 10 Sep 2024 17:43:47 -0700
Subject: [PATCH] [mlir][Vector] Verify that masked ops implement
MaskableOpInterface
This PR fixes a bug in `MaskOp::verifier` that allowed `vector.mask` to mask
operations that did not implement the MaskableOpInterface.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 11 ++++++++---
mlir/test/Dialect/Vector/canonicalize.mlir | 12 +++++++-----
mlir/test/Dialect/Vector/invalid.mlir | 8 ++++++++
3 files changed, 23 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..62f9943e93b9cf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6124,7 +6124,9 @@ LogicalResult MaskOp::verify() {
Block &block = getMaskRegion().getBlocks().front();
if (block.getOperations().empty())
return emitOpError("expects a terminator within the mask region");
- if (block.getOperations().size() > 2)
+
+ unsigned numMaskRegionOps = block.getOperations().size();
+ if (numMaskRegionOps > 2)
return emitOpError("expects only one operation to mask");
// Terminator checks.
@@ -6136,11 +6138,14 @@ LogicalResult MaskOp::verify() {
return emitOpError(
"expects number of results to match mask region yielded values");
- auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
// Empty vector.mask. Nothing else to check.
- if (!maskableOp)
+ if (numMaskRegionOps == 1)
return success();
+ auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
+ if (!maskableOp)
+ return emitOpError("expects a MaskableOpInterface within the mask region");
+
// Result checks.
if (maskableOp->getNumResults() != getNumResults())
return emitOpError("expects number of results to match maskable operation "
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e71a6eb02ea46c..b7c78de4b5bd89 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2471,13 +2471,15 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
// -----
// CHECK-LABEL: func @all_true_vector_mask
-// CHECK-SAME: %[[IN:.*]]: vector<3x4xf32>
-func.func @all_true_vector_mask(%a : vector<3x4xf32>) -> vector<3x4xf32> {
+// CHECK-SAME: %[[IN:.*]]: tensor<3x4xf32>
+func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> {
// CHECK-NOT: vector.mask
-// CHECK: %[[ADD:.*]] = arith.addf %[[IN]], %[[IN]] : vector<3x4xf32>
-// CHECK: return %[[ADD]] : vector<3x4xf32>
+// CHECK: %[[LD:.*]] = vector.transfer_read %[[IN]]
+// CHECK: return %[[LD]] : vector<3x4xf32>
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
%all_true = vector.constant_mask [3, 4] : vector<3x4xi1>
- %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
+ %0 = vector.mask %all_true { vector.transfer_read %ta[%c0, %c0], %cf0 : tensor<3x4xf32>, vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
return %0 : vector<3x4xf32>
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c95b8bd5ed6147..e2bc5ef6128e7d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1724,6 +1724,14 @@ func.func @vector_mask_passthru_no_return(%val: vector<16xf32>, %t0: tensor<?xf3
vector.mask %m0, %pt0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> vector<16xf32>
return
}
+// -----
+
+func.func @vector_mask_non_maskable_op(%a : vector<3x4xf32>) -> vector<3x4xf32> {
+ %m0 = vector.constant_mask [2, 2] : vector<3x4xi1>
+ // expected-error at +1 {{'vector.mask' op expects a MaskableOpInterface within the mask region}}
+ %0 = vector.mask %m0 { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
+ return %0 : vector<3x4xf32>
+}
// -----
More information about the Mlir-commits
mailing list