[Mlir-commits] [mlir] [mlir][transform] Add an op for replacing values with function calls (PR #78398)

Quinn Dawkins llvmlistbot at llvm.org
Thu Jan 18 06:34:47 PST 2024


https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/78398

>From 2ea50df34f0263a8f0a99a60b855f8e52e0fceb2 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Tue, 16 Jan 2024 15:34:23 -0500
Subject: [PATCH 1/2] [mlir][transform] Add an op for replacing values with
 function calls

Adds `transform.func.cast_and_call` that takes a set of inputs and
outputs and replaces the uses of those outputs with a call to a function
at a specified insertion point.

The idea with this operation is to allow users to author independent IR
outside of a to-be-compiled module, and then match and replace a slice of
the program with a call to the external function.

Additionally adds a mechanism for populating a type converter with a set
of conversion materialization functions that allow insertion of
casts on the inputs/outputs to and from the types of the function
signature.
---
 .../Func/TransformOps/FuncTransformOps.td     |  65 ++++++
 .../Tensor/TransformOps/TensorTransformOps.td |  13 ++
 .../Transform/IR/TransformInterfaces.td       |  22 ++
 .../Func/TransformOps/FuncTransformOps.cpp    | 197 ++++++++++++++++++
 .../TransformOps/TensorTransformOps.cpp       |  40 ++++
 .../lib/Dialect/Transform/IR/TransformOps.cpp |   4 +
 mlir/test/Dialect/Func/func-transform.mlir    | 120 +++++++++++
 .../Dialect/Tensor/transform-op-casting.mlir  |  65 ++++++
 8 files changed, 526 insertions(+)
 create mode 100644 mlir/test/Dialect/Func/func-transform.mlir
 create mode 100644 mlir/test/Dialect/Tensor/transform-op-casting.mlir

diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index 7a7e991c786188..e5086c26c55a4f 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -12,6 +12,8 @@
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/RegionKindInterface.td"
 include "mlir/IR/OpBase.td"
 
 def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
@@ -26,4 +28,67 @@ def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def CastAndCallOp : Op<Transform_Dialect,
+    "func.cast_and_call",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     AttrSizedOperandSegments,
+     ReportTrackingListenerFailuresOpTrait]
+        # GraphRegionNoTerminator.traits> {
+  let summary = "Casts values to the signature of a function and replaces them "
+                "with a call";
+  let description = [{
+    This transform takes a set of |input| and |output| value handles and
+    attempts to cast them to the function signature of the attached function
+    op, then builds a call to the function and replaces the users of the
+    outputs. It is the responsibility of the user to ensure that the slice of
+    the program replaced by this operation makes sense, i.e. there is no
+    verification that the inputs to this operation have any relation to the
+    outputs outside of basic dominance requirements needed for the replacement.
+
+    The casting materialization functions are specified in the graph region of
+    this op. They must implement the `TypeConversionOpInterface`. The order of
+    ops within the region is irrelevant.
+
+    The target function can be specified by a symbol name or by a handle to the
+    operation.
+
+    This transform only reads the target handles and only replaces the users of
+    the outputs with the results of the call. No handles are consumed and no
+    operations are removed. Users are expected to run cleanup separately if
+    desired.
+
+    This transform will emit a silenceable failure if:
+     - The set of outputs isn't unique
+     - The handle for the insertion point does not include exactly one operation
+     - The insertion point op does not dominate any of the output users
+     - The insertion point op is not dominated by any of the inputs
+     - The function signature does not match the number of inputs/outputs
+     - Any of the input conversions fail to be materialized
+
+    This transform will emit a definite failure if it fails to resolve the
+    target function, or if it fails to materialize the conversion from the call
+    results to the output types.
+  }];
+
+  let arguments = (ins
+    TransformHandleTypeInterface:$insertion_point,
+    UnitAttr:$insert_after,
+    Optional<TransformValueHandleTypeInterface>:$inputs,
+    Optional<TransformValueHandleTypeInterface>:$outputs,
+    OptionalAttr<SymbolRefAttr>:$function_name,
+    Optional<TransformHandleTypeInterface>:$function);
+  let results = (outs TransformHandleTypeInterface:$result);
+  let regions = (region MaxSizedRegion<1>:$conversions);
+
+  let assemblyFormat = [{
+    ($function_name^)? ($function^)?
+    ( `(` $inputs^ `)` )?
+    ( `->` $outputs^ )?
+    (`after` $insert_after^):(`before`)? $insertion_point
+    ($conversions^)? attr-dict `:` functional-type(operands, results)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // FUNC_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 8556d9570fd120..28e9249c82e309 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -169,4 +169,17 @@ def MakeLoopIndependentOp
   }];
 }
 
