[Mlir-commits] [mlir] 5caab8b - [mlir][transform] Add transform.get_operand op (#78397)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 18 06:33:18 PST 2024


Author: Quinn Dawkins
Date: 2024-01-18T09:33:14-05:00
New Revision: 5caab8bbc0f89f46aca07be2090c8d23c78605ba

URL: https://github.com/llvm/llvm-project/commit/5caab8bbc0f89f46aca07be2090c8d23c78605ba
DIFF: https://github.com/llvm/llvm-project/commit/5caab8bbc0f89f46aca07be2090c8d23c78605ba.diff

LOG: [mlir][transform] Add transform.get_operand op (#78397)

Similar to `transform.get_result`, except it returns a handle to the
operand indicated by a positional specification, same as is defined for
the linalg match ops.

Additionally updates `get_result` to take the same positional specification.
This makes the use case of wanting to get all of the results of an
operation easier by no longer requiring the user to reconstruct the list
of results one-by-one.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
    mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
    mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/python/mlir/dialects/transform/extras/__init__.py
    mlir/test/Dialect/Transform/test-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index 9e108529ec129b2..162dd05f93030f2 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -288,7 +288,7 @@ def MatchStructuredDimOp : Op<Transform_Dialect, "match.structured.dim", [
   let results = (outs Optional<TransformParamTypeInterface>:$result);
   let assemblyFormat =
       "$operand_handle `[`"
-      "custom<StructuredTransformDims>($raw_dim_list, $is_inverted, $is_all)"
+      "custom<TransformMatchDims>($raw_dim_list, $is_inverted, $is_all)"
       "`]` attr-dict `:` "
       "custom<SemiFunctionType>(type($operand_handle), type($result))";
 
@@ -347,7 +347,7 @@ class MatchStructuredOperandOp<string opname> : Op<Transform_Dialect, opname, [
       (outs Optional<AnyTypeOf<[TransformAnyHandle,Transform_AffineMapParamType]>>:$result);
   let assemblyFormat =
       "$operand_handle `[`"
-      "custom<StructuredTransformDims>($raw_position_list, $is_inverted, $is_all)"
+      "custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
       "`]` attr-dict "
       "`:` custom<SemiFunctionType>(type($operand_handle), type($result))";
 

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index b155b110677d6c7..36aeb4583029c96 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -9,11 +9,12 @@
 #ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
 #define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
 
+#include <optional>
+#include <type_traits>
+
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
 #include "llvm/ADT/STLExtras.h"
-#include <optional>
-#include <type_traits>
 
 namespace mlir {
 namespace transform {
@@ -168,6 +169,52 @@ class SingleValueMatcherOpTrait
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Printing/parsing for positional specification matchers
+//===----------------------------------------------------------------------===//
+
+/// Parses a positional index specification for transform match operations.
+/// The following forms are accepted:
+///
+///  - `all`: sets `isAll` and returns;
+///  - comma-separated-integer-list: populates `rawDimList` with the values;
+///  - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
+///  with the values and sets `isInverted`.
+ParseResult parseTransformMatchDims(OpAsmParser &parser,
+                                    DenseI64ArrayAttr &rawDimList,
+                                    UnitAttr &isInverted, UnitAttr &isAll);
+
+/// Prints a positional index specification for transform match operations.
+void printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
+                             DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
+                             UnitAttr isAll);
+
+//===----------------------------------------------------------------------===//
+// Utilities for positional specification matchers
+//===----------------------------------------------------------------------===//
+
+/// Checks if the positional specification defined is valid and reports errors
+/// otherwise.
+LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef<int64_t> raw,
+                                         bool inverted, bool all);
+
+/// Populates `result` with the positional identifiers relative to `maxNumber`.
+/// If `isAll` is set, the result will contain all numbers from `0` to
+/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
+/// values from `rawList` are  are interpreted as counting backwards from
+/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
+/// numbers remain as is. If `isInverted` is set, populates `result` with those
+/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
+/// `rawList`. If `rawList` contains values that are greater than or equal to
+/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
+/// given location. `maxNumber` must be positive. If `rawList` contains
+/// duplicate numbers or numbers that become duplicate after negative value
+/// remapping, emits a silenceable error.
+DiagnosedSilenceableFailure
+expandTargetSpecification(Location loc, bool isAll, bool isInverted,
+                          ArrayRef<int64_t> rawList, int64_t maxNumber,
+                          SmallVectorImpl<int64_t> &result);
+
 } // namespace transform
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 2627553cff69533..9f513822ed0a4e1 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -728,24 +728,75 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
                        "functional-type(operands, results)";
 }
 
+def GetOperandOp : TransformDialectOp<"get_operand",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
+  let summary = "Get a handle to the operand(s) of the targeted op";
+  let description = [{
+    The handle defined by this Transform op corresponds to the operands of the
+    given `target` operation specified by the given set of positions. There are
+    three possible modes:
+
+     - Position list directly, i.e. `%target[0, 1, 2]`. This will return the
+       operands at the specified positions.
+     - Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
+       all operands except those at the given positions.
+     - All, i.e. `%target[all]`. This will return all operands of the operation.
+    
+    This transform produces a silenceable failure if any of the operand indices
+    exceeds the number of operands in the target. It reads the target handle and
+    produces the result handle.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       DenseI64ArrayAttr:$raw_position_list,
+                       UnitAttr:$is_inverted,
+                       UnitAttr:$is_all);
+  let results = (outs TransformValueHandleTypeInterface:$result);
+  let assemblyFormat =
+      "$target `[`"
+      "custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
+      "`]` attr-dict `:` functional-type(operands, results)";
+  let hasVerifier = 1;
+}
+
 def GetResultOp : TransformDialectOp<"get_result",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
-  let summary = "Get handle to the a result of the targeted op";
+  let summary = "Get a handle to the result(s) of the targeted op";
   let description = [{
-    The handle defined by this Transform op corresponds to the OpResult with
-    `result_number` that is defined by the given `target` operation.
-
-    This transform produces a silenceable failure if the targeted operation
-    does not have enough results. It reads the target handle and produces the
-    result handle.
+    The handle defined by this Transform op correspond to the OpResults of the
+    given `target` operation. Optionally `result_number` can be specified to
+    select a specific result.
+    
+    This transform fails silently if the targeted operation does not have enough
+    results. It reads the target handle and produces the result handle.
+
+    The handle defined by this Transform op corresponds to the results of the
+    given `target` operation specified by the given set of positions. There are
+    three possible modes:
+
+     - Position list directly, i.e. `%target[0, 1, 2]`. This will return the
+       results at the specified positions.
+     - Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
+       all results except those at the given positions.
+     - All, i.e. `%target[all]`. This will return all results of the operation.
+    
+    This transform produces a silenceable failure if any of the result indices
+    exceeds the number of results returned by the target. It reads the target
+    handle and produces the result handle.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       I64Attr:$result_number);
+                       DenseI64ArrayAttr:$raw_position_list,
+                       UnitAttr:$is_inverted,
+                       UnitAttr:$is_all);
   let results = (outs TransformValueHandleTypeInterface:$result);
-  let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
-                       "functional-type(operands, results)";
+  let assemblyFormat =
+      "$target `[`"
+      "custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
+      "`]` attr-dict `:` functional-type(operands, results)";
+  let hasVerifier = 1;
 }
 
 def GetTypeOp : TransformDialectOp<"get_type",

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index fb021ed76242e90..115da4b90e063ac 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -330,91 +330,6 @@ static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
   return DiagnosedSilenceableFailure::success();
 }
 
