[Mlir-commits] [mlir] [mlir][transform] Add an op for replacing values with function calls (PR #78398)

Quinn Dawkins llvmlistbot at llvm.org
Wed Jan 17 21:10:58 PST 2024


https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/78398

>From 7ec7d9397d3fc34e237abd24ac7f5e94a7178fe9 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Tue, 16 Jan 2024 23:53:33 -0500
Subject: [PATCH 1/3] [mlir][transform] Add transform.get_operand op

Similar to `transform.get_result`, except it returns a handle to the
operand indicated by `operand_number`, or all operands if no index is
given.

Additionally updates `get_result` to make the `result_number`
optional. This makes the use case of wanting to get all of the
results of an operation easier by no longer requiring the user to
reconstruct the list of results one-by-one.
---
 .../Linalg/TransformOps/LinalgMatchOps.td     |   4 +-
 .../Dialect/Transform/IR/MatchInterfaces.h    |  51 +++++-
 .../mlir/Dialect/Transform/IR/TransformOps.td |  71 ++++++--
 .../Linalg/TransformOps/LinalgMatchOps.cpp    | 170 +-----------------
 .../Dialect/Transform/IR/MatchInterfaces.cpp  | 135 ++++++++++++++
 .../lib/Dialect/Transform/IR/TransformOps.cpp |  57 +++++-
 .../Dialect/Transform/test-interpreter.mlir   |  94 +++++++++-
 7 files changed, 395 insertions(+), 187 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index 9e108529ec129b..162dd05f93030f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -288,7 +288,7 @@ def MatchStructuredDimOp : Op<Transform_Dialect, "match.structured.dim", [
   let results = (outs Optional<TransformParamTypeInterface>:$result);
   let assemblyFormat =
       "$operand_handle `[`"
-      "custom<StructuredTransformDims>($raw_dim_list, $is_inverted, $is_all)"
+      "custom<TransformMatchDims>($raw_dim_list, $is_inverted, $is_all)"
       "`]` attr-dict `:` "
       "custom<SemiFunctionType>(type($operand_handle), type($result))";
 
@@ -347,7 +347,7 @@ class MatchStructuredOperandOp<string opname> : Op<Transform_Dialect, opname, [
       (outs Optional<AnyTypeOf<[TransformAnyHandle,Transform_AffineMapParamType]>>:$result);
   let assemblyFormat =
       "$operand_handle `[`"
-      "custom<StructuredTransformDims>($raw_position_list, $is_inverted, $is_all)"
+      "custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
       "`]` attr-dict "
       "`:` custom<SemiFunctionType>(type($operand_handle), type($result))";
 
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index b155b110677d6c..36aeb4583029c9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -9,11 +9,12 @@
 #ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
 #define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
 
+#include <optional>
+#include <type_traits>
+
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
 #include "llvm/ADT/STLExtras.h"
-#include <optional>
-#include <type_traits>
 
 namespace mlir {
 namespace transform {
@@ -168,6 +169,52 @@ class SingleValueMatcherOpTrait
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Printing/parsing for positional specification matchers
+//===----------------------------------------------------------------------===//
+
+/// Parses a positional index specification for transform match operations.
+/// The following forms are accepted:
+///
+///  - `all`: sets `isAll` and returns;
+///  - comma-separated-integer-list: populates `rawDimList` with the values;
+///  - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
+///  with the values and sets `isInverted`.
+ParseResult parseTransformMatchDims(OpAsmParser &parser,
+                                    DenseI64ArrayAttr &rawDimList,
+                                    UnitAttr &isInverted, UnitAttr &isAll);
+
+/// Prints a positional index specification for transform match operations.
+void printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
+                             DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
+                             UnitAttr isAll);
+
+//===----------------------------------------------------------------------===//
+// Utilities for positional specification matchers
+//===----------------------------------------------------------------------===//
+
+/// Checks if the positional specification defined is valid and reports errors
+/// otherwise.
+LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef<int64_t> raw,
+                                         bool inverted, bool all);
+
+/// Populates `result` with the positional identifiers relative to `maxNumber`.
+/// If `isAll` is set, the result will contain all numbers from `0` to
+/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
+/// values from `rawList` are  are interpreted as counting backwards from
+/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
+/// numbers remain as is. If `isInverted` is set, populates `result` with those
+/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
+/// `rawList`. If `rawList` contains values that are greater than or equal to
+/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
+/// given location. `maxNumber` must be positive. If `rawList` contains
+/// duplicate numbers or numbers that become duplicate after negative value
+/// remapping, emits a silenceable error.
+DiagnosedSilenceableFailure
+expandTargetSpecification(Location loc, bool isAll, bool isInverted,
+                          ArrayRef<int64_t> rawList, int64_t maxNumber,
+                          SmallVectorImpl<int64_t> &result);
+
 } // namespace transform
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 2627553cff6953..9f513822ed0a4e 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -728,24 +728,75 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
                        "functional-type(operands, results)";
 }
 
+def GetOperandOp : TransformDialectOp<"get_operand",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
+  let summary = "Get a handle to the operand(s) of the targeted op";
+  let description = [{
+    The handle defined by this Transform op corresponds to the operands of the
+    given `target` operation specified by the given set of positions. There are
+    three possible modes:
+
+     - Position list directly, i.e. `%target[0, 1, 2]`. This will return the
+       operands at the specified positions.
+     - Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
+       all operands except those at the given positions.
+     - All, i.e. `%target[all]`. This will return all operands of the operation.
+    
+    This transform produces a silenceable failure if any of the operand indices
+    exceeds the number of operands in the target. It reads the target handle and
+    produces the result handle.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       DenseI64ArrayAttr:$raw_position_list,
+                       UnitAttr:$is_inverted,
+                       UnitAttr:$is_all);
+  let results = (outs TransformValueHandleTypeInterface:$result);
+  let assemblyFormat =
+      "$target `[`"
+      "custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
+      "`]` attr-dict `:` functional-type(operands, results)";
+  let hasVerifier = 1;
+}
+
 def GetResultOp : TransformDialectOp<"get_result",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
-  let summary = "Get handle to the a result of the targeted op";
+  let summary = "Get a handle to the result(s) of the targeted op";
   let description = [{
-    The handle defined by this Transform op corresponds to the OpResult with
-    `result_number` that is defined by the given `target` operation.
-
-    This transform produces a silenceable failure if the targeted operation
-    does not have enough results. It reads the target handle and produces the
-    result handle.
+    The handle defined by this Transform op correspond to the OpResults of the
+    given `target` operation. Optionally `result_number` can be specified to
+    select a specific result.
+    
+    This transform fails silently if the targeted operation does not have enough
+    results. It reads the target handle and produces the result handle.
+
+    The handle defined by this Transform op corresponds to the results of the
+    given `target` operation specified by the given set of positions. There are
+    three possible modes:
+
+     - Position list directly, i.e. `%target[0, 1, 2]`. This will return the
+       results at the specified positions.
+     - Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
+       all results except those at the given positions.
+     - All, i.e. `%target[all]`. This will return all results of the operation.
+    
+    This transform produces a silenceable failure if any of the result indices
+    exceeds the number of results returned by the target. It reads the target
+    handle and produces the result handle.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       I64Attr:$result_number);
+                       DenseI64ArrayAttr:$raw_position_list,
+                       UnitAttr:$is_inverted,
+                       UnitAttr:$is_all);
   let results = (outs TransformValueHandleTypeInterface:$result);
-  let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
-                       "functional-type(operands, results)";
+  let assemblyFormat =
+      "$target `[`"
+      "custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
+      "`]` attr-dict `:` functional-type(operands, results)";
+  let hasVerifier = 1;
 }
 
 def GetTypeOp : TransformDialectOp<"get_type",
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index fb021ed76242e9..115da4b90e063a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -330,91 +330,6 @@ static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
   return DiagnosedSilenceableFailure::success();
 }
 
