[Mlir-commits] [mlir] f439b31 - [mlir][Linalg] Split reduction transform op

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jun 21 05:11:53 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-21T05:01:26-07:00
New Revision: f439b31971a755c19777ae2c6c4114c4d02a54dd

URL: https://github.com/llvm/llvm-project/commit/f439b31971a755c19777ae2c6c4114c4d02a54dd
DIFF: https://github.com/llvm/llvm-project/commit/f439b31971a755c19777ae2c6c4114c4d02a54dd.diff

LOG: [mlir][Linalg] Split reduction transform op

This revision separates the `LinalgSplitReduction` pattern, whose application is based on attributes,
from its implementation.
A transform dialect op extension is added to control the application of the transformation at a finer granularity.

Differential Revision: https://reviews.llvm.org/D128165

Added: 
    mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index fca389b438a37..15186ecd36940 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -153,6 +153,75 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
   }];
 }
 
+def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
+       [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+        TransformEachOpTrait, TransformOpInterface]> {
+  let description = [{
+    Indicates that the given `target` op should be transformed with the 
+    `splitReduction` transformation and split factor provided as attribute.
+
+    The `splitReduction` transformation splits the first single linalg op 
+    reduction into a parallel and reduction dimension. 
+    A new `linalg.generic` op is created to perform the rest of the reduction. 
+    
+    Example:
+    
+    ```
+      %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+                                            affine_map<(d0) -> ()>],
+            iterator_types = ["reduction"]}
+      ins(%in : tensor<32xf32>)
+      outs(%out : tensor<f32>) {
+      ^bb0(%arg1: f32, %arg2: f32):
+        %y = arith.addf %arg1, %arg2 : f32
+        linalg.yield %y : f32
+      } -> tensor<f32>
+    ```
+    
+    To:
+    
+    ```
+      %cst = arith.constant 0.000000e+00 : f32
+      %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into tensor<4x8xf32>
+      %1 = linalg.init_tensor [4] : tensor<4xf32>
+      %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32>
+      %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                            affine_map<(d0, d1) -> (d0)>],
+        iterator_types = ["parallel", "reduction"]}
+        ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) {
+        ^bb0(%arg3: f32, %arg5: f32):
+        %5 = arith.addf %arg3, %arg4 : f32
+        linalg.yield %5 : f32
+      } -> tensor<4xf32>
+      %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+                                            affine_map<(d0) -> ()>],
+        iterator_types = ["reduction"]}
+        ins(%3 : tensor<4xf32>) outs(%out : tensor<f32>) {
+        ^bb0(%arg3: f32, %arg4: f32):
+        %5 = arith.addf %arg3, %arg4 : f32
+        linalg.yield %5 : f32
+      } -> tensor<f32>
+    ```
+
+    This op returns handles to the fill op used to initialize the neutral 
+    element, the split op and the result-combining op.
+  }];
+
+  let arguments = (ins PDL_Operation:$target,
+                   DefaultValuedAttr<I64Attr, "{}">:$split_factor,
+                   DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension);
+  let results = (outs PDL_Operation:$fill_op,
+                      PDL_Operation:$split_linalg_op,
+                      PDL_Operation:$combining_linalg_op);
+
+  let assemblyFormat = "$target attr-dict";
+
+  let extraClassDeclaration = [{
+    ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
+        ::mlir::linalg::LinalgOp target);
+  }];
+}
+
 def TileOp : Op<Transform_Dialect, "structured.tile",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d5657b3bec4df..7e2d58939da1c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1466,6 +1466,7 @@ class TilingPatterns<OpTy, OpTypes...> {
 /// reduction dimension. The dimension index is used to control where the extra
 /// dimension is added to the intermediate tensor shape. If the ratio value is
 /// less or equal to 1 then nothing will be done.
+// TODO: don't use unsigned unless doing bit manipulation.
 using ControlSplitReductionFn =
     std::function<std::pair<int64_t, unsigned>(LinalgOp op)>;
 
@@ -1519,6 +1520,18 @@ splitReduction(PatternRewriter &b, LinalgOp op,
                const ControlSplitReductionFn &controlSplitReductionFn,
                const LinalgTransformationFilter &f);
 
+/// Filterless version of the above.
+/// Returns both the new linalg ops as well as the fillOp needed to initialize
+/// the temporary expanded tensor with the proper neutral element.
+struct SplitReductionResult {
+  FillOp fillOp;
+  LinalgOp splitLinalgOp;
+  LinalgOp resultCombiningLinalgOp;
+};
+FailureOr<SplitReductionResult>
+splitReduction(PatternRewriter &b, LinalgOp op,
+               const ControlSplitReductionFn &controlSplitReductionFn);
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index acb6769ce6314..7392a289135f6 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -98,13 +98,15 @@ class LLVM_NODISCARD DiagnosedSilenceableFailure {
 
   /// Streams the given values into the diagnotic. Expects this object to be a
   /// silencable failure.
-  template <typename T> DiagnosedSilenceableFailure &operator<<(T &&value) & {
+  template <typename T>
+  DiagnosedSilenceableFailure &operator<<(T &&value) & {
     assert(isSilenceableFailure() &&
            "can only append output in silencable failure state");
     *diagnostic << std::forward<T>(value);
     return *this;
   }
-  template <typename T> DiagnosedSilenceableFailure &&operator<<(T &&value) && {
+  template <typename T>
+  DiagnosedSilenceableFailure &&operator<<(T &&value) && {
     return std::move(this->operator<<(std::forward<T>(value)));
   }
 
@@ -577,16 +579,17 @@ class PossibleTopLevelTransformOpTrait
 };
 
 /// Trait implementing the TransformOpInterface for operations applying a
-/// transformation to a single operation handle and producing a single operation
-/// handle. The op must implement a method with one of the following signatures:
+/// transformation to a single operation handle and producing one or multiple
+/// operation handles.
+/// The op must implement a method with one of the following signatures:
 ///   - FailureOr<convertible-to-Operation*> applyToOne(OpTy)
+///   - FailureOr<SmallVector<convertible-to-Operation*>> applyToOne(OpTy)
 ///   - LogicalResult applyToOne(OpTy)
 /// to perform a transformation that is applied in turn to all payload IR
 /// operations that correspond to the handle of the transform IR operation.
 /// In the functions above, OpTy is either Operation * or a concrete payload IR
 /// Op class that the transformation is applied to (NOT the class of the
-/// transform IR op). The op is expected to have one operand and zero or one
-/// results.
+/// transform IR op). The op is expected to have a single operand.
 template <typename OpTy>
 class TransformEachOpTrait
     : public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
@@ -713,33 +716,53 @@ namespace transform {
 namespace detail {
 /// Appends `result` to the vector assuming it corresponds to the success state
 /// in `FailureOr<convertible-to-Operation*>`. If `result` is just a
-/// `LogicalResult`, does nothing.
+/// `LogicalResult`, appends an empy vector.
 template <typename Ty>
 std::enable_if_t<std::is_same<Ty, LogicalResult>::value, LogicalResult>
-appendTransformResultToVector(Ty result,
-                              SmallVectorImpl<Operation *> &results) {
+appendTransformResultToVector(
+    Ty result, SmallVectorImpl<SmallVector<Operation *>> &results) {
+  results.push_back(SmallVector<Operation *>());
   return result;
 }
+
 template <typename Ty>
-std::enable_if_t<!std::is_same<Ty, LogicalResult>::value, LogicalResult>
-appendTransformResultToVector(Ty result,
-                              SmallVectorImpl<Operation *> &results) {
-  static_assert(
-      std::is_convertible<typename Ty::value_type, Operation *>::value,
-      "expected transform function to return operations");
+std::enable_if_t<
+    llvm::conjunction<
+        llvm::negation<std::is_same<Ty, LogicalResult>>,
+        std::is_convertible<typename Ty::value_type, Operation *>>::value,
+    LogicalResult>
+appendTransformResultToVector(
+    Ty result, SmallVectorImpl<SmallVector<Operation *>> &results) {
   if (failed(result))
     return failure();
-
-  results.push_back(*result);
+  results.push_back(SmallVector<Operation *>{*result});
   return success();
 }
 
-/// Applies a one-to-one transform to each of the given targets. Puts the
-/// results of transforms, if any, in `results` in the same order. Fails if any
-/// of the application fails. Individual transforms must be callable with
-/// one of the following signatures:
+template <typename ContainerTy>
+std::enable_if_t<
+    llvm::conjunction<
+        llvm::negation<std::is_same<ContainerTy, LogicalResult>>,
+        llvm::negation<std::is_convertible<typename ContainerTy::value_type,
+                                           Operation *>>>::value,
+    LogicalResult>
+appendTransformResultToVector(
+    ContainerTy resultContainer,
+    SmallVectorImpl<SmallVector<Operation *>> &results) {
+  if (failed(resultContainer))
+    return failure();
+  results.push_back(*resultContainer);
+  return success();
+}
+/// Applies a one-to-one or a one-to-many transform to each of the given
+/// targets. Puts the results of transforms, if any, in `results` in the same
+/// order. Fails if any of the application fails. Individual transforms must be
+/// callable with one of the following signatures:
 ///   - FailureOr<convertible-to-Operation*>(OpTy)
 ///   - LogicalResult(OpTy)
+///   - FailureOr<SmallVectorImpl<convertible-to-Operation*>>(
+///       SmallVectorImpl<OpTy>)
+///   - LogicalResult(SmallVectorImpl<OpTy>)
 /// where OpTy is either
 ///   - Operation *, in which case the transform is always applied;
 ///   - a concrete Op class, in which case a check is performed whether
@@ -748,7 +771,8 @@ appendTransformResultToVector(Ty result,
 template <typename FnTy>
 DiagnosedSilenceableFailure
 applyTransformToEach(ArrayRef<Operation *> targets,
-                     SmallVectorImpl<Operation *> &results, FnTy transform) {
+                     SmallVectorImpl<SmallVector<Operation *>> &results,
+                     FnTy transform) {
   using OpTy = typename llvm::function_traits<FnTy>::template arg_t<0>;
   static_assert(std::is_convertible<OpTy, Operation *>::value,
                 "expected transform function to take an operation");
@@ -782,17 +806,36 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
       decltype(&OpTy::applyToOne)>::template arg_t<0>;
   ArrayRef<Operation *> targets =
       state.getPayloadOps(this->getOperation()->getOperand(0));
-  SmallVector<Operation *> results;
+  SmallVector<SmallVector<Operation *>, 1> results;
+  // In the multi-result case, collect the number of results each transform
+  // produced.
   DiagnosedSilenceableFailure result = detail::applyTransformToEach(
       targets, results, [&](TransformOpType specificOp) {
         return static_cast<OpTy *>(this)->applyToOne(specificOp);
       });
   if (!result.succeeded())
     return result;
-
-  if (OpTy::template hasTrait<OpTrait::OneResult>()) {
-    transformResults.set(
-        this->getOperation()->getResult(0).template cast<OpResult>(), results);
+  for (const SmallVector<Operation *> &oneTargetResults : results) {
+    if (OpTy::template hasTrait<OpTrait::ZeroResults>())
+      continue;
+    if (OpTy::template hasTrait<OpTrait::OneResult>()) {
+      transformResults.set(
+          this->getOperation()->getResult(0).template cast<OpResult>(),
+          oneTargetResults);
+      continue;
+    }
+    if (this->getOperation()->getNumResults() != oneTargetResults.size()) {
+      Diagnostic diag(this->getOperation()->getLoc(),
+                      DiagnosticSeverity::Error);
+      diag << "unexpected number of results (got " << oneTargetResults.size()
+           << " expected " << this->getOperation()->getNumResults() << ")";
+      return DiagnosedSilenceableFailure::silencableFailure(std::move(diag));
+    }
+    for (const auto &it :
+         llvm::zip(this->getOperation()->getResults(), oneTargetResults)) {
+      transformResults.set(std::get<0>(it).template cast<OpResult>(),
+                           std::get<1>(it));
+    }
   }
   return DiagnosedSilenceableFailure::success();
 }
@@ -802,9 +845,6 @@ mlir::LogicalResult
 mlir::transform::TransformEachOpTrait<OpTy>::verifyTrait(Operation *op) {
   static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
                 "expected single-operand op");
-  static_assert(OpTy::template hasTrait<OpTrait::OneResult>() ||
-                    OpTy::template hasTrait<OpTrait::ZeroResults>(),
-                "expected zero- or single-result op");
   if (!op->getName().getInterface<TransformOpInterface>()) {
     return op->emitError() << "TransformEachOpTrait should only be attached to "
                               "ops that implement TransformOpInterface";

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a5e865b4b96d7..db7b1808c658f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -394,6 +394,27 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
   return result->op;
 }
 
+//===----------------------------------------------------------------------===//
+// SplitReductionOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<SmallVector<Operation *>>
+transform::SplitReductionOp::applyToOne(LinalgOp target) {
+  ControlSplitReductionFn splitFn = [&](LinalgOp) {
+    return std::pair<int64_t, unsigned>(getSplitFactor(),
+                                        getInsertSplitDimension());
+  };
+  SimpleRewriter rewriter(getContext());
+  rewriter.setInsertionPoint(target);
+  FailureOr<SplitReductionResult> splitResult =
+      splitReduction(rewriter, target, splitFn);
+  if (failed(splitResult))
+    return getOperation()->emitError("failed to apply");
+  return SmallVector<Operation *>{splitResult->fillOp,
+                                  splitResult->splitLinalgOp,
+                                  splitResult->resultCombiningLinalgOp};
+}
+
 //===----------------------------------------------------------------------===//
 // TileOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index da3c52039d848..226b35d4495ce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -64,11 +64,30 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
       op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
       !op.hasOnlyProjectedPermutations())
     return b.notifyMatchFailure(op, "precondition not met");
+
+  FailureOr<SplitReductionResult> res =
+      splitReduction(b, op, controlSplitReductionFn);
+  if (failed(res))
+    return failure();
+
+  filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp);
+  filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp);
+
+  return res->splitLinalgOp;
+}
+
+FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
+    PatternRewriter &b, LinalgOp op,
+    const ControlSplitReductionFn &controlSplitReductionFn) {
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPoint(op);
+
   std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
   int64_t ratio = control.first;
   unsigned insertDimIndex = control.second;
   if (ratio <= 1)
     return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
+
   SmallVector<unsigned> dims;
   op.getReductionDims(dims);
   assert(dims.size() == 1);
@@ -79,14 +98,16 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
       reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges.size())
     return b.notifyMatchFailure(
         op, "Reduction dimension not divisible by split ratio");
+
   SmallVector<Operation *, 4> combinerOps;
   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
       combinerOps.size() != 1)
     return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
+
   Operation *reductionOp = combinerOps[0];
   Optional<Attribute> identity = getIdentity(reductionOp);
   if (!identity)
-    return b.notifyMatchFailure(op, "Unknown identity value for the redution");
+    return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
 
   Location loc = op->getLoc();
   SmallVector<Value> newInputs;
@@ -127,6 +148,7 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
         loc, newType, operand->get(), reassociation);
     newInputs.push_back(newInput);
   }
+
   // Calculate the new output map and shape, we insert the new dimension based
   // on the index returned by `controlSplitReductionFn`.
   SmallVector<int64_t> newOutputShape;
@@ -169,8 +191,8 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
   b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
                        genericOp.region().begin());
 