+def TypeConversionCastOp : Op<Transform_Dialect,
+    "type_conversion.tensor.cast",
+    [DeclareOpInterfaceMethods<TypeConversionOpInterface>]> {
+  let description = [{
+    Indicates that tensor ops (such as tensor.generate) should be replaced with
+    constants (arith.constant) when possible.
+  }];
+  let arguments = (ins UnitAttr:$ignore_dynamic_info);
+
+  let assemblyFormat =
+      "(`ignore_dynamic_info` $ignore_dynamic_info^)? attr-dict";
+}
+
 #endif // TENSOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index f29efaee620d84..3b601f42a6452d 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -280,6 +280,28 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
   ];
 }
 
+def TypeConversionOpInterface : OpInterface<"TypeConversionOpInterface"> {
+  let description = [{
+    This interface should be implemented by ops that populate type casting
+    of a `transform.cast_and_inline` op. It provides a method to populate a
+    type converter with source/target materialization patterns.
+  }];
+
+  let cppNamespace = "::mlir::transform";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Populate the given type converter with source/target materialization
+        functions.
+      }],
+      /*returnType=*/"void",
+      /*name=*/"populateTypeMaterializations",
+      /*arguments=*/(ins "::mlir::TypeConverter &":$converter)
+    >
+  ];
+}
+
 def TypeConverterBuilderOpInterface
     : OpInterface<"TypeConverterBuilderOpInterface"> {
   let description = [{
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 9e9b6bcea790de..14b6e633520d6c 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 
@@ -36,6 +37,202 @@ transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// CastAndCallOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
+                                transform::TransformResults &results,
+                                transform::TransformState &state) {
+  SmallVector<Value> inputs;
+  if (getInputs())
+    for (Value input : state.getPayloadValues(getInputs()))
+      inputs.push_back(input);
+  SmallVector<Value> outputs;
+  if (getOutputs())
+    for (Value output : state.getPayloadValues(getOutputs()))
+      outputs.push_back(output);
+
+  // Verify that the set of output values to be replaced is unique.
+  llvm::SmallDenseSet<Value> outputSet;
+  for (Value output : outputs) {
+    outputSet.insert(output);
+  }
+  if (outputSet.size() != outputs.size()) {
+    return emitSilenceableFailure(getLoc())
+           << "cast and call output values must be unique";
+  }
+
+  // Get the insertion point for the call.
+  auto insertionOps = state.getPayloadOps(getInsertionPoint());
+  if (!llvm::hasSingleElement(insertionOps)) {
+    return emitSilenceableFailure(getLoc())
+           << "Only one op can be specified as an insertion point";
+  }
+  bool insertAfter = getInsertAfter();
+  Operation *insertionPoint = *insertionOps.begin();
+
+  // Check that all inputs dominate the insertion point, and the insertion
+  // point dominates all users of the outputs.
+  DominanceInfo dom(insertionPoint);
+  for (Value output : outputs) {
+    for (Operation *user : output.getUsers()) {
+      // If we are inserting after the insertion point operation, the
+      // insertion point operation must properly dominate the user. Otherwise
+      // basic dominance is enough.
+      bool doesDominate = insertAfter
+                              ? dom.properlyDominates(insertionPoint, user)
+                              : dom.dominates(insertionPoint, user);
+      if (!doesDominate) {
+        return emitDefiniteFailure()
+               << "User " << user << " is not dominated by insertion point "
+               << insertionPoint;
+      }
+    }
+  }
+
+  for (Value input : inputs) {
+    // If we are inserting before the insertion point operation, the
+    // input must properly dominate the insertion point operation. Otherwise
+    // basic dominance is enough.
+    bool doesDominate = insertAfter
+                            ? dom.dominates(input, insertionPoint)
+                            : dom.properlyDominates(input, insertionPoint);
+    if (!doesDominate) {
+      return emitDefiniteFailure()
+             << "input " << input << " does not dominate insertion point "
+             << insertionPoint;
+    }
+  }
+
+  // Get the function to inline. This can either be specified by symbol or as a
+  // transform handle.
+  func::FuncOp targetFunction = nullptr;
+  if (getFunctionName()) {
+    targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
+        insertionPoint, *getFunctionName());
+    if (!targetFunction) {
+      return emitDefiniteFailure()
+             << "unresolved symbol " << *getFunctionName();
+    }
+  } else if (getFunction()) {
+    auto payloadOps = state.getPayloadOps(getFunction());
+    if (!llvm::hasSingleElement(payloadOps)) {
+      return emitDefiniteFailure() << "requires a single function to call";
+    }
+    targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
+    if (!targetFunction) {
+      return emitDefiniteFailure() << "invalid non-function callee";
+    }
+  } else {
+    llvm_unreachable("Invalid CastAndCall op without a function to call");
+    return emitDefiniteFailure();
+  }
+  assert(targetFunction && "no target function found");
+
+  // Verify that the function argument and result lengths match the inputs and
+  // outputs given to this op.
+  if (targetFunction.getNumArguments() != inputs.size()) {
+    return emitSilenceableFailure(targetFunction.getLoc())
+           << "mismatch between number of function arguments "
+           << targetFunction.getNumArguments() << " and number of inputs "
+           << inputs.size();
+  }
+  if (targetFunction.getNumResults() != outputs.size()) {
+    return emitSilenceableFailure(targetFunction.getLoc())
+           << "mismatch between number of function results "
+           << targetFunction->getNumResults() << " and number of outputs "
+           << outputs.size();
+  }
+
+  // Gather all specified converters.
+  MLIRContext *ctx = insertionPoint->getContext();
+  mlir::TypeConverter converter;
+  if (!getRegion().empty()) {
+    for (Operation &op : getRegion().front()) {
+      cast<transform::TypeConversionOpInterface>(&op)
+          .populateTypeMaterializations(converter);
+    }
+  }
+
+  OpBuilder builder(ctx);
+  if (insertAfter)
+    builder.setInsertionPointAfter(insertionPoint);
+  else
+    builder.setInsertionPoint(insertionPoint);
+
+  for (auto [input, type] :
+       llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
+    if (input.getType() != type) {
+      Value newInput = converter.materializeSourceConversion(
+          builder, input.getLoc(), type, input);
+      if (!newInput) {
+        return emitSilenceableFailure(input.getLoc())
+               << "Failed to materialize conversion of " << input << " to type "
+               << type;
+      }
+      input = newInput;
+    }
+  }
+
+  auto callOp = builder.create<func::CallOp>(insertionPoint->getLoc(),
+                                             targetFunction, inputs);
+
+  // Cast the call results back to the expected types. If any conversions fail
+  // this is a definite failure as the call has been constructed at this point.
+  for (auto [output, newOutput] :
+       llvm::zip_equal(outputs, callOp.getResults())) {
+    Value convertedOutput = newOutput;
+    if (output.getType() != newOutput.getType()) {
+      convertedOutput = converter.materializeTargetConversion(
+          builder, output.getLoc(), output.getType(), newOutput);
+      if (!convertedOutput) {
+        return emitSilenceableFailure(output.getLoc())
+               << "Failed to materialize conversion of " << newOutput
+               << " to type " << output.getType();
+      }
+    }
+    output.replaceAllUsesExcept(convertedOutput, callOp);
+  }
+  results.set(cast<OpResult>(getResult()), {callOp});
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::CastAndCallOp::verify() {
+  if (!getRegion().empty()) {
+    for (Operation &op : getRegion().front()) {
+      if (!isa<transform::TypeConversionOpInterface>(&op)) {
+        InFlightDiagnostic diag = emitOpError()
+                                  << "expected children ops to implement "
+                                     "TypeConversionOpInterface";
+        diag.attachNote(op.getLoc()) << "op without interface";
+        return diag;
+      }
+    }
+  }
+  if (!getFunction() && !getFunctionName()) {
+    return emitOpError() << "expected a function handle or name to call";
+  }
+  if (getFunction() && getFunctionName()) {
+    return emitOpError() << "function handle and name are mutually exclusive";
+  }
+  return success();
+}
+
+void transform::CastAndCallOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getInsertionPoint(), effects);
+  if (getInputs())
+    transform::onlyReadsHandle(getInputs(), effects);
+  if (getOutputs())
+    transform::onlyReadsHandle(getOutputs(), effects);
+  if (getFunction())
+    transform::onlyReadsHandle(getFunction(), effects);
+  transform::producesHandle(getResult(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index ed274238704713..0c89ba2a1f1895 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -15,6 +15,8 @@
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 using namespace tensor;
@@ -128,6 +130,44 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
   tensor::populateRewriteAsConstantPatterns(patterns);
 }
 
+//===----------------------------------------------------------------------===//
+// TypeConversionCastOp
+//===----------------------------------------------------------------------===//
+
+void transform::TypeConversionCastOp::populateTypeMaterializations(
+    TypeConverter &converter) {
+  bool ignoreDynamicInfo = getIgnoreDynamicInfo();
+  converter.addSourceMaterialization([ignoreDynamicInfo](
+                                         OpBuilder &builder, Type resultType,
+                                         ValueRange inputs,
+                                         Location loc) -> std::optional<Value> {
+    if (inputs.size() != 1) {
+      return std::nullopt;
+    }
+    Value input = inputs[0];
+    if (!ignoreDynamicInfo &&
+        !tensor::preservesStaticInformation(resultType, input.getType())) {
+      return std::nullopt;
+    }
+    if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
+      return std::nullopt;
+    }
+    return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
+  });
+  converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
+                                        ValueRange inputs,
+                                        Location loc) -> std::optional<Value> {
+    if (inputs.size() != 1) {
+      return std::nullopt;
+    }
+    Value input = inputs[0];
+    if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
+      return std::nullopt;
+    }
+    return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
+  });
+}
+
 //===----------------------------------------------------------------------===//
 // MakeLoopIndependentOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 485d4448e7c368..f2a57383cc5bf9 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -16,10 +16,12 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
