[Mlir-commits] [mlir] f439b31 - [mlir][Linalg] Split reduction transform op
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Jun 21 05:11:53 PDT 2022
Author: Nicolas Vasilache
Date: 2022-06-21T05:01:26-07:00
New Revision: f439b31971a755c19777ae2c6c4114c4d02a54dd
URL: https://github.com/llvm/llvm-project/commit/f439b31971a755c19777ae2c6c4114c4d02a54dd
DIFF: https://github.com/llvm/llvm-project/commit/f439b31971a755c19777ae2c6c4114c4d02a54dd.diff
LOG: [mlir][Linalg] Split reduction transform op
This revision separates the `LinalgSplitReduction` pattern, whose application is based on attributes,
from its implementation.
A transform dialect op extension is added to control the application of the transformation at a finer granularity.
Differential Revision: https://reviews.llvm.org/D128165
Added:
mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index fca389b438a37..15186ecd36940 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -153,6 +153,75 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
}];
}
+def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface]> {
+ let description = [{
+ Indicates that the given `target` op should be transformed with the
+ `splitReduction` transformation and split factor provided as attribute.
+
+ The `splitReduction` transformation splits the first single linalg op
+ reduction into a parallel and reduction dimension.
+ A new `linalg.generic` op is created to perform the rest of the reduction.
+
+ Example:
+
+ ```
+ %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> ()>],
+ iterator_types = ["reduction"]}
+ ins(%in : tensor<32xf32>)
+ outs(%out : tensor<f32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %y = arith.addf %arg1, %arg2 : f32
+ linalg.yield %y : f32
+ } -> tensor<f32>
+ ```
+
+ To:
+
+ ```
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into tensor<4x8xf32>
+ %1 = linalg.init_tensor [4] : tensor<4xf32>
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32>
+ %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) {
+ ^bb0(%arg3: f32, %arg5: f32):
+ %5 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %5 : f32
+ } -> tensor<4xf32>
+ %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> ()>],
+ iterator_types = ["reduction"]}
+ ins(%3 : tensor<4xf32>) outs(%out : tensor<f32>) {
+ ^bb0(%arg3: f32, %arg4: f32):
+ %5 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %5 : f32
+ } -> tensor<f32>
+ ```
+
+ This op returns handles to the fill op used to initialize the neutral
+ element, the split op and the result-combining op.
+ }];
+
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64Attr, "{}">:$split_factor,
+ DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension);
+ let results = (outs PDL_Operation:$fill_op,
+ PDL_Operation:$split_linalg_op,
+ PDL_Operation:$combining_linalg_op);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
+ ::mlir::linalg::LinalgOp target);
+ }];
+}
+
def TileOp : Op<Transform_Dialect, "structured.tile",
[DeclareOpInterfaceMethods<TransformOpInterface>,
FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d5657b3bec4df..7e2d58939da1c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1466,6 +1466,7 @@ class TilingPatterns<OpTy, OpTypes...> {
/// reduction dimension. The dimension index is used to control where the extra
/// dimension is added to the intermediate tensor shape. If the ratio value is
/// less or equal to 1 then nothing will be done.
+// TODO: don't use unsigned unless doing bit manipulation.
using ControlSplitReductionFn =
std::function<std::pair<int64_t, unsigned>(LinalgOp op)>;
@@ -1519,6 +1520,18 @@ splitReduction(PatternRewriter &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn,
const LinalgTransformationFilter &f);
+/// Filterless version of the above.
+/// Returns both the new linalg ops as well as the fillOp needed to initialize
+/// the temporary expanded tensor with the proper neutral element.
+struct SplitReductionResult {
+ FillOp fillOp;
+ LinalgOp splitLinalgOp;
+ LinalgOp resultCombiningLinalgOp;
+};
+FailureOr<SplitReductionResult>
+splitReduction(PatternRewriter &b, LinalgOp op,
+ const ControlSplitReductionFn &controlSplitReductionFn);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index acb6769ce6314..7392a289135f6 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -98,13 +98,15 @@ class LLVM_NODISCARD DiagnosedSilenceableFailure {
/// Streams the given values into the diagnotic. Expects this object to be a
/// silencable failure.
- template <typename T> DiagnosedSilenceableFailure &operator<<(T &&value) & {
+ template <typename T>
+ DiagnosedSilenceableFailure &operator<<(T &&value) & {
assert(isSilenceableFailure() &&
"can only append output in silencable failure state");
*diagnostic << std::forward<T>(value);
return *this;
}
- template <typename T> DiagnosedSilenceableFailure &&operator<<(T &&value) && {
+ template <typename T>
+ DiagnosedSilenceableFailure &&operator<<(T &&value) && {
return std::move(this->operator<<(std::forward<T>(value)));
}
@@ -577,16 +579,17 @@ class PossibleTopLevelTransformOpTrait
};
/// Trait implementing the TransformOpInterface for operations applying a
-/// transformation to a single operation handle and producing a single operation
-/// handle. The op must implement a method with one of the following signatures:
+/// transformation to a single operation handle and producing one or multiple
+/// operation handles.
+/// The op must implement a method with one of the following signatures:
/// - FailureOr<convertible-to-Operation*> applyToOne(OpTy)
+/// - FailureOr<SmallVector<convertible-to-Operation*>> applyToOne(OpTy)
/// - LogicalResult applyToOne(OpTy)
/// to perform a transformation that is applied in turn to all payload IR
/// operations that correspond to the handle of the transform IR operation.
/// In the functions above, OpTy is either Operation * or a concrete payload IR
/// Op class that the transformation is applied to (NOT the class of the
-/// transform IR op). The op is expected to have one operand and zero or one
-/// results.
+/// transform IR op). The op is expected to have a single operand.
template <typename OpTy>
class TransformEachOpTrait
: public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
@@ -713,33 +716,53 @@ namespace transform {
namespace detail {
/// Appends `result` to the vector assuming it corresponds to the success state
/// in `FailureOr<convertible-to-Operation*>`. If `result` is just a
-/// `LogicalResult`, does nothing.
+/// `LogicalResult`, appends an empy vector.
template <typename Ty>
std::enable_if_t<std::is_same<Ty, LogicalResult>::value, LogicalResult>
-appendTransformResultToVector(Ty result,
- SmallVectorImpl<Operation *> &results) {
+appendTransformResultToVector(
+ Ty result, SmallVectorImpl<SmallVector<Operation *>> &results) {
+ results.push_back(SmallVector<Operation *>());
return result;
}
+
template <typename Ty>
-std::enable_if_t<!std::is_same<Ty, LogicalResult>::value, LogicalResult>
-appendTransformResultToVector(Ty result,
- SmallVectorImpl<Operation *> &results) {
- static_assert(
- std::is_convertible<typename Ty::value_type, Operation *>::value,
- "expected transform function to return operations");
+std::enable_if_t<
+ llvm::conjunction<
+ llvm::negation<std::is_same<Ty, LogicalResult>>,
+ std::is_convertible<typename Ty::value_type, Operation *>>::value,
+ LogicalResult>
+appendTransformResultToVector(
+ Ty result, SmallVectorImpl<SmallVector<Operation *>> &results) {
if (failed(result))
return failure();
-
- results.push_back(*result);
+ results.push_back(SmallVector<Operation *>{*result});
return success();
}
-/// Applies a one-to-one transform to each of the given targets. Puts the
-/// results of transforms, if any, in `results` in the same order. Fails if any
-/// of the application fails. Individual transforms must be callable with
-/// one of the following signatures:
+template <typename ContainerTy>
+std::enable_if_t<
+ llvm::conjunction<
+ llvm::negation<std::is_same<ContainerTy, LogicalResult>>,
+ llvm::negation<std::is_convertible<typename ContainerTy::value_type,
+ Operation *>>>::value,
+ LogicalResult>
+appendTransformResultToVector(
+ ContainerTy resultContainer,
+ SmallVectorImpl<SmallVector<Operation *>> &results) {
+ if (failed(resultContainer))
+ return failure();
+ results.push_back(*resultContainer);
+ return success();
+}
+/// Applies a one-to-one or a one-to-many transform to each of the given
+/// targets. Puts the results of transforms, if any, in `results` in the same
+/// order. Fails if any of the application fails. Individual transforms must be
+/// callable with one of the following signatures:
/// - FailureOr<convertible-to-Operation*>(OpTy)
/// - LogicalResult(OpTy)
+/// - FailureOr<SmallVectorImpl<convertible-to-Operation*>>(
+/// SmallVectorImpl<OpTy>)
+/// - LogicalResult(SmallVectorImpl<OpTy>)
/// where OpTy is either
/// - Operation *, in which case the transform is always applied;
/// - a concrete Op class, in which case a check is performed whether
@@ -748,7 +771,8 @@ appendTransformResultToVector(Ty result,
template <typename FnTy>
DiagnosedSilenceableFailure
applyTransformToEach(ArrayRef<Operation *> targets,
- SmallVectorImpl<Operation *> &results, FnTy transform) {
+ SmallVectorImpl<SmallVector<Operation *>> &results,
+ FnTy transform) {
using OpTy = typename llvm::function_traits<FnTy>::template arg_t<0>;
static_assert(std::is_convertible<OpTy, Operation *>::value,
"expected transform function to take an operation");
@@ -782,17 +806,36 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
decltype(&OpTy::applyToOne)>::template arg_t<0>;
ArrayRef<Operation *> targets =
state.getPayloadOps(this->getOperation()->getOperand(0));
- SmallVector<Operation *> results;
+ SmallVector<SmallVector<Operation *>, 1> results;
+ // In the multi-result case, collect the number of results each transform
+ // produced.
DiagnosedSilenceableFailure result = detail::applyTransformToEach(
targets, results, [&](TransformOpType specificOp) {
return static_cast<OpTy *>(this)->applyToOne(specificOp);
});
if (!result.succeeded())
return result;
-
- if (OpTy::template hasTrait<OpTrait::OneResult>()) {
- transformResults.set(
- this->getOperation()->getResult(0).template cast<OpResult>(), results);
+ for (const SmallVector<Operation *> &oneTargetResults : results) {
+ if (OpTy::template hasTrait<OpTrait::ZeroResults>())
+ continue;
+ if (OpTy::template hasTrait<OpTrait::OneResult>()) {
+ transformResults.set(
+ this->getOperation()->getResult(0).template cast<OpResult>(),
+ oneTargetResults);
+ continue;
+ }
+ if (this->getOperation()->getNumResults() != oneTargetResults.size()) {
+ Diagnostic diag(this->getOperation()->getLoc(),
+ DiagnosticSeverity::Error);
+ diag << "unexpected number of results (got " << oneTargetResults.size()
+ << " expected " << this->getOperation()->getNumResults() << ")";
+ return DiagnosedSilenceableFailure::silencableFailure(std::move(diag));
+ }
+ for (const auto &it :
+ llvm::zip(this->getOperation()->getResults(), oneTargetResults)) {
+ transformResults.set(std::get<0>(it).template cast<OpResult>(),
+ std::get<1>(it));
+ }
}
return DiagnosedSilenceableFailure::success();
}
@@ -802,9 +845,6 @@ mlir::LogicalResult
mlir::transform::TransformEachOpTrait<OpTy>::verifyTrait(Operation *op) {
static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
"expected single-operand op");
- static_assert(OpTy::template hasTrait<OpTrait::OneResult>() ||
- OpTy::template hasTrait<OpTrait::ZeroResults>(),
- "expected zero- or single-result op");
if (!op->getName().getInterface<TransformOpInterface>()) {
return op->emitError() << "TransformEachOpTrait should only be attached to "
"ops that implement TransformOpInterface";
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a5e865b4b96d7..db7b1808c658f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -394,6 +394,27 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
return result->op;
}
+//===----------------------------------------------------------------------===//
+// SplitReductionOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<SmallVector<Operation *>>
+transform::SplitReductionOp::applyToOne(LinalgOp target) {
+ ControlSplitReductionFn splitFn = [&](LinalgOp) {
+ return std::pair<int64_t, unsigned>(getSplitFactor(),
+ getInsertSplitDimension());
+ };
+ SimpleRewriter rewriter(getContext());
+ rewriter.setInsertionPoint(target);
+ FailureOr<SplitReductionResult> splitResult =
+ splitReduction(rewriter, target, splitFn);
+ if (failed(splitResult))
+ return getOperation()->emitError("failed to apply");
+ return SmallVector<Operation *>{splitResult->fillOp,
+ splitResult->splitLinalgOp,
+ splitResult->resultCombiningLinalgOp};
+}
+
//===----------------------------------------------------------------------===//
// TileOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index da3c52039d848..226b35d4495ce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -64,11 +64,30 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
!op.hasOnlyProjectedPermutations())
return b.notifyMatchFailure(op, "precondition not met");
+
+ FailureOr<SplitReductionResult> res =
+ splitReduction(b, op, controlSplitReductionFn);
+ if (failed(res))
+ return failure();
+
+ filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp);
+ filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp);
+
+ return res->splitLinalgOp;
+}
+
+FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
+ PatternRewriter &b, LinalgOp op,
+ const ControlSplitReductionFn &controlSplitReductionFn) {
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPoint(op);
+
std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
int64_t ratio = control.first;
unsigned insertDimIndex = control.second;
if (ratio <= 1)
return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
+
SmallVector<unsigned> dims;
op.getReductionDims(dims);
assert(dims.size() == 1);
@@ -79,14 +98,16 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges.size())
return b.notifyMatchFailure(
op, "Reduction dimension not divisible by split ratio");
+
SmallVector<Operation *, 4> combinerOps;
if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
combinerOps.size() != 1)
return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
+
Operation *reductionOp = combinerOps[0];
Optional<Attribute> identity = getIdentity(reductionOp);
if (!identity)
- return b.notifyMatchFailure(op, "Unknown identity value for the redution");
+ return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
Location loc = op->getLoc();
SmallVector<Value> newInputs;
@@ -127,6 +148,7 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
loc, newType, operand->get(), reassociation);
newInputs.push_back(newInput);
}
+
// Calculate the new output map and shape, we insert the new dimension based
// on the index returned by `controlSplitReductionFn`.
SmallVector<int64_t> newOutputShape;
@@ -169,8 +191,8 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
genericOp.region().begin());
- // Then create a new reduction that only reduce the newly added dimension from
- // the previous op.
+ // Then create a new reduction that only reduce the newly added dimension
+ // from the previous op.
unsigned intermRank = newOutputShape.size();
AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
SmallVector<Value> outputOperands = op.getOutputOperands();
@@ -197,9 +219,10 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
});
b.replaceOp(op, reduction.getResults());
- filter.replaceLinalgTransformationFilter(b, genericOp);
- filter.replaceLinalgTransformationFilter(b, reduction);
- return cast<LinalgOp>(genericOp.getOperation());
+
+ return SplitReductionResult{identityTensor.getDefiningOp<FillOp>(),
+ cast<LinalgOp>(genericOp.getOperation()),
+ reduction};
}
namespace {
diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
new file mode 100644
index 0000000000000..e7a6a14ebd0f8
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s
+
+// CHECK-LABEL: func.func @matmul_split
+func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+
+ // CHECK: linalg.generic
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+ // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<16x4x64xf32>, tensor<4x64x32xf32>)
+ // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<16x32x4xf32>) {
+
+ // CHECK: linalg.generic
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+ // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<16x32x4xf32>)
+ // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<16x32xf32>) {
+ %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
+ outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+ return %0: tensor<16x32xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
+
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %1:3 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2}
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 45e01e82ab362..6a54b0ce6f59e 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -378,3 +378,10 @@ transform.with_pdl_patterns {
}
}
}
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+ // expected-error @below {{unexpected number of results (got 0 expected 3)}}
+ transform.test_wrong_number_of_results %arg0
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 5c971634ab09f..43c181651a42e 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -226,6 +226,11 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
return DiagnosedSilenceableFailure::success();
}
+FailureOr<SmallVector<Operation *>>
+mlir::test::TestWrongNumberOfResultsOp::applyToOne(Operation *) {
+ return SmallVector<Operation *>{};
+}
+
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
index 891249c854e0e..c38693fb2b3a6 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h
@@ -16,7 +16,7 @@
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
-#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
namespace mlir {
class DialectRegistry;
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index a8dab8be106fb..d811a57d3c112 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -128,4 +128,20 @@ def TestEmitRemarkAndEraseOperandOp
let cppNamespace = "::mlir::test";
}
+def TestWrongNumberOfResultsOp
+ : Op<Transform_Dialect, "test_wrong_number_of_results",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface]> {
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$a,
+ PDL_Operation:$b,
+ PDL_Operation:$c);
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "::mlir::test";
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
+ ::mlir::Operation *target);
+ }];
+}
+
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
More information about the Mlir-commits
mailing list