-/// Populates `result` with the positional identifiers relative to `maxNumber`.
-/// If `isAll` is set, the result will contain all numbers from `0` to
-/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
-/// values from `rawList` are  are interpreted as counting backwards from
-/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
-/// numbers remain as is. If `isInverted` is set, populates `result` with those
-/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
-/// `rawList`. If `rawList` contains values that are greater than or equal to
-/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
-/// given location. `maxNumber` must be positive. If `rawList` contains
-/// duplicate numbers or numbers that become duplicate after negative value
-/// remapping, emits a silenceable error.
-static DiagnosedSilenceableFailure
-expandTargetSpecification(Location loc, bool isAll, bool isInverted,
-                          ArrayRef<int64_t> rawList, int64_t maxNumber,
-                          SmallVectorImpl<int64_t> &result) {
-  assert(maxNumber > 0 && "expected size to be positive");
-  assert(!(isAll && isInverted) && "cannot invert all");
-  if (isAll) {
-    result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
-    return DiagnosedSilenceableFailure::success();
-  }
-
-  SmallVector<int64_t> expanded;
-  llvm::SmallDenseSet<int64_t> visited;
-  expanded.reserve(rawList.size());
-  SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
-  for (int64_t raw : rawList) {
-    int64_t updated = raw < 0 ? maxNumber + raw : raw;
-    if (updated >= maxNumber) {
-      return emitSilenceableFailure(loc)
-             << "position overflow " << updated << " (updated from " << raw
-             << ") for maximum " << maxNumber;
-    }
-    if (updated < 0) {
-      return emitSilenceableFailure(loc) << "position underflow " << updated
-                                         << " (updated from " << raw << ")";
-    }
-    if (!visited.insert(updated).second) {
-      return emitSilenceableFailure(loc) << "repeated position " << updated
-                                         << " (updated from " << raw << ")";
-    }
-    target.push_back(updated);
-  }
-
-  if (!isInverted)
-    return DiagnosedSilenceableFailure::success();
-
-  result.reserve(result.size() + (maxNumber - expanded.size()));
-  for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
-    if (llvm::is_contained(expanded, candidate))
-      continue;
-    result.push_back(candidate);
-  }
-
-  return DiagnosedSilenceableFailure::success();
-}
-
-/// Checks if the positional specification defined is valid and reports errors
-/// otherwise.
-LogicalResult verifyStructuredTransformDimsOp(Operation *op,
-                                              ArrayRef<int64_t> raw,
-                                              bool inverted, bool all) {
-  if (all) {
-    if (inverted) {
-      return op->emitOpError()
-             << "cannot request both 'all' and 'inverted' values in the list";
-    }
-    if (!raw.empty()) {
-      return op->emitOpError()
-             << "cannot both request 'all' and specific values in the list";
-    }
-  }
-  if (!all && raw.empty()) {
-    return op->emitOpError() << "must request specific values in the list if "
-                                "'all' is not specified";
-  }
-  SmallVector<int64_t> rawVector = llvm::to_vector(raw);
-  auto *it = std::unique(rawVector.begin(), rawVector.end());
-  if (it != rawVector.end())
-    return op->emitOpError() << "expected the listed values to be unique";
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // MatchStructuredDimOp
 //===----------------------------------------------------------------------===//
@@ -475,8 +390,8 @@ LogicalResult transform::MatchStructuredDimOp::verify() {
     return emitOpError() << "cannot request the same dimension to be both "
                             "parallel and reduction";
   }
-  return verifyStructuredTransformDimsOp(getOperation(), getRawDimList(),
-                                         getIsInverted(), getIsAll());
+  return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
+                                    getIsInverted(), getIsAll());
 }
 
 //===----------------------------------------------------------------------===//
@@ -592,8 +507,8 @@ LogicalResult verifyStructuredOperandOp(OpTy op) {
 LogicalResult transform::MatchStructuredInputOp::verify() {
   if (failed(verifyStructuredOperandOp(*this)))
     return failure();
-  return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
-                                         getIsInverted(), getIsAll());
+  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+                                    getIsInverted(), getIsAll());
 }
 
 //===----------------------------------------------------------------------===//
@@ -665,8 +580,8 @@ DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
 LogicalResult transform::MatchStructuredInitOp::verify() {
   if (failed(verifyStructuredOperandOp(*this)))
     return failure();
-  return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
-                                         getIsInverted(), getIsAll());
+  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+                                    getIsInverted(), getIsAll());
 }
 
 //===----------------------------------------------------------------------===//
@@ -793,78 +708,5 @@ void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
   build(builder, state, ValueRange());
 }
 