-/// Populates `result` with the positional identifiers relative to `maxNumber`.
-/// If `isAll` is set, the result will contain all numbers from `0` to
-/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
-/// values from `rawList` are  are interpreted as counting backwards from
-/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
-/// numbers remain as is. If `isInverted` is set, populates `result` with those
-/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
-/// `rawList`. If `rawList` contains values that are greater than or equal to
-/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
-/// given location. `maxNumber` must be positive. If `rawList` contains
-/// duplicate numbers or numbers that become duplicate after negative value
-/// remapping, emits a silenceable error.
-static DiagnosedSilenceableFailure
-expandTargetSpecification(Location loc, bool isAll, bool isInverted,
-                          ArrayRef<int64_t> rawList, int64_t maxNumber,
-                          SmallVectorImpl<int64_t> &result) {
-  assert(maxNumber > 0 && "expected size to be positive");
-  assert(!(isAll && isInverted) && "cannot invert all");
-  if (isAll) {
-    result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
-    return DiagnosedSilenceableFailure::success();
-  }
-
-  SmallVector<int64_t> expanded;
-  llvm::SmallDenseSet<int64_t> visited;
-  expanded.reserve(rawList.size());
-  SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
-  for (int64_t raw : rawList) {
-    int64_t updated = raw < 0 ? maxNumber + raw : raw;
-    if (updated >= maxNumber) {
-      return emitSilenceableFailure(loc)
-             << "position overflow " << updated << " (updated from " << raw
-             << ") for maximum " << maxNumber;
-    }
-    if (updated < 0) {
-      return emitSilenceableFailure(loc) << "position underflow " << updated
-                                         << " (updated from " << raw << ")";
-    }
-    if (!visited.insert(updated).second) {
-      return emitSilenceableFailure(loc) << "repeated position " << updated
-                                         << " (updated from " << raw << ")";
-    }
-    target.push_back(updated);
-  }
-
-  if (!isInverted)
-    return DiagnosedSilenceableFailure::success();
-
-  result.reserve(result.size() + (maxNumber - expanded.size()));
-  for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
-    if (llvm::is_contained(expanded, candidate))
-      continue;
-    result.push_back(candidate);
-  }
-
-  return DiagnosedSilenceableFailure::success();
-}
-
-/// Checks if the positional specification defined is valid and reports errors
-/// otherwise.
-LogicalResult verifyStructuredTransformDimsOp(Operation *op,
-                                              ArrayRef<int64_t> raw,
-                                              bool inverted, bool all) {
-  if (all) {
-    if (inverted) {
-      return op->emitOpError()
-             << "cannot request both 'all' and 'inverted' values in the list";
-    }
-    if (!raw.empty()) {
-      return op->emitOpError()
-             << "cannot both request 'all' and specific values in the list";
-    }
-  }
-  if (!all && raw.empty()) {
-    return op->emitOpError() << "must request specific values in the list if "
-                                "'all' is not specified";
-  }
-  SmallVector<int64_t> rawVector = llvm::to_vector(raw);
-  auto *it = std::unique(rawVector.begin(), rawVector.end());
-  if (it != rawVector.end())
-    return op->emitOpError() << "expected the listed values to be unique";
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // MatchStructuredDimOp
 //===----------------------------------------------------------------------===//
@@ -475,8 +390,8 @@ LogicalResult transform::MatchStructuredDimOp::verify() {
     return emitOpError() << "cannot request the same dimension to be both "
                             "parallel and reduction";
   }
-  return verifyStructuredTransformDimsOp(getOperation(), getRawDimList(),
-                                         getIsInverted(), getIsAll());
+  return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
+                                    getIsInverted(), getIsAll());
 }
 
 //===----------------------------------------------------------------------===//
