[Mlir-commits] [mlir] 42b1603 - [mlir][transform] Add an op for replacing values with function calls (#78398)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 19 10:21:56 PST 2024


Author: Quinn Dawkins
Date: 2024-01-19T13:21:52-05:00
New Revision: 42b160356fe5d3b41bf07c428d0142d3721b1d44

URL: https://github.com/llvm/llvm-project/commit/42b160356fe5d3b41bf07c428d0142d3721b1d44
DIFF: https://github.com/llvm/llvm-project/commit/42b160356fe5d3b41bf07c428d0142d3721b1d44.diff

LOG: [mlir][transform] Add an op for replacing values with function calls (#78398)

Adds `transform.func.cast_and_call` that takes a set of inputs and
outputs and replaces the uses of those outputs with a call to a function
at a specified insertion point.

The idea with this operation is to allow users to author independent IR
outside of a to-be-compiled module, and then match and replace a slice
of the program with a call to the external function.

Additionally adds a mechanism for populating a type converter with a set
of conversion materialization functions that allow insertion of
casts on the inputs/outputs to and from the types of the function
signature.

Added: 
    mlir/test/Dialect/Func/func-transform.mlir
    mlir/test/Dialect/Tensor/transform-op-casting.mlir

Modified: 
    mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
    mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
    mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
    mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
    mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index 7a7e991c7861886..c36fdd150556208 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -12,6 +12,8 @@
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/RegionKindInterface.td"
 include "mlir/IR/OpBase.td"
 
 def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
@@ -26,4 +28,74 @@ def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def CastAndCallOp : Op<Transform_Dialect,
+    "func.cast_and_call",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     AttrSizedOperandSegments,
+     ReportTrackingListenerFailuresOpTrait]
+        # GraphRegionNoTerminator.traits> {
+  let summary = "Casts values to the signature of a function and replaces them "
+                "with a call";
+  let description = [{
+    This transform takes value handles to a set of `inputs` and `outputs` and
+    attempts to cast them to the function signature of the attached function
+    op, then builds a call to the function and replaces the users of the
+    outputs. It is the responsibility of the user to ensure that the slice of
+    the program replaced by this operation makes sense, i.e. there is no
+    verification that the inputs to this operation have any relation to the
+    outputs outside of basic dominance requirements needed for the call.
+
+    The casting materialization functions are specified in the graph region of
+    this op. They must implement the `TypeConverterBuilderOpInterface`. The
+    order of ops within the region is irrelevant.
+
+    The target function can be specified by a symbol name or by a handle to the
+    operation.
+
+    This transform only reads the operand handles and only replaces the users of
+    the outputs with the results of the call. No handles are consumed and no
+    operations are removed. Users are expected to run cleanup separately if
+    desired.
+
+    Warning: The replacement of the uses of the outputs could invalidate certain
+    restricted value handle types (e.g. `transform.block_arg` if it existed, by
+    replacing the use with something not coming from a block argument). The
+    value will still exist in such cases but wouldn't verify against the type.
+    See the discussion here for more information:
+    https://github.com/llvm/llvm-project/pull/78398#discussion_r1455070087
+
+    This transform will emit a silenceable failure if:
+     - The set of outputs isn't unique
+     - The handle for the insertion point does not include exactly one operation
+     - The insertion point op does not dominate any of the output users
+     - The insertion point op is not dominated by any of the inputs
+     - The function signature does not match the number of inputs/outputs
+
+    This transform will emit a definite failure if it fails to resolve the
+    target function, or if it fails to materialize the conversion casts of
+    either the inputs to the function argument types, or the call results to
+    the output types.
+  }];
+
+  let arguments = (ins
+    TransformHandleTypeInterface:$insertion_point,
+    UnitAttr:$insert_after,
+    Optional<TransformValueHandleTypeInterface>:$inputs,
+    Optional<TransformValueHandleTypeInterface>:$outputs,
+    OptionalAttr<SymbolRefAttr>:$function_name,
+    Optional<TransformHandleTypeInterface>:$function);
+  let results = (outs TransformHandleTypeInterface:$result);
+  let regions = (region MaxSizedRegion<1>:$conversions);
+
+  let assemblyFormat = [{
+    ($function_name^)? ($function^)?
+    ( `(` $inputs^ `)` )?
+    ( `->` $outputs^ )?
+    (`after` $insert_after^):(`before`)? $insertion_point
+    ($conversions^)? attr-dict `:` functional-type(operands, results)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // FUNC_TRANSFORM_OPS

diff  --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 76309b9b8a9640d..29383a3825be883 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -18,7 +18,8 @@ include "mlir/IR/OpBase.td"
 def MemrefToLLVMTypeConverterOp : Op<Transform_Dialect,
     "apply_conversion_patterns.memref.memref_to_llvm_type_converter",
     [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
-                               ["getTypeConverterType"]>]> {
+                               ["getTypeConverter",
+                                "getTypeConverterType"]>]> {
   let description = [{
     This operation provides an "LLVMTypeConverter" that lowers memref types to
     LLVM types.

diff  --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 8556d9570fd1200..39e1d7fa3494a39 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -169,4 +169,22 @@ def MakeLoopIndependentOp
   }];
 }
 
+def TypeConversionCastShapeDynamicDimsOp : Op<Transform_Dialect,
+    "type_conversion.tensor.cast_shape_dynamic_dims",
+    [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
+                               ["populateTypeMaterializations"]>]> {
+  let description = [{
+    Populates a type converter with conversion materialization functions that
+    cast a tensor value between two cast-compatible tensors. See `tensor.cast`
+    for more information on cast compatibility between tensors.
+
+    If `ignore_dynamic_info` is not set, this will set an additional constraint
+    that source materializations do not cast dynamic dimensions to static ones.
+  }];
+  let arguments = (ins UnitAttr:$ignore_dynamic_info);
+
+  let assemblyFormat =
+      "(`ignore_dynamic_info` $ignore_dynamic_info^)? attr-dict";
+}
+
 #endif // TENSOR_TRANSFORM_OPS

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index f29efaee620d845..8f7b8f1999e0c59 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -284,8 +284,14 @@ def TypeConverterBuilderOpInterface
     : OpInterface<"TypeConverterBuilderOpInterface"> {
   let description = [{
     This interface should be implemented by ops that specify a type converter
-    for a dialect conversion. Such ops can be used with
-    "apply_conversion_patterns".
+    for a dialect conversion, or to populate a type converter with
+    conversions.
+
+    When such ops are intended to be used with "apply_conversion_patterns" or
+    other operations that expect a type converter, a non-default implementation
+    of `getTypeConverter` should be implemented. For use with "cast_and_call"
+    like ops that construct a type converter iteratively, non-default
+    `populateTypeMaterializations` should be implemented.
   }];
 
   let cppNamespace = "::mlir::transform";
@@ -297,7 +303,11 @@ def TypeConverterBuilderOpInterface
       }],
       /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
       /*name=*/"getTypeConverter",
-      /*arguments=*/(ins)
+      /*arguments=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return std::make_unique<::mlir::TypeConverter>();
+      }]
     >,
     StaticInterfaceMethod<
       /*desc=*/[{
@@ -310,6 +320,17 @@ def TypeConverterBuilderOpInterface
       /*methodBody=*/"",
       /*defaultImplementation=*/[{ return "TypeConverter"; }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Populate the given type converter with source/target materialization
+        functions.
+      }],
+      /*returnType=*/"void",
+      /*name=*/"populateTypeMaterializations",
+      /*arguments=*/(ins "::mlir::TypeConverter &":$converter),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return; }]
+    >,
   ];
 }
 