-//===----------------------------------------------------------------------===//
-// Printing and parsing for structured match ops.
-//===----------------------------------------------------------------------===//
-
-/// Keyword syntax for positional specification inversion.
-constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
-
-/// Keyword syntax for full inclusion in positional specification.
-constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
-
-/// Parses a positional specification for structured transform operations. The
-/// following forms are accepted:
-///
-///  - `all`: sets `isAll` and returns;
-///  - comma-separated-integer-list: populates `rawDimList` with the values;
-///  - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
-///  with the values and sets `isInverted`.
-static ParseResult parseStructuredTransformDims(OpAsmParser &parser,
-                                                DenseI64ArrayAttr &rawDimList,
-                                                UnitAttr &isInverted,
-                                                UnitAttr &isAll) {
-  Builder &builder = parser.getBuilder();
-  if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
-    rawDimList = builder.getDenseI64ArrayAttr({});
-    isInverted = nullptr;
-    isAll = builder.getUnitAttr();
-    return success();
-  }
-
-  isAll = nullptr;
-  isInverted = nullptr;
-  if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
-    isInverted = builder.getUnitAttr();
-  }
-
-  if (isInverted) {
-    if (parser.parseLParen().failed())
-      return failure();
-  }
-
-  SmallVector<int64_t> values;
-  ParseResult listResult = parser.parseCommaSeparatedList(
-      [&]() { return parser.parseInteger(values.emplace_back()); });
-  if (listResult.failed())
-    return failure();
-
-  rawDimList = builder.getDenseI64ArrayAttr(values);
-
-  if (isInverted) {
-    if (parser.parseRParen().failed())
-      return failure();
-  }
-  return success();
-}
-
-/// Prints a positional specification for structured transform operations.
-static void printStructuredTransformDims(OpAsmPrinter &printer, Operation *op,
-                                         DenseI64ArrayAttr rawDimList,
-                                         UnitAttr isInverted, UnitAttr isAll) {
-  if (isAll) {
-    printer << kDimAllKeyword;
-    return;
-  }
-  if (isInverted) {
-    printer << kDimExceptKeyword << "(";
-  }
-  llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
-                        [&](int64_t value) { printer << value; });
-  if (isInverted) {
-    printer << ")";
-  }
-}
-
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
diff --git a/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp
index 6049ee767e1bb7..b9b6dabc26216e 100644
--- a/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp
@@ -10,6 +10,141 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Printing and parsing for match ops.
+//===----------------------------------------------------------------------===//
+
+/// Keyword syntax for positional specification inversion.
+constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
+
+/// Keyword syntax for full inclusion in positional specification.
+constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
+
+ParseResult transform::parseTransformMatchDims(OpAsmParser &parser,
+                                               DenseI64ArrayAttr &rawDimList,
+                                               UnitAttr &isInverted,
+                                               UnitAttr &isAll) {
+  Builder &builder = parser.getBuilder();
+  if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
+    rawDimList = builder.getDenseI64ArrayAttr({});
+    isInverted = nullptr;
+    isAll = builder.getUnitAttr();
+    return success();
+  }
+
+  isAll = nullptr;
+  isInverted = nullptr;
+  if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
+    isInverted = builder.getUnitAttr();
+  }
+
+  if (isInverted) {
+    if (parser.parseLParen().failed())
+      return failure();
+  }
+
+  SmallVector<int64_t> values;
+  ParseResult listResult = parser.parseCommaSeparatedList(
+      [&]() { return parser.parseInteger(values.emplace_back()); });
+  if (listResult.failed())
+    return failure();
+
+  rawDimList = builder.getDenseI64ArrayAttr(values);
+
+  if (isInverted) {
+    if (parser.parseRParen().failed())
+      return failure();
+  }
+  return success();
+}
+
+void transform::printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
+                                        DenseI64ArrayAttr rawDimList,
+                                        UnitAttr isInverted, UnitAttr isAll) {
+  if (isAll) {
+    printer << kDimAllKeyword;
+    return;
+  }
+  if (isInverted) {
+    printer << kDimExceptKeyword << "(";
+  }
+  llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
+                        [&](int64_t value) { printer << value; });
+  if (isInverted) {
+    printer << ")";
+  }
+}
+
+LogicalResult transform::verifyTransformMatchDimsOp(Operation *op,
+                                                    ArrayRef<int64_t> raw,
+                                                    bool inverted, bool all) {
+  if (all) {
+    if (inverted) {
+      return op->emitOpError()
+             << "cannot request both 'all' and 'inverted' values in the list";
+    }
+    if (!raw.empty()) {
+      return op->emitOpError()
+             << "cannot both request 'all' and specific values in the list";
+    }
+  }
+  if (!all && raw.empty()) {
+    return op->emitOpError() << "must request specific values in the list if "
+                                "'all' is not specified";
+  }
+  SmallVector<int64_t> rawVector = llvm::to_vector(raw);
+  auto *it = std::unique(rawVector.begin(), rawVector.end());
+  if (it != rawVector.end())
+    return op->emitOpError() << "expected the listed values to be unique";
+
+  return success();
+}
+
+DiagnosedSilenceableFailure transform::expandTargetSpecification(
+    Location loc, bool isAll, bool isInverted, ArrayRef<int64_t> rawList,
+    int64_t maxNumber, SmallVectorImpl<int64_t> &result) {
+  assert(maxNumber > 0 && "expected size to be positive");
+  assert(!(isAll && isInverted) && "cannot invert all");
+  if (isAll) {
+    result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  SmallVector<int64_t> expanded;
+  llvm::SmallDenseSet<int64_t> visited;
+  expanded.reserve(rawList.size());
+  SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
+  for (int64_t raw : rawList) {
+    int64_t updated = raw < 0 ? maxNumber + raw : raw;
+    if (updated >= maxNumber) {
+      return emitSilenceableFailure(loc)
+             << "position overflow " << updated << " (updated from " << raw
+             << ") for maximum " << maxNumber;
+    }
+    if (updated < 0) {
+      return emitSilenceableFailure(loc) << "position underflow " << updated
+                                         << " (updated from " << raw << ")";
+    }
+    if (!visited.insert(updated).second) {
+      return emitSilenceableFailure(loc) << "repeated position " << updated
+                                         << " (updated from " << raw << ")";
+    }
+    target.push_back(updated);
+  }
+
+  if (!isInverted)
+    return DiagnosedSilenceableFailure::success();
+
+  result.reserve(result.size() + (maxNumber - expanded.size()));
+  for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
+    if (llvm::is_contained(expanded, candidate))
+      continue;
+    result.push_back(candidate);
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // Generated interface implementation.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index b80fc09751d2aa..485d4448e7c368 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1464,6 +1464,39 @@ transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// GetOperandOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
+                               transform::TransformResults &results,
+                               transform::TransformState &state) {
+  SmallVector<Value> operands;
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    SmallVector<int64_t> operandPositions;
+    DiagnosedSilenceableFailure diag = expandTargetSpecification(
+        getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
+        target->getNumOperands(), operandPositions);
+    if (diag.isSilenceableFailure()) {
+      diag.attachNote(target->getLoc())
+          << "while considering positions of this payload operation";
+      return diag;
+    }
+    llvm::append_range(operands,
+                       llvm::map_range(operandPositions, [&](int64_t pos) {
+                         return target->getOperand(pos);
+                       }));
+  }
+  results.setValues(cast<OpResult>(getResult()), operands);
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::GetOperandOp::verify() {
+  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+                                    getIsInverted(), getIsAll());
+}
+
 //===----------------------------------------------------------------------===//
 // GetResultOp
 //===----------------------------------------------------------------------===//
@@ -1472,21 +1505,31 @@ DiagnosedSilenceableFailure
 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
                               transform::TransformResults &results,
                               transform::TransformState &state) {
-  int64_t resultNumber = getResultNumber();
   SmallVector<Value> opResults;
   for (Operation *target : state.getPayloadOps(getTarget())) {
-    if (resultNumber >= target->getNumResults()) {
-      DiagnosedSilenceableFailure diag =
-          emitSilenceableError() << "targeted op does not have enough results";
-      diag.attachNote(target->getLoc()) << "target op";
+    SmallVector<int64_t> resultPositions;
+    DiagnosedSilenceableFailure diag = expandTargetSpecification(
+        getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
+        target->getNumResults(), resultPositions);
+    if (diag.isSilenceableFailure()) {
+      diag.attachNote(target->getLoc())
+          << "while considering positions of this payload operation";
       return diag;
     }
-    opResults.push_back(target->getOpResult(resultNumber));
+    llvm::append_range(opResults,
+                       llvm::map_range(resultPositions, [&](int64_t pos) {
+                         return target->getResult(pos);
+                       }));
   }
-  results.setValues(llvm::cast<OpResult>(getResult()), opResults);
+  results.setValues(cast<OpResult>(getResult()), opResults);
   return DiagnosedSilenceableFailure::success();
 }
 
+LogicalResult transform::GetResultOp::verify() {
+  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+                                    getIsInverted(), getIsAll());
+}
+
 //===----------------------------------------------------------------------===//
 // GetTypeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 96f2122e976df5..de5807b2874b27 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1483,6 +1483,78 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// expected-remark @below {{addi operand}}