@@ -592,8 +507,8 @@ LogicalResult verifyStructuredOperandOp(OpTy op) {
 LogicalResult transform::MatchStructuredInputOp::verify() {
   if (failed(verifyStructuredOperandOp(*this)))
     return failure();
-  return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
-                                         getIsInverted(), getIsAll());
+  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+                                    getIsInverted(), getIsAll());
 }
 
 //===----------------------------------------------------------------------===//
@@ -665,8 +580,8 @@ DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
 LogicalResult transform::MatchStructuredInitOp::verify() {
   if (failed(verifyStructuredOperandOp(*this)))
     return failure();
-  return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
-                                         getIsInverted(), getIsAll());
+  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+                                    getIsInverted(), getIsAll());
 }
 
 //===----------------------------------------------------------------------===//
@@ -793,78 +708,5 @@ void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
   build(builder, state, ValueRange());
 }
 
-//===----------------------------------------------------------------------===//
-// Printing and parsing for structured match ops.
-//===----------------------------------------------------------------------===//
-
-/// Keyword syntax for positional specification inversion.
-constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
-
-/// Keyword syntax for full inclusion in positional specification.
-constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
-
-/// Parses a positional specification for structured transform operations. The
-/// following forms are accepted:
-///
-///  - `all`: sets `isAll` and returns;
-///  - comma-separated-integer-list: populates `rawDimList` with the values;
-///  - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
-///  with the values and sets `isInverted`.
-static ParseResult parseStructuredTransformDims(OpAsmParser &parser,
-                                                DenseI64ArrayAttr &rawDimList,
-                                                UnitAttr &isInverted,
-                                                UnitAttr &isAll) {
-  Builder &builder = parser.getBuilder();
-  if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
-    rawDimList = builder.getDenseI64ArrayAttr({});
-    isInverted = nullptr;
-    isAll = builder.getUnitAttr();
-    return success();
-  }
-
-  isAll = nullptr;
-  isInverted = nullptr;
-  if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
-    isInverted = builder.getUnitAttr();
-  }
-
-  if (isInverted) {
-    if (parser.parseLParen().failed())
-      return failure();
-  }
-
-  SmallVector<int64_t> values;
-  ParseResult listResult = parser.parseCommaSeparatedList(
-      [&]() { return parser.parseInteger(values.emplace_back()); });
-  if (listResult.failed())
-    return failure();
-
-  rawDimList = builder.getDenseI64ArrayAttr(values);
-
-  if (isInverted) {
-    if (parser.parseRParen().failed())
-      return failure();
-  }
-  return success();
-}
-
-/// Prints a positional specification for structured transform operations.
-static void printStructuredTransformDims(OpAsmPrinter &printer, Operation *op,
-                                         DenseI64ArrayAttr rawDimList,
-                                         UnitAttr isInverted, UnitAttr isAll) {
-  if (isAll) {
-    printer << kDimAllKeyword;
-    return;
-  }
-  if (isInverted) {
-    printer << kDimExceptKeyword << "(";
-  }
-  llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
-                        [&](int64_t value) { printer << value; });
-  if (isInverted) {
-    printer << ")";
-  }
-}
-
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp
index 6049ee767e1bb74..b9b6dabc26216e1 100644
--- a/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp
@@ -10,6 +10,141 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Printing and parsing for match ops.
+//===----------------------------------------------------------------------===//
+
+/// Keyword syntax for positional specification inversion.
+constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
+
+/// Keyword syntax for full inclusion in positional specification.
+constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
+
+ParseResult transform::parseTransformMatchDims(OpAsmParser &parser,
+                                               DenseI64ArrayAttr &rawDimList,
+                                               UnitAttr &isInverted,
+                                               UnitAttr &isAll) {
+  Builder &builder = parser.getBuilder();
+  if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
+    rawDimList = builder.getDenseI64ArrayAttr({});
+    isInverted = nullptr;
+    isAll = builder.getUnitAttr();
+    return success();
+  }
+
+  isAll = nullptr;
+  isInverted = nullptr;
+  if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
+    isInverted = builder.getUnitAttr();
+  }
+
+  if (isInverted) {
+    if (parser.parseLParen().failed())
+      return failure();
+  }
+
+  SmallVector<int64_t> values;
+  ParseResult listResult = parser.parseCommaSeparatedList(
+      [&]() { return parser.parseInteger(values.emplace_back()); });
+  if (listResult.failed())
+    return failure();
+
+  rawDimList = builder.getDenseI64ArrayAttr(values);
+
+  if (isInverted) {
+    if (parser.parseRParen().failed())
+      return failure();
+  }
+  return success();
+}
+
+void transform::printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
+                                        DenseI64ArrayAttr rawDimList,
+                                        UnitAttr isInverted, UnitAttr isAll) {
+  if (isAll) {
+    printer << kDimAllKeyword;
+    return;
+  }
+  if (isInverted) {
+    printer << kDimExceptKeyword << "(";
+  }
+  llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
+                        [&](int64_t value) { printer << value; });
+  if (isInverted) {
+    printer << ")";
+  }
+}
+
+LogicalResult transform::verifyTransformMatchDimsOp(Operation *op,
+                                                    ArrayRef<int64_t> raw,
+                                                    bool inverted, bool all) {
+  if (all) {
+    if (inverted) {
+      return op->emitOpError()
+             << "cannot request both 'all' and 'inverted' values in the list";
+    }
+    if (!raw.empty()) {
+      return op->emitOpError()
+             << "cannot both request 'all' and specific values in the list";
+    }
+  }
+  if (!all && raw.empty()) {
+    return op->emitOpError() << "must request specific values in the list if "
+                                "'all' is not specified";
+  }
+  SmallVector<int64_t> rawVector = llvm::to_vector(raw);
+  auto *it = std::unique(rawVector.begin(), rawVector.end());
+  if (it != rawVector.end())
+    return op->emitOpError() << "expected the listed values to be unique";
+
+  return success();
+}
+
+DiagnosedSilenceableFailure transform::expandTargetSpecification(
+    Location loc, bool isAll, bool isInverted, ArrayRef<int64_t> rawList,
+    int64_t maxNumber, SmallVectorImpl<int64_t> &result) {
+  assert(maxNumber > 0 && "expected size to be positive");
+  assert(!(isAll && isInverted) && "cannot invert all");
+  if (isAll) {
+    result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  SmallVector<int64_t> expanded;
+  llvm::SmallDenseSet<int64_t> visited;
+  expanded.reserve(rawList.size());
+  SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
+  for (int64_t raw : rawList) {
+    int64_t updated = raw < 0 ? maxNumber + raw : raw;
+    if (updated >= maxNumber) {
+      return emitSilenceableFailure(loc)
+             << "position overflow " << updated << " (updated from " << raw
+             << ") for maximum " << maxNumber;
+    }
+    if (updated < 0) {
+      return emitSilenceableFailure(loc) << "position underflow " << updated
+                                         << " (updated from " << raw << ")";
+    }
+    if (!visited.insert(updated).second) {
+      return emitSilenceableFailure(loc) << "repeated position " << updated
+                                         << " (updated from " << raw << ")";
+    }
+    target.push_back(updated);
+  }
+
+  if (!isInverted)
+    return DiagnosedSilenceableFailure::success();
+
+  result.reserve(result.size() + (maxNumber - expanded.size()));
+  for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
+    if (llvm::is_contained(expanded, candidate))
+      continue;
+    result.push_back(candidate);
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // Generated interface implementation.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index b80fc09751d2aa4..485d4448e7c3683 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1464,6 +1464,39 @@ transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// GetOperandOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
+                               transform::TransformResults &results,
+                               transform::TransformState &state) {
+  SmallVector<Value> operands;
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    SmallVector<int64_t> operandPositions;
+    DiagnosedSilenceableFailure diag = expandTargetSpecification(
+        getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
+        target->getNumOperands(), operandPositions);
+    if (diag.isSilenceableFailure()) {
+      diag.attachNote(target->getLoc())
+          << "while considering positions of this payload operation";
+      return diag;
+    }
+    llvm::append_range(operands,
+                       llvm::map_range(operandPositions, [&](int64_t pos) {
+                         return target->getOperand(pos);
+                       }));
+  }
+  results.setValues(cast<OpResult>(getResult()), operands);
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::GetOperandOp::verify() {
+  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+                                    getIsInverted(), getIsAll());
+}
+
 //===----------------------------------------------------------------------===//
 // GetResultOp
 //===----------------------------------------------------------------------===//