diff  --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 9e9b6bcea790def..9e79b086c0be841 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 
@@ -36,6 +37,196 @@ transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// CastAndCallOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
+                                transform::TransformResults &results,
+                                transform::TransformState &state) {
+  SmallVector<Value> inputs;
+  if (getInputs())
+    llvm::append_range(inputs, state.getPayloadValues(getInputs()));
+
+  SetVector<Value> outputs;
+  if (getOutputs()) {
+    for (auto output : state.getPayloadValues(getOutputs()))
+      outputs.insert(output);
+
+    // Verify that the set of output values to be replaced is unique.
+    if (outputs.size() !=
+        llvm::range_size(state.getPayloadValues(getOutputs()))) {
+      return emitSilenceableFailure(getLoc())
+             << "cast and call output values must be unique";
+    }
+  }
+
+  // Get the insertion point for the call.
+  auto insertionOps = state.getPayloadOps(getInsertionPoint());
+  if (!llvm::hasSingleElement(insertionOps)) {
+    return emitSilenceableFailure(getLoc())
+           << "Only one op can be specified as an insertion point";
+  }
+  bool insertAfter = getInsertAfter();
+  Operation *insertionPoint = *insertionOps.begin();
+
+  // Check that all inputs dominate the insertion point, and the insertion
+  // point dominates all users of the outputs.
+  DominanceInfo dom(insertionPoint);
+  for (Value output : outputs) {
+    for (Operation *user : output.getUsers()) {
+      // If we are inserting after the insertion point operation, the
+      // insertion point operation must properly dominate the user. Otherwise
+      // basic dominance is enough.
+      bool doesDominate = insertAfter
+                              ? dom.properlyDominates(insertionPoint, user)
+                              : dom.dominates(insertionPoint, user);
+      if (!doesDominate) {
+        return emitDefiniteFailure()
+               << "User " << user << " is not dominated by insertion point "
+               << insertionPoint;
+      }
+    }
+  }
+
+  for (Value input : inputs) {
+    // If we are inserting before the insertion point operation, the
+    // input must properly dominate the insertion point operation. Otherwise
+    // basic dominance is enough.
+    bool doesDominate = insertAfter
+                            ? dom.dominates(input, insertionPoint)
+                            : dom.properlyDominates(input, insertionPoint);
+    if (!doesDominate) {
+      return emitDefiniteFailure()
+             << "input " << input << " does not dominate insertion point "
+             << insertionPoint;
+    }
+  }
+
+  // Get the function to call. This can either be specified by symbol or as a
+  // transform handle.
+  func::FuncOp targetFunction = nullptr;
+  if (getFunctionName()) {
+    targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
+        insertionPoint, *getFunctionName());
+    if (!targetFunction) {
+      return emitDefiniteFailure()
+             << "unresolved symbol " << *getFunctionName();
+    }
+  } else if (getFunction()) {
+    auto payloadOps = state.getPayloadOps(getFunction());
+    if (!llvm::hasSingleElement(payloadOps)) {
+      return emitDefiniteFailure() << "requires a single function to call";
+    }
+    targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
+    if (!targetFunction) {
+      return emitDefiniteFailure() << "invalid non-function callee";
+    }
+  } else {
+    llvm_unreachable("Invalid CastAndCall op without a function to call");
+    return emitDefiniteFailure();
+  }
+
+  // Verify that the function argument and result lengths match the inputs and
+  // outputs given to this op.
+  if (targetFunction.getNumArguments() != inputs.size()) {
+    return emitSilenceableFailure(targetFunction.getLoc())
+           << "mismatch between number of function arguments "
+           << targetFunction.getNumArguments() << " and number of inputs "
+           << inputs.size();
+  }
+  if (targetFunction.getNumResults() != outputs.size()) {
+    return emitSilenceableFailure(targetFunction.getLoc())
+           << "mismatch between number of function results "
+           << targetFunction->getNumResults() << " and number of outputs "
+           << outputs.size();
+  }
+
+  // Gather all specified converters.
+  mlir::TypeConverter converter;
+  if (!getRegion().empty()) {
+    for (Operation &op : getRegion().front()) {
+      cast<transform::TypeConverterBuilderOpInterface>(&op)
+          .populateTypeMaterializations(converter);
+    }
+  }
+
+  if (insertAfter)
+    rewriter.setInsertionPointAfter(insertionPoint);
+  else
+    rewriter.setInsertionPoint(insertionPoint);
+
+  for (auto [input, type] :
+       llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
+    if (input.getType() != type) {
+      Value newInput = converter.materializeSourceConversion(
+          rewriter, input.getLoc(), type, input);
+      if (!newInput) {
+        return emitDefiniteFailure() << "Failed to materialize conversion of "
+                                     << input << " to type " << type;
+      }
+      input = newInput;
+    }
+  }
+
+  auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
+                                              targetFunction, inputs);
+
+  // Cast the call results back to the expected types. If any conversions fail
+  // this is a definite failure as the call has been constructed at this point.
+  for (auto [output, newOutput] :
+       llvm::zip_equal(outputs, callOp.getResults())) {
+    Value convertedOutput = newOutput;
+    if (output.getType() != newOutput.getType()) {
+      convertedOutput = converter.materializeTargetConversion(
+          rewriter, output.getLoc(), output.getType(), newOutput);
+      if (!convertedOutput) {
+        return emitDefiniteFailure()
+               << "Failed to materialize conversion of " << newOutput
+               << " to type " << output.getType();
+      }
+    }
+    rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
+  }
+  results.set(cast<OpResult>(getResult()), {callOp});
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::CastAndCallOp::verify() {
+  if (!getRegion().empty()) {
+    for (Operation &op : getRegion().front()) {
+      if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
+        InFlightDiagnostic diag = emitOpError()
+                                  << "expected children ops to implement "
+                                     "TypeConverterBuilderOpInterface";
+        diag.attachNote(op.getLoc()) << "op without interface";
+        return diag;
+      }
+    }
+  }
+  if (!getFunction() && !getFunctionName()) {
+    return emitOpError() << "expected a function handle or name to call";
+  }
+  if (getFunction() && getFunctionName()) {
+    return emitOpError() << "function handle and name are mutually exclusive";
+  }
+  return success();
+}
+
+void transform::CastAndCallOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getInsertionPoint(), effects);
+  if (getInputs())
+    transform::onlyReadsHandle(getInputs(), effects);
+  if (getOutputs())
+    transform::onlyReadsHandle(getOutputs(), effects);
+  if (getFunction())
+    transform::onlyReadsHandle(getFunction(), effects);
+  transform::producesHandle(getResult(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index ed274238704713c..38f1824a3634a35 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -15,6 +15,8 @@
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 using namespace tensor;
@@ -128,6 +130,44 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
   tensor::populateRewriteAsConstantPatterns(patterns);
 }
 
+//===----------------------------------------------------------------------===//
+// TypeConversionCastTensorShapeOp
+//===----------------------------------------------------------------------===//
+
+void transform::TypeConversionCastShapeDynamicDimsOp::
+    populateTypeMaterializations(TypeConverter &converter) {
+  bool ignoreDynamicInfo = getIgnoreDynamicInfo();
+  converter.addSourceMaterialization([ignoreDynamicInfo](
+                                         OpBuilder &builder, Type resultType,
+                                         ValueRange inputs,
+                                         Location loc) -> std::optional<Value> {
+    if (inputs.size() != 1) {
+      return std::nullopt;
+    }
+    Value input = inputs[0];
+    if (!ignoreDynamicInfo &&
+        !tensor::preservesStaticInformation(resultType, input.getType())) {
+      return std::nullopt;
+    }
+    if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
+      return std::nullopt;
+    }
+    return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
+  });
+  converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
+                                        ValueRange inputs,
+                                        Location loc) -> std::optional<Value> {
+    if (inputs.size() != 1) {
+      return std::nullopt;
+    }
+    Value input = inputs[0];
+    if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
+      return std::nullopt;
+    }
+    return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
+  });
+}
+
 //===----------------------------------------------------------------------===//
 // MakeLoopIndependentOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 485d4448e7c3683..f2a57383cc5bf95 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -16,10 +16,12 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