+// expected-note @below {{value handle points to a block argument #0}}
+func.func @get_operand_of_op(%arg0: index, %arg1: index) -> index {
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %operand = transform.get_operand %addi[0] : (!transform.any_op) -> !transform.any_value
+    transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @get_out_of_bounds_operand_of_op(%arg0: index, %arg1: index) -> index {
+  // expected-note @below {{while considering positions of this payload operation}}
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{position overflow 2 (updated from 2) for maximum 2}}
+    %operand = transform.get_operand %addi[2] : (!transform.any_op) -> !transform.any_value
+    transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-remark @below {{addi operand}}
+// expected-note @below {{value handle points to a block argument #1}}
+func.func @get_inverted_operand_of_op(%arg0: index, %arg1: index) -> index {
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %operand = transform.get_operand %addi[except(0)] : (!transform.any_op) -> !transform.any_value
+    transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @get_multiple_operands_of_op(%arg0: index, %arg1: index) -> index {
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addui = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %operands = transform.get_operand %addui[all] : (!transform.any_op) -> !transform.any_value
+    %p = transform.num_associations %operands : (!transform.any_value) -> !transform.param<i64>
+    // expected-remark @below {{2}}
+    transform.debug.emit_param_as_remark %p : !transform.param<i64>
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @get_result_of_op(%arg0: index, %arg1: index) -> index {
   // expected-remark @below {{addi result}}
   // expected-note @below {{value handle points to an op result #0}}
@@ -1502,7 +1574,7 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index {
-  // expected-note @below {{target op}}
+  // expected-note @below {{while considering positions of this payload operation}}
   %r = arith.addi %arg0, %arg1 : index
   return %r : index
 }
@@ -1510,7 +1582,7 @@ func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    // expected-error @below {{targeted op does not have enough results}}
+    // expected-error @below {{position overflow 1 (updated from 1) for maximum 1}}
     %result = transform.get_result %addi[1] : (!transform.any_op) -> !transform.any_value
     transform.debug.emit_remark_at %result, "addi result" : !transform.any_value
     transform.yield
@@ -1537,6 +1609,24 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @get_multiple_result_of_op(%arg0: index, %arg1: index) -> (index, i1) {
+  %r, %b = arith.addui_extended %arg0, %arg1 : index, i1
+  return %r, %b : index, i1
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addui = transform.structured.match ops{["arith.addui_extended"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %results = transform.get_result %addui[all] : (!transform.any_op) -> !transform.any_value
+    %p = transform.num_associations %results : (!transform.any_value) -> !transform.param<i64>
+    // expected-remark @below {{2}}
+    transform.debug.emit_param_as_remark %p : !transform.param<i64>
+    transform.yield
+  }
+}
+
+// -----
+
 // expected-note @below {{target value}}
 func.func @get_result_of_op_bbarg(%arg0: index, %arg1: index) -> index {
   %r = arith.addi %arg0, %arg1 : index

>From 84d680b66412869f29ac294d98f6fba7e0729db2 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Tue, 16 Jan 2024 15:34:23 -0500
Subject: [PATCH 2/3] [mlir][transform] Add an op for replacing values with
 function calls

Adds `transform.func.cast_and_call` that takes a set of inputs and
outputs and replaces the uses of those outputs with a call to a function
at a specified insertion point.

The idea with this operation is to allow users to author independent IR
outside of a to-be-compiled module, and then match and replace a slice of
the program with a call to the external function.

Additionally adds a mechanism for populating a type converter with a set
of conversion materialization functions that allow insertion of
casts on the inputs/outputs to and from the types of the function
signature.
---
 .../Func/TransformOps/FuncTransformOps.td     |  65 ++++++
 .../Tensor/TransformOps/TensorTransformOps.td |  13 ++
 .../Transform/IR/TransformInterfaces.td       |  22 ++
 .../Func/TransformOps/FuncTransformOps.cpp    | 197 ++++++++++++++++++
 .../TransformOps/TensorTransformOps.cpp       |  40 ++++
 .../lib/Dialect/Transform/IR/TransformOps.cpp |   4 +
 mlir/test/Dialect/Func/func-transform.mlir    | 120 +++++++++++
 .../Dialect/Tensor/transform-op-casting.mlir  |  65 ++++++
 8 files changed, 526 insertions(+)
 create mode 100644 mlir/test/Dialect/Func/func-transform.mlir
 create mode 100644 mlir/test/Dialect/Tensor/transform-op-casting.mlir

diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index 7a7e991c786188..e5086c26c55a4f 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -12,6 +12,8 @@
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/RegionKindInterface.td"
 include "mlir/IR/OpBase.td"
 
 def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
@@ -26,4 +28,67 @@ def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def CastAndCallOp : Op<Transform_Dialect,
+    "func.cast_and_call",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     AttrSizedOperandSegments,
+     ReportTrackingListenerFailuresOpTrait]
+        # GraphRegionNoTerminator.traits> {
+  let summary = "Casts values to the signature of a function and replaces them "
+                "with a call";
+  let description = [{
+    This transform takes a set of |input| and |output| value handles and
+    attempts to cast them to the function signature of the attached function
+    op, then builds a call to the function and replaces the users of the
+    outputs. It is the responsibility of the user to ensure that the slice of
+    the program replaced by this operation makes sense, i.e. there is no
+    verification that the inputs to this operation have any relation to the
+    outputs outside of basic dominance requirements needed for the replacement.
+
+    The casting materialization functions are specified in the graph region of
+    this op. They must implement the `TypeConversionOpInterface`. The order of
+    ops within the region is irrelevant.
+
+    The target function can be specified by a symbol name or by a handle to the
+    operation.
+
+    This transform only reads the target handles and only replaces the users of
+    the outputs with the results of the call. No handles are consumed and no
+    operations are removed. Users are expected to run cleanup separately if
+    desired.
+
+    This transform will emit a silenceable failure if:
+     - The set of outputs isn't unique
+     - The handle for the insertion point does not include exactly one operation
+     - The insertion point op does not dominate any of the output users
+     - The insertion point op is not dominated by any of the inputs
+     - The function signature does not match the number of inputs/outputs
+     - Any of the input conversions fail to be materialized
+
+    This transform will emit a definite failure if it fails to resolve the
+    target function, or if it fails to materialize the conversion from the call
+    results to the output types.
+  }];
+
+  let arguments = (ins
+    TransformHandleTypeInterface:$insertion_point,
+    UnitAttr:$insert_after,
+    Optional<TransformValueHandleTypeInterface>:$inputs,
+    Optional<TransformValueHandleTypeInterface>:$outputs,
+    OptionalAttr<SymbolRefAttr>:$function_name,
+    Optional<TransformHandleTypeInterface>:$function);
+  let results = (outs TransformHandleTypeInterface:$result);
+  let regions = (region MaxSizedRegion<1>:$conversions);
+
+  let assemblyFormat = [{
+    ($function_name^)? ($function^)?
+    ( `(` $inputs^ `)` )?
+    ( `->` $outputs^ )?
+    (`after` $insert_after^):(`before`)? $insertion_point
+    ($conversions^)? attr-dict `:` functional-type(operands, results)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // FUNC_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 8556d9570fd120..28e9249c82e309 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -169,4 +169,17 @@ def MakeLoopIndependentOp
   }];
 }
 
+def TypeConversionCastOp : Op<Transform_Dialect,
+    "type_conversion.tensor.cast",
+    [DeclareOpInterfaceMethods<TypeConversionOpInterface>]> {
+  let description = [{
+    Indicates that tensor ops (such as tensor.generate) should be replaced with
+    constants (arith.constant) when possible.
+  }];
+  let arguments = (ins UnitAttr:$ignore_dynamic_info);
+
+  let assemblyFormat =
+      "(`ignore_dynamic_info` $ignore_dynamic_info^)? attr-dict";
+}
+
 #endif // TENSOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index f29efaee620d84..3b601f42a6452d 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -280,6 +280,28 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
   ];
 }
 
+def TypeConversionOpInterface : OpInterface<"TypeConversionOpInterface"> {
+  let description = [{
+    This interface should be implemented by ops that populate type casting
+    of a `transform.cast_and_inline` op. It provides a method to populate a
+    type converter with source/target materialization patterns.
+  }];
+
+  let cppNamespace = "::mlir::transform";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Populate the given type converter with source/target materialization
+        functions.
+      }],
+      /*returnType=*/"void",
+      /*name=*/"populateTypeMaterializations",
+      /*arguments=*/(ins "::mlir::TypeConverter &":$converter)
+    >
+  ];
+}
+
 def TypeConverterBuilderOpInterface
     : OpInterface<"TypeConverterBuilderOpInterface"> {
   let description = [{
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 9e9b6bcea790de..14b6e633520d6c 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 
@@ -36,6 +37,202 @@ transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// CastAndCallOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
+                                transform::TransformResults &results,
+                                transform::TransformState &state) {
+  SmallVector<Value> inputs;
+  if (getInputs())
+    for (Value input : state.getPayloadValues(getInputs()))
+      inputs.push_back(input);
+  SmallVector<Value> outputs;
+  if (getOutputs())
+    for (Value output : state.getPayloadValues(getOutputs()))
+      outputs.push_back(output);
+
+  // Verify that the set of output values to be replaced is unique.
+  llvm::SmallDenseSet<Value> outputSet;
+  for (Value output : outputs) {
+    outputSet.insert(output);
+  }
+  if (outputSet.size() != outputs.size()) {
+    return emitSilenceableFailure(getLoc())
+           << "cast and call output values must be unique";
+  }
+
+  // Get the insertion point for the call.
+  auto insertionOps = state.getPayloadOps(getInsertionPoint());
+  if (!llvm::hasSingleElement(insertionOps)) {
+    return emitSilenceableFailure(getLoc())
+           << "Only one op can be specified as an insertion point";
+  }
+  bool insertAfter = getInsertAfter();
+  Operation *insertionPoint = *insertionOps.begin();
+
+  // Check that all inputs dominate the insertion point, and the insertion
+  // point dominates all users of the outputs.
+  DominanceInfo dom(insertionPoint);
+  for (Value output : outputs) {
+    for (Operation *user : output.getUsers()) {
+      // If we are inserting after the insertion point operation, the
+      // insertion point operation must properly dominate the user. Otherwise
+      // basic dominance is enough.
+      bool doesDominate = insertAfter
+                              ? dom.properlyDominates(insertionPoint, user)
+                              : dom.dominates(insertionPoint, user);
+      if (!doesDominate) {
+        return emitDefiniteFailure()
+               << "User " << user << " is not dominated by insertion point "
+               << insertionPoint;
+      }
+    }
+  }
+
+  for (Value input : inputs) {
+    // If we are inserting before the insertion point operation, the
+    // input must properly dominate the insertion point operation. Otherwise
+    // basic dominance is enough.
+    bool doesDominate = insertAfter
+                            ? dom.dominates(input, insertionPoint)
+                            : dom.properlyDominates(input, insertionPoint);
+    if (!doesDominate) {
+      return emitDefiniteFailure()
+             << "input " << input << " does not dominate insertion point "
+             << insertionPoint;
+    }
+  }
+
+  // Get the function to inline. This can either be specified by symbol or as a
+  // transform handle.
+  func::FuncOp targetFunction = nullptr;
+  if (getFunctionName()) {
+    targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
+        insertionPoint, *getFunctionName());
+    if (!targetFunction) {
+      return emitDefiniteFailure()
+             << "unresolved symbol " << *getFunctionName();
+    }
+  } else if (getFunction()) {
+    auto payloadOps = state.getPayloadOps(getFunction());
+    if (!llvm::hasSingleElement(payloadOps)) {
+      return emitDefiniteFailure() << "requires a single function to call";
+    }
+    targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
+    if (!targetFunction) {
+      return emitDefiniteFailure() << "invalid non-function callee";
+    }
+  } else {
+    llvm_unreachable("Invalid CastAndCall op without a function to call");
+    return emitDefiniteFailure();
+  }
+  assert(targetFunction && "no target function found");
+
+  // Verify that the function argument and result lengths match the inputs and
+  // outputs given to this op.
+  if (targetFunction.getNumArguments() != inputs.size()) {
+    return emitSilenceableFailure(targetFunction.getLoc())
+           << "mismatch between number of function arguments "
+           << targetFunction.getNumArguments() << " and number of inputs "
+           << inputs.size();
+  }
+  if (targetFunction.getNumResults() != outputs.size()) {
+    return emitSilenceableFailure(targetFunction.getLoc())
+           << "mismatch between number of function results "
+           << targetFunction->getNumResults() << " and number of outputs "
+           << outputs.size();
+  }
+
+  // Gather all specified converters.
+  MLIRContext *ctx = insertionPoint->getContext();
+  mlir::TypeConverter converter;
+  if (!getRegion().empty()) {
+    for (Operation &op : getRegion().front()) {
+      cast<transform::TypeConversionOpInterface>(&op)
+          .populateTypeMaterializations(converter);
+    }
+  }
+
+  OpBuilder builder(ctx);
+  if (insertAfter)
+    builder.setInsertionPointAfter(insertionPoint);
+  else
+    builder.setInsertionPoint(insertionPoint);
+
+  for (auto [input, type] :
+       llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
+    if (input.getType() != type) {
+      Value newInput = converter.materializeSourceConversion(
+          builder, input.getLoc(), type, input);
+      if (!newInput) {
+        return emitSilenceableFailure(input.getLoc())
+               << "Failed to materialize conversion of " << input << " to type "
+               << type;
+      }
+      input = newInput;
+    }
+  }
+
+  auto callOp = builder.create<func::CallOp>(insertionPoint->getLoc(),
+                                             targetFunction, inputs);
+
+  // Cast the call results back to the expected types. If any conversions fail
+  // this is a definite failure as the call has been constructed at this point.
+  for (auto [output, newOutput] :
+       llvm::zip_equal(outputs, callOp.getResults())) {
+    Value convertedOutput = newOutput;
+    if (output.getType() != newOutput.getType()) {
+      convertedOutput = converter.materializeTargetConversion(
+          builder, output.getLoc(), output.getType(), newOutput);
+      if (!convertedOutput) {
+        return emitSilenceableFailure(output.getLoc())
+               << "Failed to materialize conversion of " << newOutput
+               << " to type " << output.getType();
+      }
+    }
+    output.replaceAllUsesExcept(convertedOutput, callOp);
+  }
+  results.set(cast<OpResult>(getResult()), {callOp});
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::CastAndCallOp::verify() {
+  if (!getRegion().empty()) {
+    for (Operation &op : getRegion().front()) {
+      if (!isa<transform::TypeConversionOpInterface>(&op)) {
+        InFlightDiagnostic diag = emitOpError()
+                                  << "expected children ops to implement "
+                                     "TypeConversionOpInterface";
+        diag.attachNote(op.getLoc()) << "op without interface";
+        return diag;
+      }
+    }
+  }
+  if (!getFunction() && !getFunctionName()) {
+    return emitOpError() << "expected a function handle or name to call";
+  }
+  if (getFunction() && getFunctionName()) {
+    return emitOpError() << "function handle and name are mutually exclusive";
+  }
+  return success();
+}
+
+void transform::CastAndCallOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getInsertionPoint(), effects);
+  if (getInputs())
+    transform::onlyReadsHandle(getInputs(), effects);
+  if (getOutputs())
+    transform::onlyReadsHandle(getOutputs(), effects);
+  if (getFunction())
+    transform::onlyReadsHandle(getFunction(), effects);
+  transform::producesHandle(getResult(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index ed274238704713..0c89ba2a1f1895 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -15,6 +15,8 @@
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 using namespace tensor;
@@ -128,6 +130,44 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
   tensor::populateRewriteAsConstantPatterns(patterns);
 }
 
+//===----------------------------------------------------------------------===//
+// TypeConversionCastOp
+//===----------------------------------------------------------------------===//
+
+void transform::TypeConversionCastOp::populateTypeMaterializations(
+    TypeConverter &converter) {
+  bool ignoreDynamicInfo = getIgnoreDynamicInfo();
+  converter.addSourceMaterialization([ignoreDynamicInfo](
+                                         OpBuilder &builder, Type resultType,
+                                         ValueRange inputs,
+                                         Location loc) -> std::optional<Value> {
+    if (inputs.size() != 1) {
+      return std::nullopt;
+    }
+    Value input = inputs[0];
+    if (!ignoreDynamicInfo &&
+        !tensor::preservesStaticInformation(resultType, input.getType())) {
+      return std::nullopt;
+    }
+    if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
+      return std::nullopt;
+    }
+    return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
+  });
+  converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
+                                        ValueRange inputs,
+                                        Location loc) -> std::optional<Value> {
+    if (inputs.size() != 1) {
+      return std::nullopt;
+    }
+    Value input = inputs[0];
+    if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
+      return std::nullopt;
+    }
+    return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
+  });
+}
+
 //===----------------------------------------------------------------------===//
 // MakeLoopIndependentOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 485d4448e7c368..f2a57383cc5bf9 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -16,10 +16,12 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
