[Mlir-commits] [mlir] f94131a - [mlir][vector] Support multiple result types in vector.mask

Matthias Springer llvmlistbot at llvm.org
Fri Jan 13 07:59:44 PST 2023


Author: Matthias Springer
Date: 2023-01-13T16:59:36+01:00
New Revision: f94131a2a502118e0507164dcef160d1cfecb316

URL: https://github.com/llvm/llvm-project/commit/f94131a2a502118e0507164dcef160d1cfecb316
DIFF: https://github.com/llvm/llvm-project/commit/f94131a2a502118e0507164dcef160d1cfecb316.diff

LOG: [mlir][vector] Support multiple result types in vector.mask

The verifier already had support for multiple result types, but the op definition assumed a single, optional result.

Differential Revision: https://reviews.llvm.org/D141683

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 04af8d3d80af2..5a14f0da52b16 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2287,10 +2287,13 @@ def Vector_MaskOp : Vector_Op<"mask", [
     The `vector.mask` is a `MaskingOpInterface` operation that predicates the
     execution of another operation. It takes an `i1` vector mask and an
     optional passthru vector as arguments.
-    A `vector.yield`-terminated region encloses the operation to be masked.
-    Values used within the region are captured from above. Only one *maskable*
-    operation can be masked with a `vector.mask` operation at a time. An
-    operation is *maskable* if it implements the `MaskableOpInterface`.
+
+    A implicitly `vector.yield`-terminated region encloses the operation to be
+    masked. Values used within the region are captured from above. Only one
+    *maskable* operation can be masked with a `vector.mask` operation at a time.
+    An operation is *maskable* if it implements the `MaskableOpInterface`. The
+    terminator yields all results of the maskable operation to the result of
+    this operation.
 
     The vector mask argument holds a bit for each vector lane and determines
     which vector lanes should execute the maskable operation and which ones
@@ -2321,12 +2324,16 @@ def Vector_MaskOp : Vector_Op<"mask", [
     ```
       vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref<?xf32> } : vector<16xi1>
     ```
+
+    ```
+      vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> tensor<?xf32>
+    ```
   }];
 
   // TODO: Support multiple results and passthru values.
   let arguments = (ins VectorOf<[I1]>:$mask,
                    Optional<AnyType>:$passthru);
-  let results = (outs Optional<AnyType>:$results);
+  let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$maskRegion);
 
   let skipDefaultBuilders = 1;
@@ -2334,10 +2341,10 @@ def Vector_MaskOp : Vector_Op<"mask", [
     OpBuilder<(ins "Value":$mask,
                    CArg<"function_ref<void(OpBuilder &, Location)>",
                         "buildTerminatedBody">:$maskRegion)>,
-    OpBuilder<(ins "Type":$resultType, "Value":$mask,
+    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
                    CArg<"function_ref<void(OpBuilder &, Location)>",
                         "buildTerminatedBody">:$maskRegion)>,
-    OpBuilder<(ins "Type":$resultType, "Value":$mask,
+    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
                    "Value":$passthru,
                    CArg<"function_ref<void(OpBuilder &, Location)>",
                         "buildTerminatedBody">:$maskRegion)>

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e2a3e619164e2..f00d8494e3151 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5288,20 +5288,20 @@ void MaskOp::build(
 }
 
 void MaskOp::build(
-    OpBuilder &builder, OperationState &result, Type resultType, Value mask,
-    function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
-  build(builder, result, resultType, mask, /*passthru=*/Value(),
+    OpBuilder &builder, OperationState &result, TypeRange resultTypes,
+    Value mask, function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
+  build(builder, result, resultTypes, mask, /*passthru=*/Value(),
         maskRegionBuilder);
 }
 
 void MaskOp::build(
-    OpBuilder &builder, OperationState &result, Type resultType, Value mask,
-    Value passthru,
+    OpBuilder &builder, OperationState &result, TypeRange resultTypes,
+    Value mask, Value passthru,
     function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
   build(builder, result, mask, maskRegionBuilder);
   if (passthru)
     result.addOperands(passthru);
-  result.addTypes(resultType);
+  result.addTypes(resultTypes);
 }
 
 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {


        


More information about the Mlir-commits mailing list