-  // Then create a new reduction that only reduce the newly added dimension from
-  // the previous op.
+  // Then create a new reduction that only reduce the newly added dimension
+  // from the previous op.
   unsigned intermRank = newOutputShape.size();
   AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
   SmallVector<Value> outputOperands = op.getOutputOperands();
@@ -197,9 +219,10 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
         b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
       });
   b.replaceOp(op, reduction.getResults());
-  filter.replaceLinalgTransformationFilter(b, genericOp);
-  filter.replaceLinalgTransformationFilter(b, reduction);
-  return cast<LinalgOp>(genericOp.getOperation());
+
+  return SplitReductionResult{identityTensor.getDefiningOp<FillOp>(),
+                              cast<LinalgOp>(genericOp.getOperation()),
+                              reduction};
 }
 
 namespace {

diff  --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
new file mode 100644
index 0000000000000..e7a6a14ebd0f8
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s
+
+// CHECK-LABEL: func.func @matmul_split
+func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+
+  //      CHECK: linalg.generic 
+  // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+  // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<16x4x64xf32>, tensor<4x64x32xf32>)
+  // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<16x32x4xf32>) {
+
+  //      CHECK: linalg.generic 
+  // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+  // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<16x32x4xf32>)
+  // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<16x32xf32>) {
+  %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
+                    outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0: tensor<16x32xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target in %arg1
+    %1:3 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2}
+  }
+}

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 45e01e82ab362..6a54b0ce6f59e 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -378,3 +378,10 @@ transform.with_pdl_patterns {
     }
   }
 }
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-error @below {{unexpected number of results (got 0 expected 3)}}
+  transform.test_wrong_number_of_results %arg0
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 5c971634ab09f..43c181651a42e 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -226,6 +226,11 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
   return DiagnosedSilenceableFailure::success();
 }
 
+FailureOr<SmallVector<Operation *>>
+mlir::test::TestWrongNumberOfResultsOp::applyToOne(Operation *) {
+  return SmallVector<Operation *>{};
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
index 891249c854e0e..c38693fb2b3a6 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
@@ -16,7 +16,7 @@
 
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
-#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
 
 namespace mlir {
 class DialectRegistry;

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index a8dab8be106fb..d811a57d3c112 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -128,4 +128,20 @@ def TestEmitRemarkAndEraseOperandOp
   let cppNamespace = "::mlir::test";
 }
 
+def TestWrongNumberOfResultsOp
+  : Op<Transform_Dialect, "test_wrong_number_of_results",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, 
+     TransformEachOpTrait, TransformOpInterface]> {
+  let arguments = (ins PDL_Operation:$target);
+  let results = (outs PDL_Operation:$a,
+                      PDL_Operation:$b,
+                      PDL_Operation:$c);
+  let assemblyFormat = "$target attr-dict";
+  let cppNamespace = "::mlir::test";
+  let extraClassDeclaration = [{
+    ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
+        ::mlir::Operation *target);
+  }];
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD


        


More information about the Mlir-commits mailing list