@@ -30,11 +32,13 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
 #include <optional>
 
 #define DEBUG_TYPE "transform-dialect"
diff --git a/mlir/test/Dialect/Func/func-transform.mlir b/mlir/test/Dialect/Func/func-transform.mlir
new file mode 100644
index 00000000000000..6aab07b0cb38a0
--- /dev/null
+++ b/mlir/test/Dialect/Func/func-transform.mlir
@@ -0,0 +1,120 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @basic_cast_and_call
+func.func @basic_cast_and_call() {
+  // CHECK-NEXT: call @second()
+  "test.foo"() : () -> ()
+  // CHECK-NEXT: test.foo
+  // CHECK-NEXT: call @third()
+  func.return
+}
+
+func.func @second() {
+  "test.bar"() : () -> ()
+  func.return
+}
+
+func.func private @third()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:3 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    transform.func.cast_and_call @second before %foo : (!transform.any_op) -> !transform.any_op
+    transform.func.cast_and_call %f#2 after %foo : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @non_empty_arg_and_out
+func.func @non_empty_arg_and_out(%arg0 : index) -> i32 {
+  // CHECK-NEXT: %[[FOO:.+]] = "test.foo"
+  %0 = "test.foo"(%arg0) : (index) -> (index)
+  // CHECK-NEXT: %[[CALL:.+]] = call @second(%[[FOO]]) : (index) -> i32
+  %1 = "test.bar"(%0) : (index) -> (i32)
+  // CHECK: return %[[CALL]] : i32
+  func.return %1 : i32
+}
+
+func.func private @second(%arg1 : index) -> i32
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %bar = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %in = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
+    %out = transform.get_result %bar[0] : (!transform.any_op) -> !transform.any_value
+    transform.func.cast_and_call %f#1(%in) -> %out before %bar
+        : (!transform.any_op, !transform.any_value,
+           !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @multi_arg_and_result
+func.func @multi_arg_and_result(%arg0 : index) -> (index, index) {
+  // CHECK-NEXT: %[[FOO:.+]] = "test.foo"
+  %0 = "test.foo"(%arg0) : (index) -> (index)
+  %1 = "test.bar"(%0) : (index) -> (index)
+  %2 = "test.bar"(%0) : (index) -> (index)
+  // CHECK: %[[CALL:.+]]:2 = call @second(%[[FOO]], %[[FOO]]) : (index, index) -> (index, index)
+  // CHECK: return %[[CALL]]#0, %[[CALL]]#1 : index, index
+  func.return %1, %2 : index, index
+}
+
+func.func private @second(%arg1: index, %arg2: index) -> (index, index)
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %foo = transform.structured.match ops{["test.foo"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %bars = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %in0 = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
+    %in1 = transform.get_result %foo[0] : (!transform.any_op) -> !transform.any_value
+    %ins = transform.merge_handles %in0, %in1 : !transform.any_value
+
+    %outs = transform.get_result %bars[0] : (!transform.any_op) -> !transform.any_value
+
+    transform.func.cast_and_call %f#1(%ins) -> %outs after %foo
+        : (!transform.any_op, !transform.any_value,
+           !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @nested_call
+func.func @nested_call() {
+  // CHECK-NEXT: %[[ARG:.+]] = "test.arg"
+  // CHECK-NEXT: test.foo
+  %0 = "test.arg"() : () -> (index)
+  "test.foo"() ({
+    // CHECK-NEXT: call @second(%[[ARG]]) : (index) -> ()
+    "test.bar"(%0) : (index) -> ()
+  }) : () -> ()
+}
+
+func.func private @second(%arg1: index) -> ()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %arg = transform.structured.match ops{["test.arg"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %bar = transform.structured.match ops{["test.bar"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %in = transform.get_result %arg[0] : (!transform.any_op) -> !transform.any_value
+
+    transform.func.cast_and_call %f#1(%in) before %bar
+        : (!transform.any_op, !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Tensor/transform-op-casting.mlir b/mlir/test/Dialect/Tensor/transform-op-casting.mlir
new file mode 100644
index 00000000000000..fd2fc8a1883a3c
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/transform-op-casting.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file | FileCheck %s
+
+func.func @cast_to_dynamic(%arg0: tensor<10x13xf32>, %arg1: tensor<3x13xf32>) -> tensor<13x13xf32> {
+  %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<10x13xf32>, tensor<3x13xf32>) -> tensor<13x13xf32>
+  func.return %0 : tensor<13x13xf32>
+}
+
+func.func private @concat_replacement(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %concat = transform.structured.match ops{["tensor.concat"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %ins = transform.get_operand %concat : (!transform.any_op) -> !transform.any_value
+    %out = transform.get_result %concat : (!transform.any_op) -> !transform.any_value
+    transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
+      transform.type_conversion.tensor.cast
+    } : (!transform.any_op, !transform.any_value,
+         !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.apply_dce to %f#0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @cast_to_dynamic
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<10x13xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<3x13xf32>
+//   CHECK-DAG:   %[[CAST0:.+]] = tensor.cast %[[ARG0]] : tensor<10x13xf32> to tensor<?x?xf32>
+//   CHECK-DAG:   %[[CAST1:.+]] = tensor.cast %[[ARG1]] : tensor<3x13xf32> to tensor<?x?xf32>
+//       CHECK:   %[[CALL:.+]] = call @concat_replacement(%[[CAST0]], %[[CAST1]])
+//       CHECK:   %[[CAST_RES:.+]] = tensor.cast %[[CALL]] : tensor<?x?xf32> to tensor<13x13xf32>
+//       CHECK:   return %[[CAST_RES]] : tensor<13x13xf32>
+
+// -----
+
+func.func @cast_to_static(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
+  func.return %0 : tensor<?xf32>
+}
+
+func.func private @collapse_replacement(%arg0: tensor<4x5xf32>) -> tensor<20xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %concat = transform.structured.match ops{["tensor.collapse_shape"]} in %f#0 : (!transform.any_op) -> !transform.any_op
+    %ins = transform.get_operand %concat : (!transform.any_op) -> !transform.any_value
+    %out = transform.get_result %concat : (!transform.any_op) -> !transform.any_value
+    transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
+      transform.type_conversion.tensor.cast ignore_dynamic_info
+    } : (!transform.any_op, !transform.any_value,
+         !transform.any_value, !transform.any_op) -> !transform.any_op
+    transform.apply_dce to %f#0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @cast_to_static
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//   CHECK-DAG:   %[[CAST_IN:.+]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x5xf32>
+//       CHECK:   %[[CALL:.+]] = call @collapse_replacement(%[[CAST_IN]])
+//       CHECK:   %[[CAST_RES:.+]] = tensor.cast %[[CALL]] : tensor<20xf32> to tensor<?xf32>
+//       CHECK:   return %[[CAST_RES]] : tensor<?xf32>

>From b70c13d9b6ce4b0458782b3baa49beedaf341f89 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 18 Jan 2024 00:10:21 -0500
Subject: [PATCH 3/3] Collapse TypeConversion interface into converter builder
 interface and address comments

---
 .../Func/TransformOps/FuncTransformOps.td     | 16 ++---
 .../MemRef/TransformOps/MemRefTransformOps.td |  3 +-
 .../Tensor/TransformOps/TensorTransformOps.td | 15 +++--
 .../Transform/IR/TransformInterfaces.td       | 43 ++++++--------
 .../Func/TransformOps/FuncTransformOps.cpp    | 58 +++++++++----------
 .../TransformOps/TensorTransformOps.cpp       |  6 +-
 .../Dialect/Tensor/transform-op-casting.mlir  | 12 ++--
 .../TestTransformDialectExtension.td          |  3 +-
 8 files changed, 75 insertions(+), 81 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index e5086c26c55a4f..afb08ebd5eb435 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -38,22 +38,22 @@ def CastAndCallOp : Op<Transform_Dialect,
   let summary = "Casts values to the signature of a function and replaces them "
                 "with a call";
   let description = [{
-    This transform takes a set of |input| and |output| value handles and
+    This transform takes value handles to a set of `inputs` and `outputs` and
     attempts to cast them to the function signature of the attached function
     op, then builds a call to the function and replaces the users of the
     outputs. It is the responsibility of the user to ensure that the slice of
     the program replaced by this operation makes sense, i.e. there is no
     verification that the inputs to this operation have any relation to the
-    outputs outside of basic dominance requirements needed for the replacement.
+    outputs outside of basic dominance requirements needed for the call.
 
     The casting materialization functions are specified in the graph region of
-    this op. They must implement the `TypeConversionOpInterface`. The order of
-    ops within the region is irrelevant.
+    this op. They must implement the `TypeConverterBuilderOpInterface`. The
+    order of ops within the region is irrelevant.
 
     The target function can be specified by a symbol name or by a handle to the
     operation.
 
-    This transform only reads the target handles and only replaces the users of
+    This transform only reads the operand handles and only replaces the users of
     the outputs with the results of the call. No handles are consumed and no
     operations are removed. Users are expected to run cleanup separately if
     desired.
@@ -64,11 +64,11 @@ def CastAndCallOp : Op<Transform_Dialect,
      - The insertion point op does not dominate any of the output users
      - The insertion point op is not dominated by any of the inputs
      - The function signature does not match the number of inputs/outputs
-     - Any of the input conversions fail to be materialized
 
     This transform will emit a definite failure if it fails to resolve the
-    target function, or if it fails to materialize the conversion from the call
-    results to the output types.
+    target function, or if it fails to materialize the conversion casts of
+    either the inputs to the function argument types, or the call results to
+    the output types.
   }];
 
   let arguments = (ins
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 76309b9b8a9640..29383a3825be88 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -18,7 +18,8 @@ include "mlir/IR/OpBase.td"
 def MemrefToLLVMTypeConverterOp : Op<Transform_Dialect,
     "apply_conversion_patterns.memref.memref_to_llvm_type_converter",
     [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
-                               ["getTypeConverterType"]>]> {
+                               ["getTypeConverter",
+                                "getTypeConverterType"]>]> {
   let description = [{
     This operation provides an "LLVMTypeConverter" that lowers memref types to
     LLVM types.
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 28e9249c82e309..39e1d7fa3494a3 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -169,12 +169,17 @@ def MakeLoopIndependentOp
   }];
 }
 
-def TypeConversionCastOp : Op<Transform_Dialect,
-    "type_conversion.tensor.cast",
-    [DeclareOpInterfaceMethods<TypeConversionOpInterface>]> {
+def TypeConversionCastShapeDynamicDimsOp : Op<Transform_Dialect,
+    "type_conversion.tensor.cast_shape_dynamic_dims",
+    [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
+                               ["populateTypeMaterializations"]>]> {
   let description = [{
-    Indicates that tensor ops (such as tensor.generate) should be replaced with
-    constants (arith.constant) when possible.
+    Populates a type converter with conversion materialization functions that
+    cast a tensor value between two cast-compatible tensors. See `tensor.cast`
+    for more information on cast compatibility between tensors.
+
+    If `ignore_dynamic_info` is not set, this will set an additional constraint
+    that source materializations do not cast dynamic dimensions to static ones.
   }];
   let arguments = (ins UnitAttr:$ignore_dynamic_info);
 
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index 3b601f42a6452d..1ef094436881aa 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -280,34 +280,12 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
   ];
 }
 
-def TypeConversionOpInterface : OpInterface<"TypeConversionOpInterface"> {
-  let description = [{
-    This interface should be implemented by ops that populate type casting
-    of a `transform.cast_and_inline` op. It provides a method to populate a
-    type converter with source/target materialization patterns.
-  }];
-
-  let cppNamespace = "::mlir::transform";
-
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/[{
-        Populate the given type converter with source/target materialization
-        functions.
-      }],
-      /*returnType=*/"void",
-      /*name=*/"populateTypeMaterializations",
-      /*arguments=*/(ins "::mlir::TypeConverter &":$converter)
-    >
-  ];
-}
-
 def TypeConverterBuilderOpInterface
     : OpInterface<"TypeConverterBuilderOpInterface"> {
   let description = [{
     This interface should be implemented by ops that specify a type converter
-    for a dialect conversion. Such ops can be used with
-    "apply_conversion_patterns".
+    for a dialect conversion, or to populate a type converter with
+    conversions. Such ops can be used with "apply_conversion_patterns".
   }];
 
   let cppNamespace = "::mlir::transform";
@@ -319,7 +297,11 @@ def TypeConverterBuilderOpInterface
       }],
       /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
       /*name=*/"getTypeConverter",
-      /*arguments=*/(ins)
+      /*arguments=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return std::make_unique<::mlir::TypeConverter>();
+      }]
     >,
     StaticInterfaceMethod<
       /*desc=*/[{
@@ -332,6 +314,17 @@ def TypeConverterBuilderOpInterface
       /*methodBody=*/"",
       /*defaultImplementation=*/[{ return "TypeConverter"; }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Populate the given type converter with source/target materialization
+        functions.
+      }],
+      /*returnType=*/"void",
+      /*name=*/"populateTypeMaterializations",
+      /*arguments=*/(ins "::mlir::TypeConverter &":$converter),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return; }]
+    >,
   ];
 }
 
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 14b6e633520d6c..9e79b086c0be84 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -47,21 +47,19 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
                                 transform::TransformState &state) {
   SmallVector<Value> inputs;
   if (getInputs())
-    for (Value input : state.getPayloadValues(getInputs()))
-      inputs.push_back(input);
-  SmallVector<Value> outputs;
-  if (getOutputs())
-    for (Value output : state.getPayloadValues(getOutputs()))
-      outputs.push_back(output);
+    llvm::append_range(inputs, state.getPayloadValues(getInputs()));
 
-  // Verify that the set of output values to be replaced is unique.
-  llvm::SmallDenseSet<Value> outputSet;
-  for (Value output : outputs) {
-    outputSet.insert(output);
-  }
-  if (outputSet.size() != outputs.size()) {
-    return emitSilenceableFailure(getLoc())
-           << "cast and call output values must be unique";
+  SetVector<Value> outputs;
+  if (getOutputs()) {
+    for (auto output : state.getPayloadValues(getOutputs()))
+      outputs.insert(output);
+
+    // Verify that the set of output values to be replaced is unique.
+    if (outputs.size() !=
+        llvm::range_size(state.getPayloadValues(getOutputs()))) {
+      return emitSilenceableFailure(getLoc())
+             << "cast and call output values must be unique";
+    }
   }
 
   // Get the insertion point for the call.
@@ -106,7 +104,7 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
     }
   }
 
-  // Get the function to inline. This can either be specified by symbol or as a
+  // Get the function to call. This can either be specified by symbol or as a
   // transform handle.
   func::FuncOp targetFunction = nullptr;
   if (getFunctionName()) {
@@ -129,7 +127,6 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
     llvm_unreachable("Invalid CastAndCall op without a function to call");
     return emitDefiniteFailure();
   }
-  assert(targetFunction && "no target function found");
 
   // Verify that the function argument and result lengths match the inputs and
   // outputs given to this op.
@@ -147,37 +144,34 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
   }
 
   // Gather all specified converters.
-  MLIRContext *ctx = insertionPoint->getContext();
   mlir::TypeConverter converter;
   if (!getRegion().empty()) {
     for (Operation &op : getRegion().front()) {
-      cast<transform::TypeConversionOpInterface>(&op)
+      cast<transform::TypeConverterBuilderOpInterface>(&op)
           .populateTypeMaterializations(converter);
     }
   }
 
-  OpBuilder builder(ctx);
   if (insertAfter)
-    builder.setInsertionPointAfter(insertionPoint);
+    rewriter.setInsertionPointAfter(insertionPoint);
   else
-    builder.setInsertionPoint(insertionPoint);
+    rewriter.setInsertionPoint(insertionPoint);
 
   for (auto [input, type] :
        llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
     if (input.getType() != type) {
       Value newInput = converter.materializeSourceConversion(
-          builder, input.getLoc(), type, input);
+          rewriter, input.getLoc(), type, input);
       if (!newInput) {
-        return emitSilenceableFailure(input.getLoc())
-               << "Failed to materialize conversion of " << input << " to type "
-               << type;
+        return emitDefiniteFailure() << "Failed to materialize conversion of "
+                                     << input << " to type " << type;
       }
       input = newInput;
     }
   }
 
-  auto callOp = builder.create<func::CallOp>(insertionPoint->getLoc(),
-                                             targetFunction, inputs);
+  auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
+                                              targetFunction, inputs);
 
   // Cast the call results back to the expected types. If any conversions fail
   // this is a definite failure as the call has been constructed at this point.
@@ -186,14 +180,14 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
     Value convertedOutput = newOutput;
     if (output.getType() != newOutput.getType()) {
       convertedOutput = converter.materializeTargetConversion(
-          builder, output.getLoc(), output.getType(), newOutput);
+          rewriter, output.getLoc(), output.getType(), newOutput);
       if (!convertedOutput) {
-        return emitSilenceableFailure(output.getLoc())
+        return emitDefiniteFailure()
                << "Failed to materialize conversion of " << newOutput
                << " to type " << output.getType();
       }
     }
-    output.replaceAllUsesExcept(convertedOutput, callOp);
+    rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
   }
   results.set(cast<OpResult>(getResult()), {callOp});
   return DiagnosedSilenceableFailure::success();