@@ -30,11 +32,13 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
 #include <optional>
 
 #define DEBUG_TYPE "transform-dialect"
diff --git a/mlir/test/Dialect/Func/func-transform.mlir b/mlir/test/Dialect/Func/func-transform.mlir
new file mode 100644
index 00000000000000..6aab07b0cb38a0
--- /dev/null
+++ b/mlir/test/Dialect/Func/func-transform.mlir
@@ -0,0 +1,120 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @basic_cast_and_call
+func.func @basic_cast_and_call() {
+  // CHECK-NEXT: call @second()
+  "test.foo"() : () -> ()
+  // CHECK-NEXT: test.foo
+  // CHECK-NEXT: call @third()
+  func.return
+}
+
+func.func @second() {
+  "test.bar"() : () -> ()
+  func.return
+}
+
+func.func private @third()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:3 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    transform.func.cast_and_call @second before %foo : (!transform.any_op) -> !transform.any_op
+    transform.func.cast_and_call %f#2 after %foo : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @non_empty_arg_and_out
+func.func @non_empty_arg_and_out(%arg0 : index) -> i32 {
+  // CHECK-NEXT: %[[FOO:.+]] = "test.foo"
+  %0 = "test.foo"(%arg0) : (index) -> (index)
+  // CHECK-NEXT: %[[CALL:.+]] = call @second(%[[FOO]]) : (index) -> i32
+  %1 = "test.bar"(%0) : (index) -> (i32)
+  // CHECK: return %[[CALL]] : i32
+  func.return %1 : i32
+}
+
+func.func private @second(%arg1 : index) -> i32
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %bar = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %in = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
+    %out = transform.get_result %bar[0] : (!transform.any_op) -> !transform.any_value
+    transform.func.cast_and_call %f#1(%in) -> %out before %bar
+        : (!transform.any_op, !transform.any_value,
+           !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @multi_arg_and_result
+func.func @multi_arg_and_result(%arg0 : index) -> (index, index) {
+  // CHECK-NEXT: %[[FOO:.+]] = "test.foo"
+  %0 = "test.foo"(%arg0) : (index) -> (index)
+  %1 = "test.bar"(%0) : (index) -> (index)
+  %2 = "test.bar"(%0) : (index) -> (index)
+  // CHECK: %[[CALL:.+]]:2 = call @second(%[[FOO]], %[[FOO]]) : (index, index) -> (index, index)
+  // CHECK: return %[[CALL]]#0, %[[CALL]]#1 : index, index
+  func.return %1, %2 : index, index
+}
+
+func.func private @second(%arg1: index, %arg2: index) -> (index, index)
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %bars = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %in0 = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
+    %in1 = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
+    %ins = transform.merge_handles %in0, %in1 : !transform.any_value
+
+    %outs = transform.get_result %bars[0] : (!transform.any_op) -> !transform.any_value
+
+    transform.func.cast_and_call %f#1(%ins) -> %outs after %foo
+        : (!transform.any_op, !transform.any_value,
+           !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @nested_call
+func.func @nested_call() {
+  // CHECK-NEXT: %[[ARG:.+]] = "test.arg"
+  // CHECK-NEXT: test.foo
+  %0 = "test.arg"() : () -> (index)
+  "test.foo"() ({
+    // CHECK-NEXT: call @second(%[[ARG]]) : (index) -> ()
+    "test.bar"(%0) : (index) -> ()
+  }) : () -> ()
+}
+
+func.func private @second(%arg1: index) -> ()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %arg = transform.structured.match ops{["test.arg"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %bar = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %in = transform.get_result %arg[0] : (!transform.any_op) -> !transform.any_value
+
+    transform.func.cast_and_call %f#1(%in) before %bar
+        : (!transform.any_op, !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Tensor/transform-op-casting.mlir b/mlir/test/Dialect/Tensor/transform-op-casting.mlir
new file mode 100644
index 00000000000000..fd2fc8a1883a3c
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/transform-op-casting.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file | FileCheck %s
+
+func.func @cast_to_dynamic(%arg0: tensor<10x13xf32>, %arg1: tensor<3x13xf32>) -> tensor<13x13xf32> {
+  %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<10x13xf32>, tensor<3x13xf32>) -> tensor<13x13xf32>
+  func.return %0 : tensor<13x13xf32>
+}
+
+func.func private @concat_replacement(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %concat = transform.structured.match ops{["tensor.concat"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %ins = transform.get_operand %concat : (!transform.any_op) -> !transform.any_value
+    %out = transform.get_result %concat : (!transform.any_op) -> !transform.any_value
+    transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
+      transform.type_conversion.tensor.cast
+    } : (!transform.any_op, !transform.any_value,
+         !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.apply_dce to %f#0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @cast_to_dynamic
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<10x13xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<3x13xf32>
+//   CHECK-DAG:   %[[CAST0:.+]] = tensor.cast %[[ARG0]] : tensor<10x13xf32> to tensor<?x?xf32>
+//   CHECK-DAG:   %[[CAST1:.+]] = tensor.cast %[[ARG1]] : tensor<3x13xf32> to tensor<?x?xf32>
+//       CHECK:   %[[CALL:.+]] = call @concat_replacement(%[[CAST0]], %[[CAST1]])
+//       CHECK:   %[[CAST_RES:.+]] = tensor.cast %[[CALL]] : tensor<?x?xf32> to tensor<13x13xf32>
+//       CHECK:   return %[[CAST_RES]] : tensor<13x13xf32>
+
+// -----
+
+func.func @cast_to_static(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+func.func private @collapse_replacement(%arg0: tensor<4x5xf32>) -> tensor<20xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %concat = transform.structured.match ops{["tensor.collapse_shape"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %ins = transform.get_operand %concat : (!transform.any_op) -> !transform.any_value
+    %out = transform.get_result %concat : (!transform.any_op) -> !transform.any_value
+    transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
+      transform.type_conversion.tensor.cast ignore_dynamic_info
+    } : (!transform.any_op, !transform.any_value,
+         !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.apply_dce to %f#0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @cast_to_static
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//   CHECK-DAG:   %[[CAST_IN:.+]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x5xf32>
+//       CHECK:   %[[CALL:.+]] = call @collapse_replacement(%[[CAST_IN]])
+//       CHECK:   %[[CAST_RES:.+]] = tensor.cast %[[CALL]] : tensor<20xf32> to tensor<?xf32>
+//       CHECK:   return %[[CAST_RES]] : tensor<?xf32>

>From e6211958bc210b909de4c76f17c169e3fb44ece8 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 18 Jan 2024 00:10:21 -0500
Subject: [PATCH 2/2] Collapse TypeConversion interface into converter builder
 interface and address comments

---
 .../Func/TransformOps/FuncTransformOps.td     | 16 ++---
 .../MemRef/TransformOps/MemRefTransformOps.td |  3 +-
 .../Tensor/TransformOps/TensorTransformOps.td | 15 +++--
 .../Transform/IR/TransformInterfaces.td       | 43 ++++++--------
 .../Func/TransformOps/FuncTransformOps.cpp    | 58 +++++++++----------
 .../TransformOps/TensorTransformOps.cpp       |  6 +-
 .../Dialect/Tensor/transform-op-casting.mlir  | 12 ++--
 .../TestTransformDialectExtension.td          |  3 +-
 8 files changed, 75 insertions(+), 81 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index e5086c26c55a4f..afb08ebd5eb435 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -38,22 +38,22 @@ def CastAndCallOp : Op<Transform_Dialect,
   let summary = "Casts values to the signature of a function and replaces them "
                 "with a call";
   let description = [{
-    This transform takes a set of |input| and |output| value handles and
+    This transform takes value handles to a set of `inputs` and `outputs` and
     attempts to cast them to the function signature of the attached function
     op, then builds a call to the function and replaces the users of the
     outputs. It is the responsibility of the user to ensure that the slice of
     the program replaced by this operation makes sense, i.e. there is no
     verification that the inputs to this operation have any relation to the
-    outputs outside of basic dominance requirements needed for the replacement.
+    outputs outside of basic dominance requirements needed for the call.
 
     The casting materialization functions are specified in the graph region of
-    this op. They must implement the `TypeConversionOpInterface`. The order of
-    ops within the region is irrelevant.
+    this op. They must implement the `TypeConverterBuilderOpInterface`. The
+    order of ops within the region is irrelevant.
 
     The target function can be specified by a symbol name or by a handle to the
     operation.
 
-    This transform only reads the target handles and only replaces the users of
+    This transform only reads the operand handles and only replaces the users of
     the outputs with the results of the call. No handles are consumed and no
     operations are removed. Users are expected to run cleanup separately if
     desired.
@@ -64,11 +64,11 @@ def CastAndCallOp : Op<Transform_Dialect,
      - The insertion point op does not dominate any of the output users
      - The insertion point op is not dominated by any of the inputs
      - The function signature does not match the number of inputs/outputs
-     - Any of the input conversions fail to be materialized
 
     This transform will emit a definite failure if it fails to resolve the
-    target function, or if it fails to materialize the conversion from the call
-    results to the output types.
+    target function, or if it fails to materialize the conversion casts of
+    either the inputs to the function argument types, or the call results to
+    the output types.
   }];
 
   let arguments = (ins
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 76309b9b8a9640..29383a3825be88 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -18,7 +18,8 @@ include "mlir/IR/OpBase.td"
 def MemrefToLLVMTypeConverterOp : Op<Transform_Dialect,
     "apply_conversion_patterns.memref.memref_to_llvm_type_converter",
     [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
-                               ["getTypeConverterType"]>]> {
+                               ["getTypeConverter",
+                                "getTypeConverterType"]>]> {
   let description = [{
     This operation provides an "LLVMTypeConverter" that lowers memref types to
     LLVM types.
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 28e9249c82e309..39e1d7fa3494a3 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -169,12 +169,17 @@ def MakeLoopIndependentOp
   }];
 }
 
-def TypeConversionCastOp : Op<Transform_Dialect,
-    "type_conversion.tensor.cast",
-    [DeclareOpInterfaceMethods<TypeConversionOpInterface>]> {
+def TypeConversionCastShapeDynamicDimsOp : Op<Transform_Dialect,
+    "type_conversion.tensor.cast_shape_dynamic_dims",
+    [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
+                               ["populateTypeMaterializations"]>]> {
   let description = [{
-    Indicates that tensor ops (such as tensor.generate) should be replaced with
-    constants (arith.constant) when possible.
+    Populates a type converter with conversion materialization functions that
+    cast a tensor value between two cast-compatible tensors. See `tensor.cast`
+    for more information on cast compatibility between tensors.
+
+    If `ignore_dynamic_info` is not set, this will set an additional constraint
+    that source materializations do not cast dynamic dimensions to static ones.
   }];
   let arguments = (ins UnitAttr:$ignore_dynamic_info);
 
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index 3b601f42a6452d..1ef094436881aa 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -280,34 +280,12 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
   ];
 }
 
-def TypeConversionOpInterface : OpInterface<"TypeConversionOpInterface"> {
-  let description = [{
-    This interface should be implemented by ops that populate type casting
-    of a `transform.cast_and_inline` op. It provides a method to populate a
-    type converter with source/target materialization patterns.
-  }];
-
-  let cppNamespace = "::mlir::transform";
-
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/[{
-        Populate the given type converter with source/target materialization
-        functions.
-      }],
-      /*returnType=*/"void",
-      /*name=*/"populateTypeMaterializations",
-      /*arguments=*/(ins "::mlir::TypeConverter &":$converter)
-    >
-  ];
-}
-
 def TypeConverterBuilderOpInterface
     : OpInterface<"TypeConverterBuilderOpInterface"> {
   let description = [{
     This interface should be implemented by ops that specify a type converter
-    for a dialect conversion. Such ops can be used with
-    "apply_conversion_patterns".
+    for a dialect conversion, or to populate a type converter with
+    conversions. Such ops can be used with "apply_conversion_patterns".
   }];
 
   let cppNamespace = "::mlir::transform";
@@ -319,7 +297,11 @@ def TypeConverterBuilderOpInterface
       }],
       /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
       /*name=*/"getTypeConverter",
-      /*arguments=*/(ins)
+      /*arguments=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return std::make_unique<::mlir::TypeConverter>();
+      }]
     >,
     StaticInterfaceMethod<
       /*desc=*/[{
@@ -332,6 +314,17 @@ def TypeConverterBuilderOpInterface
       /*methodBody=*/"",
       /*defaultImplementation=*/[{ return "TypeConverter"; }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Populate the given type converter with source/target materialization
+        functions.
+      }],
+      /*returnType=*/"void",
+      /*name=*/"populateTypeMaterializations",
+      /*arguments=*/(ins "::mlir::TypeConverter &":$converter),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return; }]
+    >,
   ];
 }
 
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 14b6e633520d6c..9e79b086c0be84 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -47,21 +47,19 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
                                 transform::TransformState &state) {
   SmallVector<Value> inputs;
   if (getInputs())
-    for (Value input : state.getPayloadValues(getInputs()))
-      inputs.push_back(input);
-  SmallVector<Value> outputs;
-  if (getOutputs())
-    for (Value output : state.getPayloadValues(getOutputs()))
-      outputs.push_back(output);
+    llvm::append_range(inputs, state.getPayloadValues(getInputs()));
 
-  // Verify that the set of output values to be replaced is unique.
-  llvm::SmallDenseSet<Value> outputSet;
-  for (Value output : outputs) {
-    outputSet.insert(output);
-  }
-  if (outputSet.size() != outputs.size()) {
-    return emitSilenceableFailure(getLoc())
-           << "cast and call output values must be unique";
+  SetVector<Value> outputs;
+  if (getOutputs()) {
+    for (auto output : state.getPayloadValues(getOutputs()))
+      outputs.insert(output);
+
+    // Verify that the set of output values to be replaced is unique.
+    if (outputs.size() !=
+        llvm::range_size(state.getPayloadValues(getOutputs()))) {
+      return emitSilenceableFailure(getLoc())
+             << "cast and call output values must be unique";
+    }
   }
 
   // Get the insertion point for the call.
@@ -106,7 +104,7 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
     }
   }
 
-  // Get the function to inline. This can either be specified by symbol or as a
+  // Get the function to call. This can either be specified by symbol or as a
   // transform handle.
   func::FuncOp targetFunction = nullptr;
   if (getFunctionName()) {
@@ -129,7 +127,6 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
     llvm_unreachable("Invalid CastAndCall op without a function to call");
     return emitDefiniteFailure();
   }
-  assert(targetFunction && "no target function found");
 
   // Verify that the function argument and result lengths match the inputs and
   // outputs given to this op.
@@ -147,37 +144,34 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
   }
 
   // Gather all specified converters.
-  MLIRContext *ctx = insertionPoint->getContext();
   mlir::TypeConverter converter;
   if (!getRegion().empty()) {
     for (Operation &op : getRegion().front()) {
-      cast<transform::TypeConversionOpInterface>(&op)
+      cast<transform::TypeConverterBuilderOpInterface>(&op)
           .populateTypeMaterializations(converter);
     }
   }
 
-  OpBuilder builder(ctx);
   if (insertAfter)
-    builder.setInsertionPointAfter(insertionPoint);
+    rewriter.setInsertionPointAfter(insertionPoint);
   else
-    builder.setInsertionPoint(insertionPoint);
+    rewriter.setInsertionPoint(insertionPoint);
 
   for (auto [input, type] :
        llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
     if (input.getType() != type) {
       Value newInput = converter.materializeSourceConversion(
-          builder, input.getLoc(), type, input);
+          rewriter, input.getLoc(), type, input);
       if (!newInput) {
-        return emitSilenceableFailure(input.getLoc())
-               << "Failed to materialize conversion of " << input << " to type "
-               << type;
+        return emitDefiniteFailure() << "Failed to materialize conversion of "
+                                     << input << " to type " << type;
       }
       input = newInput;
     }
   }
 
-  auto callOp = builder.create<func::CallOp>(insertionPoint->getLoc(),
-                                             targetFunction, inputs);
+  auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
+                                              targetFunction, inputs);
 
   // Cast the call results back to the expected types. If any conversions fail
   // this is a definite failure as the call has been constructed at this point.
@@ -186,14 +180,14 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
     Value convertedOutput = newOutput;
     if (output.getType() != newOutput.getType()) {
       convertedOutput = converter.materializeTargetConversion(
-          builder, output.getLoc(), output.getType(), newOutput);
+          rewriter, output.getLoc(), output.getType(), newOutput);
       if (!convertedOutput) {
-        return emitSilenceableFailure(output.getLoc())
+        return emitDefiniteFailure()
                << "Failed to materialize conversion of " << newOutput
                << " to type " << output.getType();
       }
     }
-    output.replaceAllUsesExcept(convertedOutput, callOp);
+    rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
   }
   results.set(cast<OpResult>(getResult()), {callOp});
   return DiagnosedSilenceableFailure::success();
