[Mlir-commits] [mlir] 42b1603 - [mlir][transform] Add an op for replacing values with function calls (#78398)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 19 10:21:56 PST 2024
Author: Quinn Dawkins
Date: 2024-01-19T13:21:52-05:00
New Revision: 42b160356fe5d3b41bf07c428d0142d3721b1d44
URL: https://github.com/llvm/llvm-project/commit/42b160356fe5d3b41bf07c428d0142d3721b1d44
DIFF: https://github.com/llvm/llvm-project/commit/42b160356fe5d3b41bf07c428d0142d3721b1d44.diff
LOG: [mlir][transform] Add an op for replacing values with function calls (#78398)
Adds `transform.func.cast_and_call` that takes a set of inputs and
outputs and replaces the uses of those outputs with a call to a function
at a specified insertion point.
The idea with this operation is to allow users to author independent IR
outside of a to-be-compiled module, and then match and replace a slice
of the program with a call to the external function.
Additionally adds a mechanism for populating a type converter with a set
of conversion materialization functions that allow insertion of
casts on the inputs/outputs to and from the types of the function
signature.
Added:
mlir/test/Dialect/Func/func-transform.mlir
mlir/test/Dialect/Tensor/transform-op-casting.mlir
Modified:
mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index 7a7e991c7861886..c36fdd150556208 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -12,6 +12,8 @@
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/RegionKindInterface.td"
include "mlir/IR/OpBase.td"
def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
@@ -26,4 +28,74 @@ def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def CastAndCallOp : Op<Transform_Dialect,
+ "func.cast_and_call",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ AttrSizedOperandSegments,
+ ReportTrackingListenerFailuresOpTrait]
+ # GraphRegionNoTerminator.traits> {
+ let summary = "Casts values to the signature of a function and replaces them "
+ "with a call";
+ let description = [{
+ This transform takes value handles to a set of `inputs` and `outputs` and
+ attempts to cast them to the function signature of the attached function
+ op, then builds a call to the function and replaces the users of the
+ outputs. It is the responsibility of the user to ensure that the slice of
+ the program replaced by this operation makes sense, i.e. there is no
+ verification that the inputs to this operation have any relation to the
+ outputs outside of basic dominance requirements needed for the call.
+
+ The casting materialization functions are specified in the graph region of
+ this op. They must implement the `TypeConverterBuilderOpInterface`. The
+ order of ops within the region is irrelevant.
+
+ The target function can be specified by a symbol name or by a handle to the
+ operation.
+
+ This transform only reads the operand handles and only replaces the users of
+ the outputs with the results of the call. No handles are consumed and no
+ operations are removed. Users are expected to run cleanup separately if
+ desired.
+
+ Warning: The replacement of the uses of the outputs could invalidate certain
+ restricted value handle types (e.g. `transform.block_arg` if it existed, by
+ replacing the use with something not coming from a block argument). The
+ value will still exist in such cases but wouldn't verify against the type.
+ See the discussion here for more information:
+ https://github.com/llvm/llvm-project/pull/78398#discussion_r1455070087
+
+ This transform will emit a silenceable failure if:
+ - The set of outputs isn't unique
+ - The handle for the insertion point does not include exactly one operation
+ - The insertion point op does not dominate any of the output users
+ - The insertion point op is not dominated by any of the inputs
+ - The function signature does not match the number of inputs/outputs
+
+ This transform will emit a definite failure if it fails to resolve the
+ target function, or if it fails to materialize the conversion casts of
+ either the inputs to the function argument types, or the call results to
+ the output types.
+ }];
+
+ let arguments = (ins
+ TransformHandleTypeInterface:$insertion_point,
+ UnitAttr:$insert_after,
+ Optional<TransformValueHandleTypeInterface>:$inputs,
+ Optional<TransformValueHandleTypeInterface>:$outputs,
+ OptionalAttr<SymbolRefAttr>:$function_name,
+ Optional<TransformHandleTypeInterface>:$function);
+ let results = (outs TransformHandleTypeInterface:$result);
+ let regions = (region MaxSizedRegion<1>:$conversions);
+
+ let assemblyFormat = [{
+ ($function_name^)? ($function^)?
+ ( `(` $inputs^ `)` )?
+ ( `->` $outputs^ )?
+ (`after` $insert_after^):(`before`)? $insertion_point
+ ($conversions^)? attr-dict `:` functional-type(operands, results)
+ }];
+ let hasVerifier = 1;
+}
+
#endif // FUNC_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 76309b9b8a9640d..29383a3825be883 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -18,7 +18,8 @@ include "mlir/IR/OpBase.td"
def MemrefToLLVMTypeConverterOp : Op<Transform_Dialect,
"apply_conversion_patterns.memref.memref_to_llvm_type_converter",
[DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
- ["getTypeConverterType"]>]> {
+ ["getTypeConverter",
+ "getTypeConverterType"]>]> {
let description = [{
This operation provides an "LLVMTypeConverter" that lowers memref types to
LLVM types.
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 8556d9570fd1200..39e1d7fa3494a39 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -169,4 +169,22 @@ def MakeLoopIndependentOp
}];
}
+def TypeConversionCastShapeDynamicDimsOp : Op<Transform_Dialect,
+ "type_conversion.tensor.cast_shape_dynamic_dims",
+ [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
+ ["populateTypeMaterializations"]>]> {
+ let description = [{
+ Populates a type converter with conversion materialization functions that
+ cast a tensor value between two cast-compatible tensors. See `tensor.cast`
+ for more information on cast compatibility between tensors.
+
+ If `ignore_dynamic_info` is not set, this will set an additional constraint
+ that source materializations do not cast dynamic dimensions to static ones.
+ }];
+ let arguments = (ins UnitAttr:$ignore_dynamic_info);
+
+ let assemblyFormat =
+ "(`ignore_dynamic_info` $ignore_dynamic_info^)? attr-dict";
+}
+
#endif // TENSOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index f29efaee620d845..8f7b8f1999e0c59 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -284,8 +284,14 @@ def TypeConverterBuilderOpInterface
: OpInterface<"TypeConverterBuilderOpInterface"> {
let description = [{
This interface should be implemented by ops that specify a type converter
- for a dialect conversion. Such ops can be used with
- "apply_conversion_patterns".
+ for a dialect conversion, or to populate a type converter with
+ conversions.
+
+ When such ops are intended to be used with "apply_conversion_patterns" or
+ other operations that expect a type converter, a non-default implementation
+ of `getTypeConverter` should be implemented. For use with "cast_and_call"
+ like ops that construct a type converter iteratively, non-default
+ `populateTypeMaterializations` should be implemented.
}];
let cppNamespace = "::mlir::transform";
@@ -297,7 +303,11 @@ def TypeConverterBuilderOpInterface
}],
/*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
/*name=*/"getTypeConverter",
- /*arguments=*/(ins)
+ /*arguments=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return std::make_unique<::mlir::TypeConverter>();
+ }]
>,
StaticInterfaceMethod<
/*desc=*/[{
@@ -310,6 +320,17 @@ def TypeConverterBuilderOpInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{ return "TypeConverter"; }]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Populate the given type converter with source/target materialization
+ functions.
+ }],
+ /*returnType=*/"void",
+ /*name=*/"populateTypeMaterializations",
+ /*arguments=*/(ins "::mlir::TypeConverter &":$converter),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{ return; }]
+ >,
];
}
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 9e9b6bcea790def..9e79b086c0be841 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@@ -36,6 +37,196 @@ transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
return success();
}
+//===----------------------------------------------------------------------===//
+// CastAndCallOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Value> inputs;
+ if (getInputs())
+ llvm::append_range(inputs, state.getPayloadValues(getInputs()));
+
+ SetVector<Value> outputs;
+ if (getOutputs()) {
+ for (auto output : state.getPayloadValues(getOutputs()))
+ outputs.insert(output);
+
+ // Verify that the set of output values to be replaced is unique.
+ if (outputs.size() !=
+ llvm::range_size(state.getPayloadValues(getOutputs()))) {
+ return emitSilenceableFailure(getLoc())
+ << "cast and call output values must be unique";
+ }
+ }
+
+ // Get the insertion point for the call.
+ auto insertionOps = state.getPayloadOps(getInsertionPoint());
+ if (!llvm::hasSingleElement(insertionOps)) {
+ return emitSilenceableFailure(getLoc())
+ << "Only one op can be specified as an insertion point";
+ }
+ bool insertAfter = getInsertAfter();
+ Operation *insertionPoint = *insertionOps.begin();
+
+ // Check that all inputs dominate the insertion point, and the insertion
+ // point dominates all users of the outputs.
+ DominanceInfo dom(insertionPoint);
+ for (Value output : outputs) {
+ for (Operation *user : output.getUsers()) {
+ // If we are inserting after the insertion point operation, the
+ // insertion point operation must properly dominate the user. Otherwise
+ // basic dominance is enough.
+ bool doesDominate = insertAfter
+ ? dom.properlyDominates(insertionPoint, user)
+ : dom.dominates(insertionPoint, user);
+ if (!doesDominate) {
+ return emitDefiniteFailure()
+ << "User " << user << " is not dominated by insertion point "
+ << insertionPoint;
+ }
+ }
+ }
+
+ for (Value input : inputs) {
+ // If we are inserting before the insertion point operation, the
+ // input must properly dominate the insertion point operation. Otherwise
+ // basic dominance is enough.
+ bool doesDominate = insertAfter
+ ? dom.dominates(input, insertionPoint)
+ : dom.properlyDominates(input, insertionPoint);
+ if (!doesDominate) {
+ return emitDefiniteFailure()
+ << "input " << input << " does not dominate insertion point "
+ << insertionPoint;
+ }
+ }
+
+ // Get the function to call. This can either be specified by symbol or as a
+ // transform handle.
+ func::FuncOp targetFunction = nullptr;
+ if (getFunctionName()) {
+ targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
+ insertionPoint, *getFunctionName());
+ if (!targetFunction) {
+ return emitDefiniteFailure()
+ << "unresolved symbol " << *getFunctionName();
+ }
+ } else if (getFunction()) {
+ auto payloadOps = state.getPayloadOps(getFunction());
+ if (!llvm::hasSingleElement(payloadOps)) {
+ return emitDefiniteFailure() << "requires a single function to call";
+ }
+ targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
+ if (!targetFunction) {
+ return emitDefiniteFailure() << "invalid non-function callee";
+ }
+ } else {
+ llvm_unreachable("Invalid CastAndCall op without a function to call");
+ return emitDefiniteFailure();
+ }
+
+ // Verify that the function argument and result lengths match the inputs and
+ // outputs given to this op.
+ if (targetFunction.getNumArguments() != inputs.size()) {
+ return emitSilenceableFailure(targetFunction.getLoc())
+ << "mismatch between number of function arguments "
+ << targetFunction.getNumArguments() << " and number of inputs "
+ << inputs.size();
+ }
+ if (targetFunction.getNumResults() != outputs.size()) {
+ return emitSilenceableFailure(targetFunction.getLoc())
+ << "mismatch between number of function results "
+ << targetFunction->getNumResults() << " and number of outputs "
+ << outputs.size();
+ }
+
+ // Gather all specified converters.
+ mlir::TypeConverter converter;
+ if (!getRegion().empty()) {
+ for (Operation &op : getRegion().front()) {
+ cast<transform::TypeConverterBuilderOpInterface>(&op)
+ .populateTypeMaterializations(converter);
+ }
+ }
+
+ if (insertAfter)
+ rewriter.setInsertionPointAfter(insertionPoint);
+ else
+ rewriter.setInsertionPoint(insertionPoint);
+
+ for (auto [input, type] :
+ llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
+ if (input.getType() != type) {
+ Value newInput = converter.materializeSourceConversion(
+ rewriter, input.getLoc(), type, input);
+ if (!newInput) {
+ return emitDefiniteFailure() << "Failed to materialize conversion of "
+ << input << " to type " << type;
+ }
+ input = newInput;
+ }
+ }
+
+ auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
+ targetFunction, inputs);
+
+ // Cast the call results back to the expected types. If any conversions fail
+ // this is a definite failure as the call has been constructed at this point.
+ for (auto [output, newOutput] :
+ llvm::zip_equal(outputs, callOp.getResults())) {
+ Value convertedOutput = newOutput;
+ if (output.getType() != newOutput.getType()) {
+ convertedOutput = converter.materializeTargetConversion(
+ rewriter, output.getLoc(), output.getType(), newOutput);
+ if (!convertedOutput) {
+ return emitDefiniteFailure()
+ << "Failed to materialize conversion of " << newOutput
+ << " to type " << output.getType();
+ }
+ }
+ rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
+ }
+ results.set(cast<OpResult>(getResult()), {callOp});
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::CastAndCallOp::verify() {
+ if (!getRegion().empty()) {
+ for (Operation &op : getRegion().front()) {
+ if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
+ InFlightDiagnostic diag = emitOpError()
+ << "expected children ops to implement "
+ "TypeConverterBuilderOpInterface";
+ diag.attachNote(op.getLoc()) << "op without interface";
+ return diag;
+ }
+ }
+ }
+ if (!getFunction() && !getFunctionName()) {
+ return emitOpError() << "expected a function handle or name to call";
+ }
+ if (getFunction() && getFunctionName()) {
+ return emitOpError() << "function handle and name are mutually exclusive";
+ }
+ return success();
+}
+
+void transform::CastAndCallOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getInsertionPoint(), effects);
+ if (getInputs())
+ transform::onlyReadsHandle(getInputs(), effects);
+ if (getOutputs())
+ transform::onlyReadsHandle(getOutputs(), effects);
+ if (getFunction())
+ transform::onlyReadsHandle(getFunction(), effects);
+ transform::producesHandle(getResult(), effects);
+ transform::modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index ed274238704713c..38f1824a3634a35 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -15,6 +15,8 @@
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace tensor;
@@ -128,6 +130,44 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
tensor::populateRewriteAsConstantPatterns(patterns);
}
+//===----------------------------------------------------------------------===//
+// TypeConversionCastTensorShapeOp
+//===----------------------------------------------------------------------===//
+
+void transform::TypeConversionCastShapeDynamicDimsOp::
+ populateTypeMaterializations(TypeConverter &converter) {
+ bool ignoreDynamicInfo = getIgnoreDynamicInfo();
+ converter.addSourceMaterialization([ignoreDynamicInfo](
+ OpBuilder &builder, Type resultType,
+ ValueRange inputs,
+ Location loc) -> std::optional<Value> {
+ if (inputs.size() != 1) {
+ return std::nullopt;
+ }
+ Value input = inputs[0];
+ if (!ignoreDynamicInfo &&
+ !tensor::preservesStaticInformation(resultType, input.getType())) {
+ return std::nullopt;
+ }
+ if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
+ return std::nullopt;
+ }
+ return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
+ });
+ converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
+ ValueRange inputs,
+ Location loc) -> std::optional<Value> {
+ if (inputs.size() != 1) {
+ return std::nullopt;
+ }
+ Value input = inputs[0];
+ if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
+ return std::nullopt;
+ }
+ return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
+ });
+}
+
//===----------------------------------------------------------------------===//
// MakeLoopIndependentOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 485d4448e7c3683..f2a57383cc5bf95 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -16,10 +16,12 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -30,11 +32,13 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
+#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
#include <optional>
#define DEBUG_TYPE "transform-dialect"
diff --git a/mlir/test/Dialect/Func/func-transform.mlir b/mlir/test/Dialect/Func/func-transform.mlir
new file mode 100644
index 000000000000000..6aab07b0cb38a06
--- /dev/null
+++ b/mlir/test/Dialect/Func/func-transform.mlir
@@ -0,0 +1,120 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @basic_cast_and_call
+func.func @basic_cast_and_call() {
+ // CHECK-NEXT: call @second()
+ "test.foo"() : () -> ()
+ // CHECK-NEXT: test.foo
+ // CHECK-NEXT: call @third()
+ func.return
+}
+
+func.func @second() {
+ "test.bar"() : () -> ()
+ func.return
+}
+
+func.func private @third()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %f:3 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+ transform.func.cast_and_call @second before %foo : (!transform.any_op) -> !transform.any_op
+ transform.func.cast_and_call %f#2 after %foo : (!transform.any_op, !transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @non_empty_arg_and_out
+func.func @non_empty_arg_and_out(%arg0 : index) -> i32 {
+ // CHECK-NEXT: %[[FOO:.+]] = "test.foo"
+ %0 = "test.foo"(%arg0) : (index) -> (index)
+ // CHECK-NEXT: %[[CALL:.+]] = call @second(%[[FOO]]) : (index) -> i32
+ %1 = "test.bar"(%0) : (index) -> (i32)
+ // CHECK: return %[[CALL]] : i32
+ func.return %1 : i32
+}
+
+func.func private @second(%arg1 : index) -> i32
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+ %bar = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+ %in = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
+ %out = transform.get_result %bar[0] : (!transform.any_op) -> !transform.any_value
+ transform.func.cast_and_call %f#1(%in) -> %out before %bar
+ : (!transform.any_op, !transform.any_value,
+ !transform.any_value, !transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @multi_arg_and_result
+func.func @multi_arg_and_result(%arg0 : index) -> (index, index) {
+ // CHECK-NEXT: %[[FOO:.+]] = "test.foo"
+ %0 = "test.foo"(%arg0) : (index) -> (index)
+ %1 = "test.bar"(%0) : (index) -> (index)
+ %2 = "test.bar"(%0) : (index) -> (index)
+ // CHECK: %[[CALL:.+]]:2 = call @second(%[[FOO]], %[[FOO]]) : (index, index) -> (index, index)
+ // CHECK: return %[[CALL]]#0, %[[CALL]]#1 : index, index
+ func.return %1, %2 : index, index
+}
+
+func.func private @second(%arg1: index, %arg2: index) -> (index, index)
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+ %bars = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+ %in0 = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
+ %in1 = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
+ %ins = transform.merge_handles %in0, %in1 : !transform.any_value
+
+ %outs = transform.get_result %bars[0] : (!transform.any_op) -> !transform.any_value
+
+ transform.func.cast_and_call %f#1(%ins) -> %outs after %foo
+ : (!transform.any_op, !transform.any_value,
+ !transform.any_value, !transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @nested_call
+func.func @nested_call() {
+ // CHECK-NEXT: %[[ARG:.+]] = "test.arg"
+ // CHECK-NEXT: test.foo
+ %0 = "test.arg"() : () -> (index)
+ "test.foo"() ({
+ // CHECK-NEXT: call @second(%[[ARG]]) : (index) -> ()
+ "test.bar"(%0) : (index) -> ()
+ }) : () -> ()
+}
+
+func.func private @second(%arg1: index) -> ()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %arg = transform.structured.match ops{["test.arg"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+ %bar = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+ %in = transform.get_result %arg[0] : (!transform.any_op) -> !transform.any_value
+
+ transform.func.cast_and_call %f#1(%in) before %bar
+ : (!transform.any_op, !transform.any_value, !transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Tensor/transform-op-casting.mlir b/mlir/test/Dialect/Tensor/transform-op-casting.mlir
new file mode 100644
index 000000000000000..16a1fa2b0ba9c75
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/transform-op-casting.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file | FileCheck %s
+
+func.func @cast_to_dynamic(%arg0: tensor<10x13xf32>, %arg1: tensor<3x13xf32>) -> tensor<13x13xf32> {
+ %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<10x13xf32>, tensor<3x13xf32>) -> tensor<13x13xf32>
+ func.return %0 : tensor<13x13xf32>
+}
+
+func.func private @concat_replacement(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %concat = transform.structured.match ops{["tensor.concat"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+ %ins = transform.get_operand %concat[all] : (!transform.any_op) -> !transform.any_value
+ %out = transform.get_result %concat[all] : (!transform.any_op) -> !transform.any_value
+ transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
+ transform.type_conversion.tensor.cast_shape_dynamic_dims
+ } : (!transform.any_op, !transform.any_value,
+ !transform.any_value, !transform.any_op) -> !transform.any_op
+ transform.apply_dce to %f#0 : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @cast_to_dynamic
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<10x13xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<3x13xf32>
+// CHECK-DAG: %[[CAST0:.+]] = tensor.cast %[[ARG0]] : tensor<10x13xf32> to tensor<?x?xf32>
+// CHECK-DAG: %[[CAST1:.+]] = tensor.cast %[[ARG1]] : tensor<3x13xf32> to tensor<?x?xf32>
+// CHECK: %[[CALL:.+]] = call @concat_replacement(%[[CAST0]], %[[CAST1]])
+// CHECK: %[[CAST_RES:.+]] = tensor.cast %[[CALL]] : tensor<?x?xf32> to tensor<13x13xf32>
+// CHECK: return %[[CAST_RES]] : tensor<13x13xf32>
+
+// -----
+
+func.func @cast_to_static(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
+ func.return %0 : tensor<?xf32>
+}
+
+func.func private @collapse_replacement(%arg0: tensor<4x5xf32>) -> tensor<20xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %concat = transform.structured.match ops{["tensor.collapse_shape"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+ %ins = transform.get_operand %concat[all] : (!transform.any_op) -> !transform.any_value
+ %out = transform.get_result %concat[all] : (!transform.any_op) -> !transform.any_value
+ transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
+ transform.type_conversion.tensor.cast_shape_dynamic_dims ignore_dynamic_info
+ } : (!transform.any_op, !transform.any_value,
+ !transform.any_value, !transform.any_op) -> !transform.any_op
+ transform.apply_dce to %f#0 : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @cast_to_static
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[CAST_IN:.+]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x5xf32>
+// CHECK: %[[CALL:.+]] = call @collapse_replacement(%[[CAST_IN]])
+// CHECK: %[[CAST_RES:.+]] = tensor.cast %[[CALL]] : tensor<20xf32> to tensor<?xf32>
+// CHECK: return %[[CAST_RES]] : tensor<?xf32>
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 54036f7929d1b8f..c00cc560e83e9b6 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -502,7 +502,8 @@ def ApplyTestConversionPatternsOp
def TestTypeConverterOp
: Op<Transform_Dialect, "apply_conversion_patterns.transform.test_type_converter",
- [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface>]> {
+ [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
+ ["getTypeConverter"]>]> {
let arguments = (ins);
let results = (outs);
let assemblyFormat = "attr-dict";
More information about the Mlir-commits
mailing list