@@ -1472,21 +1505,31 @@ DiagnosedSilenceableFailure
 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
                               transform::TransformResults &results,
                               transform::TransformState &state) {
-  int64_t resultNumber = getResultNumber();
   SmallVector<Value> opResults;
   for (Operation *target : state.getPayloadOps(getTarget())) {
-    if (resultNumber >= target->getNumResults()) {
-      DiagnosedSilenceableFailure diag =
-          emitSilenceableError() << "targeted op does not have enough results";
-      diag.attachNote(target->getLoc()) << "target op";
+    SmallVector<int64_t> resultPositions;
+    DiagnosedSilenceableFailure diag = expandTargetSpecification(
+        getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
+        target->getNumResults(), resultPositions);
+    if (diag.isSilenceableFailure()) {
+      diag.attachNote(target->getLoc())
+          << "while considering positions of this payload operation";
       return diag;
     }
-    opResults.push_back(target->getOpResult(resultNumber));
+    llvm::append_range(opResults,
+                       llvm::map_range(resultPositions, [&](int64_t pos) {
+                         return target->getResult(pos);
+                       }));
   }
-  results.setValues(llvm::cast<OpResult>(getResult()), opResults);
+  results.setValues(cast<OpResult>(getResult()), opResults);
   return DiagnosedSilenceableFailure::success();
 }
 
+LogicalResult transform::GetResultOp::verify() {
+  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+                                    getIsInverted(), getIsAll());
+}
+
 //===----------------------------------------------------------------------===//
 // GetTypeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index ba51c400fe2cb2c..8d045cad7a4a36f 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -43,7 +43,6 @@ def __init__(
         self.parent = parent
         self.children = children if children is not None else []
 