@@ -30,11 +32,13 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
 #include <optional>
 
 #define DEBUG_TYPE "transform-dialect"

diff  --git a/mlir/test/Dialect/Func/func-transform.mlir b/mlir/test/Dialect/Func/func-transform.mlir
new file mode 100644
index 000000000000000..6aab07b0cb38a06
--- /dev/null
+++ b/mlir/test/Dialect/Func/func-transform.mlir
@@ -0,0 +1,120 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @basic_cast_and_call
+func.func @basic_cast_and_call() {
+  // CHECK-NEXT: call @second()
+  "test.foo"() : () -> ()
+  // CHECK-NEXT: test.foo
+  // CHECK-NEXT: call @third()
+  func.return
+}
+
+func.func @second() {
+  "test.bar"() : () -> ()
+  func.return
+}
+
+func.func private @third()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:3 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    transform.func.cast_and_call @second before %foo : (!transform.any_op) -> !transform.any_op
+    transform.func.cast_and_call %f#2 after %foo : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @non_empty_arg_and_out
+func.func @non_empty_arg_and_out(%arg0 : index) -> i32 {
+  // CHECK-NEXT: %[[FOO:.+]] = "test.foo"
+  %0 = "test.foo"(%arg0) : (index) -> (index)
+  // CHECK-NEXT: %[[CALL:.+]] = call @second(%[[FOO]]) : (index) -> i32
+  %1 = "test.bar"(%0) : (index) -> (i32)
+  // CHECK: return %[[CALL]] : i32
+  func.return %1 : i32
+}
+
+func.func private @second(%arg1 : index) -> i32
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %bar = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %in = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
+    %out = transform.get_result %bar[0] : (!transform.any_op) -> !transform.any_value
+    transform.func.cast_and_call %f#1(%in) -> %out before %bar
+        : (!transform.any_op, !transform.any_value,
+           !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @multi_arg_and_result
+func.func @multi_arg_and_result(%arg0 : index) -> (index, index) {
+  // CHECK-NEXT: %[[FOO:.+]] = "test.foo"
+  %0 = "test.foo"(%arg0) : (index) -> (index)
+  %1 = "test.bar"(%0) : (index) -> (index)
+  %2 = "test.bar"(%0) : (index) -> (index)
+  // CHECK: %[[CALL:.+]]:2 = call @second(%[[FOO]], %[[FOO]]) : (index, index) -> (index, index)
+  // CHECK: return %[[CALL]]#0, %[[CALL]]#1 : index, index
+  func.return %1, %2 : index, index
+}
+
+func.func private @second(%arg1: index, %arg2: index) -> (index, index)
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %bars = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %in0 = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
+    %in1 = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
+    %ins = transform.merge_handles %in0, %in1 : !transform.any_value
+
+    %outs = transform.get_result %bars[0] : (!transform.any_op) -> !transform.any_value
+
+    transform.func.cast_and_call %f#1(%ins) -> %outs after %foo
+        : (!transform.any_op, !transform.any_value,
+           !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @nested_call
+func.func @nested_call() {
+  // CHECK-NEXT: %[[ARG:.+]] = "test.arg"
+  // CHECK-NEXT: test.foo
+  %0 = "test.arg"() : () -> (index)
+  "test.foo"() ({
+    // CHECK-NEXT: call @second(%[[ARG]]) : (index) -> ()
+    "test.bar"(%0) : (index) -> ()
+  }) : () -> ()
+}
+
+func.func private @second(%arg1: index) -> ()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %arg = transform.structured.match ops{["test.arg"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %bar = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %in = transform.get_result %arg[0] : (!transform.any_op) -> !transform.any_value
+
+    transform.func.cast_and_call %f#1(%in) before %bar
+        : (!transform.any_op, !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

diff  --git a/mlir/test/Dialect/Tensor/transform-op-casting.mlir b/mlir/test/Dialect/Tensor/transform-op-casting.mlir
new file mode 100644
index 000000000000000..16a1fa2b0ba9c75
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/transform-op-casting.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file | FileCheck %s
+
+func.func @cast_to_dynamic(%arg0: tensor<10x13xf32>, %arg1: tensor<3x13xf32>) -> tensor<13x13xf32> {
+  %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<10x13xf32>, tensor<3x13xf32>) -> tensor<13x13xf32>
+  func.return %0 : tensor<13x13xf32>
+}
+
+func.func private @concat_replacement(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %concat = transform.structured.match ops{["tensor.concat"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %ins = transform.get_operand %concat[all] : (!transform.any_op) -> !transform.any_value
+    %out = transform.get_result %concat[all] : (!transform.any_op) -> !transform.any_value
+    transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
+      transform.type_conversion.tensor.cast_shape_dynamic_dims
+    } : (!transform.any_op, !transform.any_value,
+         !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.apply_dce to %f#0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @cast_to_dynamic
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<10x13xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<3x13xf32>
+//   CHECK-DAG:   %[[CAST0:.+]] = tensor.cast %[[ARG0]] : tensor<10x13xf32> to tensor<?x?xf32>
+//   CHECK-DAG:   %[[CAST1:.+]] = tensor.cast %[[ARG1]] : tensor<3x13xf32> to tensor<?x?xf32>
+//       CHECK:   %[[CALL:.+]] = call @concat_replacement(%[[CAST0]], %[[CAST1]])
+//       CHECK:   %[[CAST_RES:.+]] = tensor.cast %[[CALL]] : tensor<?x?xf32> to tensor<13x13xf32>
+//       CHECK:   return %[[CAST_RES]] : tensor<13x13xf32>
+
+// -----
+
+func.func @cast_to_static(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+func.func private @collapse_replacement(%arg0: tensor<4x5xf32>) -> tensor<20xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %concat = transform.structured.match ops{["tensor.collapse_shape"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %ins = transform.get_operand %concat[all] : (!transform.any_op) -> !transform.any_value
+    %out = transform.get_result %concat[all] : (!transform.any_op) -> !transform.any_value
+    transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
+      transform.type_conversion.tensor.cast_shape_dynamic_dims ignore_dynamic_info
+    } : (!transform.any_op, !transform.any_value,
+         !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.apply_dce to %f#0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @cast_to_static
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//   CHECK-DAG:   %[[CAST_IN:.+]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x5xf32>
+//       CHECK:   %[[CALL:.+]] = call @collapse_replacement(%[[CAST_IN]])
+//       CHECK:   %[[CAST_RES:.+]] = tensor.cast %[[CALL]] : tensor<20xf32> to tensor<?xf32>
+//       CHECK:   return %[[CAST_RES]] : tensor<?xf32>

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 54036f7929d1b8f..c00cc560e83e9b6 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -502,7 +502,8 @@ def ApplyTestConversionPatternsOp
 
 def TestTypeConverterOp
   : Op<Transform_Dialect, "apply_conversion_patterns.transform.test_type_converter",
-      [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface>]> {
+      [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
+                                 ["getTypeConverter"]>]> {
   let arguments = (ins);
   let results = (outs);
   let assemblyFormat = "attr-dict";


        


More information about the Mlir-commits mailing list