[Mlir-commits] [mlir] f94131a - [mlir][vector] Support multiple result types in vector.mask
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 13 07:59:44 PST 2023
Author: Matthias Springer
Date: 2023-01-13T16:59:36+01:00
New Revision: f94131a2a502118e0507164dcef160d1cfecb316
URL: https://github.com/llvm/llvm-project/commit/f94131a2a502118e0507164dcef160d1cfecb316
DIFF: https://github.com/llvm/llvm-project/commit/f94131a2a502118e0507164dcef160d1cfecb316.diff
LOG: [mlir][vector] Support multiple result types in vector.mask
The verifier already had support for multiple result types, but the op definition assumed a single, optional result.
Differential Revision: https://reviews.llvm.org/D141683
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 04af8d3d80af2..5a14f0da52b16 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2287,10 +2287,13 @@ def Vector_MaskOp : Vector_Op<"mask", [
The `vector.mask` is a `MaskingOpInterface` operation that predicates the
execution of another operation. It takes an `i1` vector mask and an
optional passthru vector as arguments.
- A `vector.yield`-terminated region encloses the operation to be masked.
- Values used within the region are captured from above. Only one *maskable*
- operation can be masked with a `vector.mask` operation at a time. An
- operation is *maskable* if it implements the `MaskableOpInterface`.
+
+ A implicitly `vector.yield`-terminated region encloses the operation to be
+ masked. Values used within the region are captured from above. Only one
+ *maskable* operation can be masked with a `vector.mask` operation at a time.
+ An operation is *maskable* if it implements the `MaskableOpInterface`. The
+ terminator yields all results of the maskable operation to the result of
+ this operation.
The vector mask argument holds a bit for each vector lane and determines
which vector lanes should execute the maskable operation and which ones
@@ -2321,12 +2324,16 @@ def Vector_MaskOp : Vector_Op<"mask", [
```
vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref<?xf32> } : vector<16xi1>
```
+
+ ```
+ vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> tensor<?xf32>
+ ```
}];
// TODO: Support multiple results and passthru values.
let arguments = (ins VectorOf<[I1]>:$mask,
Optional<AnyType>:$passthru);
- let results = (outs Optional<AnyType>:$results);
+ let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$maskRegion);
let skipDefaultBuilders = 1;
@@ -2334,10 +2341,10 @@ def Vector_MaskOp : Vector_Op<"mask", [
OpBuilder<(ins "Value":$mask,
CArg<"function_ref<void(OpBuilder &, Location)>",
"buildTerminatedBody">:$maskRegion)>,
- OpBuilder<(ins "Type":$resultType, "Value":$mask,
+ OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
CArg<"function_ref<void(OpBuilder &, Location)>",
"buildTerminatedBody">:$maskRegion)>,
- OpBuilder<(ins "Type":$resultType, "Value":$mask,
+ OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
"Value":$passthru,
CArg<"function_ref<void(OpBuilder &, Location)>",
"buildTerminatedBody">:$maskRegion)>
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e2a3e619164e2..f00d8494e3151 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5288,20 +5288,20 @@ void MaskOp::build(
}
void MaskOp::build(
- OpBuilder &builder, OperationState &result, Type resultType, Value mask,
- function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
- build(builder, result, resultType, mask, /*passthru=*/Value(),
+ OpBuilder &builder, OperationState &result, TypeRange resultTypes,
+ Value mask, function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
+ build(builder, result, resultTypes, mask, /*passthru=*/Value(),
maskRegionBuilder);
}
void MaskOp::build(
- OpBuilder &builder, OperationState &result, Type resultType, Value mask,
- Value passthru,
+ OpBuilder &builder, OperationState &result, TypeRange resultTypes,
+ Value mask, Value passthru,
function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
build(builder, result, mask, maskRegionBuilder);
if (passthru)
result.addOperands(passthru);
- result.addTypes(resultType);
+ result.addTypes(resultTypes);
}
ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
More information about the Mlir-commits
mailing list