@@ -202,10 +196,10 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
 LogicalResult transform::CastAndCallOp::verify() {
   if (!getRegion().empty()) {
     for (Operation &op : getRegion().front()) {
-      if (!isa<transform::TypeConversionOpInterface>(&op)) {
+      if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
         InFlightDiagnostic diag = emitOpError()
                                   << "expected children ops to implement "
-                                     "TypeConversionOpInterface";
+                                     "TypeConverterBuilderOpInterface";
         diag.attachNote(op.getLoc()) << "op without interface";
         return diag;
       }
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 0c89ba2a1f1895..38f1824a3634a3 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -131,11 +131,11 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
 }
 
 //===----------------------------------------------------------------------===//
-// TypeConversionCastOp
+// TypeConversionCastTensorShapeOp
 //===----------------------------------------------------------------------===//
 
-void transform::TypeConversionCastOp::populateTypeMaterializations(
-    TypeConverter &converter) {
+void transform::TypeConversionCastShapeDynamicDimsOp::
+    populateTypeMaterializations(TypeConverter &converter) {
   bool ignoreDynamicInfo = getIgnoreDynamicInfo();
   converter.addSourceMaterialization([ignoreDynamicInfo](
                                          OpBuilder &builder, Type resultType,
diff --git a/mlir/test/Dialect/Tensor/transform-op-casting.mlir b/mlir/test/Dialect/Tensor/transform-op-casting.mlir
index fd2fc8a1883a3c..16a1fa2b0ba9c7 100644
--- a/mlir/test/Dialect/Tensor/transform-op-casting.mlir
+++ b/mlir/test/Dialect/Tensor/transform-op-casting.mlir
@@ -12,10 +12,10 @@ module attributes {transform.with_named_sequence} {
     %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     %concat = transform.structured.match ops{["tensor.concat"]} in %f#0 : (!transform.any_op) -> !transform.any_op
-    %ins = transform.get_operand %concat : (!transform.any_op) -> !transform.any_value
-    %out = transform.get_result %concat : (!transform.any_op) -> !transform.any_value
+    %ins = transform.get_operand %concat[all] : (!transform.any_op) -> !transform.any_value
+    %out = transform.get_result %concat[all] : (!transform.any_op) -> !transform.any_value
     transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
-      transform.type_conversion.tensor.cast
+      transform.type_conversion.tensor.cast_shape_dynamic_dims
     } : (!transform.any_op, !transform.any_value,
          !transform.any_value, !transform.any_op) -> !transform.any_op
     transform.apply_dce to %f#0 : !transform.any_op
@@ -46,10 +46,10 @@ module attributes {transform.with_named_sequence} {
     %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     %concat = transform.structured.match ops{["tensor.collapse_shape"]} in %f#0 : (!transform.any_op) -> !transform.any_op
-    %ins = transform.get_operand %concat : (!transform.any_op) -> !transform.any_value
-    %out = transform.get_result %concat : (!transform.any_op) -> !transform.any_value
+    %ins = transform.get_operand %concat[all] : (!transform.any_op) -> !transform.any_value
+    %out = transform.get_result %concat[all] : (!transform.any_op) -> !transform.any_value
     transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
-      transform.type_conversion.tensor.cast ignore_dynamic_info
+      transform.type_conversion.tensor.cast_shape_dynamic_dims ignore_dynamic_info
     } : (!transform.any_op, !transform.any_value,
          !transform.any_value, !transform.any_op) -> !transform.any_op
     transform.apply_dce to %f#0 : !transform.any_op
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 54036f7929d1b8..c00cc560e83e9b 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -502,7 +502,8 @@ def ApplyTestConversionPatternsOp
 
 def TestTypeConverterOp
   : Op<Transform_Dialect, "apply_conversion_patterns.transform.test_type_converter",
-      [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface>]> {
+      [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
+                                 ["getTypeConverter"]>]> {
   let arguments = (ins);
   let results = (outs);
   let assemblyFormat = "attr-dict";



More information about the Mlir-commits mailing list