-
 @ir.register_value_caster(AnyOpType.get_static_typeid())
 @ir.register_value_caster(OperationType.get_static_typeid())
 class OpHandle(Handle):
@@ -61,16 +60,16 @@ def __init__(
     ):
         super().__init__(v, parent=parent, children=children)
 
-    def get_result(self, idx: int = 0) -> "ValueHandle":
+    def get_result(self, indices: Sequence[int] = [0]) -> "ValueHandle":
         """
         Emits a `transform.GetResultOp`.
         Returns a handle to the result of the payload operation at the given
-        index.
+        indices.
         """
         get_result_op = transform.GetResultOp(
             AnyValueType.get(),
             self,
-            idx,
+            indices,
         )
         return get_result_op.result
 

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 96f2122e976df5a..de5807b2874b27a 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1483,6 +1483,78 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// expected-remark @below {{addi operand}}
+// expected-note @below {{value handle points to a block argument #0}}
+func.func @get_operand_of_op(%arg0: index, %arg1: index) -> index {
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %operand = transform.get_operand %addi[0] : (!transform.any_op) -> !transform.any_value
+    transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @get_out_of_bounds_operand_of_op(%arg0: index, %arg1: index) -> index {
+  // expected-note @below {{while considering positions of this payload operation}}
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{position overflow 2 (updated from 2) for maximum 2}}
+    %operand = transform.get_operand %addi[2] : (!transform.any_op) -> !transform.any_value
+    transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-remark @below {{addi operand}}
+// expected-note @below {{value handle points to a block argument #1}}
+func.func @get_inverted_operand_of_op(%arg0: index, %arg1: index) -> index {
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %operand = transform.get_operand %addi[except(0)] : (!transform.any_op) -> !transform.any_value
+    transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @get_multiple_operands_of_op(%arg0: index, %arg1: index) -> index {
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addui = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %operands = transform.get_operand %addui[all] : (!transform.any_op) -> !transform.any_value
+    %p = transform.num_associations %operands : (!transform.any_value) -> !transform.param<i64>
+    // expected-remark @below {{2}}
+    transform.debug.emit_param_as_remark %p : !transform.param<i64>
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @get_result_of_op(%arg0: index, %arg1: index) -> index {
   // expected-remark @below {{addi result}}
   // expected-note @below {{value handle points to an op result #0}}
@@ -1502,7 +1574,7 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index {
-  // expected-note @below {{target op}}
+  // expected-note @below {{while considering positions of this payload operation}}
   %r = arith.addi %arg0, %arg1 : index
   return %r : index
 }
@@ -1510,7 +1582,7 @@ func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    // expected-error @below {{targeted op does not have enough results}}
+    // expected-error @below {{position overflow 1 (updated from 1) for maximum 1}}
     %result = transform.get_result %addi[1] : (!transform.any_op) -> !transform.any_value
     transform.debug.emit_remark_at %result, "addi result" : !transform.any_value
     transform.yield
@@ -1537,6 +1609,24 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @get_multiple_result_of_op(%arg0: index, %arg1: index) -> (index, i1) {
+  %r, %b = arith.addui_extended %arg0, %arg1 : index, i1
+  return %r, %b : index, i1
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addui = transform.structured.match ops{["arith.addui_extended"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %results = transform.get_result %addui[all] : (!transform.any_op) -> !transform.any_value
+    %p = transform.num_associations %results : (!transform.any_value) -> !transform.param<i64>
+    // expected-remark @below {{2}}
+    transform.debug.emit_param_as_remark %p : !transform.param<i64>
+    transform.yield
+  }
+}
+
+// -----
+
 // expected-note @below {{target value}}
 func.func @get_result_of_op_bbarg(%arg0: index, %arg1: index) -> index {
   %r = arith.addi %arg0, %arg1 : index


        


More information about the Mlir-commits mailing list