@@ -202,10 +196,10 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
 LogicalResult transform::CastAndCallOp::verify() {
   if (!getRegion().empty()) {
     for (Operation &op : getRegion().front()) {
-      if (!isa<transform::TypeConversionOpInterface>(&op)) {
+      if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
         InFlightDiagnostic diag = emitOpError()
                                   << "expected children ops to implement "
-                                     "TypeConversionOpInterface";
+                                     "TypeConverterBuilderOpInterface";
         diag.attachNote(op.getLoc()) << "op without interface";
         return diag;
       }
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 0c89ba2a1f1895..38f1824a3634a3 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -131,11 +131,11 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
 }
 
 //===----------------------------------------------------------------------===//
-// TypeConversionCastOp
+// TypeConversionCastTensorShapeOp
 //===----------------------------------------------------------------------===//
 
-void transform::TypeConversionCastOp::populateTypeMaterializations(
-    TypeConverter &converter) {
+void transform::TypeConversionCastShapeDynamicDimsOp::
+    populateTypeMaterializations(TypeConverter &converter) {
   bool ignoreDynamicInfo = getIgnoreDynamicInfo();
   converter.addSourceMaterialization([ignoreDynamicInfo](
                                          OpBuilder &builder, Type resultType,
diff --git a/mlir/test/Dialect/Tensor/transform-op-casting.mlir b/mlir/test/Dialect/Tensor/transform-op-casting.mlir
index fd2fc8a1883a3c..16a1fa2b0ba9c7 100644
--- a/mlir/test/Dialect/Tensor/transform-op-casting.mlir
+++ b/mlir/test/Dialect/Tensor/transform-op-casting.mlir
@@ -12,10 +12,10 @@ module attributes {transform.with_named_sequence} {
     %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     %concat = transform.structured.match ops{["tensor.concat"]} in %f#0 : (!transform.any_op) -> !transform.any_op
-    %ins = transform.get_operand %concat : (!transform.any_op) -> !transform.any_value
-    %out = transform.get_result %concat : (!transform.any_op) -> !transform.any_value
+    %ins = transform.get_operand %concat[all] : (!transform.any_op) -> !transform.any_value
+    %out = transform.get_result %concat[all] : (!transform.any_op) -> !transform.any_value
     transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
-      transform.type_conversion.tensor.cast
+      transform.type_conversion.tensor.cast_shape_dynamic_dims
     } : (!transform.any_op, !transform.any_value,
          !transform.any_value, !transform.any_op) -> !transform.any_op
     transform.apply_dce to %f#0 : !transform.any_op
@@ -46,10 +46,10 @@ module attributes {transform.with_named_sequence} {
     %funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     %concat = transform.structured.match ops{["tensor.collapse_shape"]} in %f#0 : (!transform.any_op) -> !transform.any_op
-    %ins = transform.get_operand %concat : (!transform.any_op) -> !transform.any_value
-    %out = transform.get_result %concat : (!transform.any_op) -> !transform.any_value
+    %ins = transform.get_operand %concat[all] : (!transform.any_op) -> !transform.any_value
+    %out = transform.get_result %concat[all] : (!transform.any_op) -> !transform.any_value
     transform.func.cast_and_call %f#1(%ins) -> %out before %concat {
-      transform.type_conversion.tensor.cast ignore_dynamic_info
+      transform.type_conversion.tensor.cast_shape_dynamic_dims ignore_dynamic_info
     } : (!transform.any_op, !transform.any_value,
          !transform.any_value, !transform.any_op) -> !transform.any_op
     transform.apply_dce to %f#0 : !transform.any_op
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 54036f7929d1b8..c00cc560e83e9b 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -502,7 +502,8 @@ def ApplyTestConversionPatternsOp
 
 def TestTypeConverterOp
   : Op<Transform_Dialect, "apply_conversion_patterns.transform.test_type_converter",
-      [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface>]> {
+      [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
+                                 ["getTypeConverter"]>]> {
   let arguments = (ins);
   let results = (outs);
   let assemblyFormat = "attr-dict";



More information about the Mlir-commits mailing list