[Mlir-commits] [mlir] [mlir][transform] Add transform.get_operand op (PR #78397)

Quinn Dawkins llvmlistbot at llvm.org
Wed Jan 17 19:57:40 PST 2024


https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/78397

>From 6ce93698660378d02dccb060327f8e53d2c67e98 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Tue, 16 Jan 2024 23:53:33 -0500
Subject: [PATCH 1/4] [mlir][transform] Add transform.get_operand op

Similar to `transform.get_result`, except it returns a handle to the
operand indicated by `operand_number`, or all operands if no index is
given.

Additionally updates `get_result` to make the `result_number`
optional. This makes the use case of wanting to get all of the
results of an operation easier by no longer requiring the user to
reconstruct the list of results one-by-one.
---
 .../mlir/Dialect/Transform/IR/TransformOps.td | 38 +++++++---
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 37 +++++++++-
 .../Dialect/Transform/test-interpreter.mlir   | 73 +++++++++++++++++++
 3 files changed, 138 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 2627553cff69533..e1da7299fad1a93 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -728,23 +728,43 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
                        "functional-type(operands, results)";
 }
 
+def GetOperandOp : TransformDialectOp<"get_operand",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
+  let summary = "Get a handle to the operand(s) of the targeted op";
+  let description = [{
+    The handle defined by this Transform op corresponds to the Operands of the
+    given `target` operation. Optionally `operand_number` can be specified to
+    select a specific operand.
+    
+    This transform fails silently if the targeted operation does not have enough
+    operands. It reads the target handle and produces the result handle.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       OptionalAttr<I64Attr>:$operand_number);
+  let results = (outs TransformValueHandleTypeInterface:$result);
+  let assemblyFormat = "$target (`[` $operand_number^ `]`)? attr-dict `:` "
+                       "functional-type(operands, results)";
+}
+
 def GetResultOp : TransformDialectOp<"get_result",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
-  let summary = "Get handle to the a result of the targeted op";
+  let summary = "Get a handle to the result(s) of the targeted op";
   let description = [{
-    The handle defined by this Transform op corresponds to the OpResult with
-    `result_number` that is defined by the given `target` operation.
-
-    This transform produces a silenceable failure if the targeted operation
-    does not have enough results. It reads the target handle and produces the
-    result handle.
+    The handle defined by this Transform op correspond to the OpResults of the
+    given `target` operation. Optionally `result_number` can be specified to
+    select a specific result.
+    
+    This transform fails silently if the targeted operation does not have enough
+    results. It reads the target handle and produces the result handle.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       I64Attr:$result_number);
+                       OptionalAttr<I64Attr>:$result_number);
   let results = (outs TransformValueHandleTypeInterface:$result);
-  let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
+  let assemblyFormat = "$target (`[` $result_number^ `]`)? attr-dict `:` "
                        "functional-type(operands, results)";
 }
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index b80fc09751d2aa4..56baae9b5fadf28 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1464,6 +1464,35 @@ transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// GetOperandOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
+                               transform::TransformResults &results,
+                               transform::TransformState &state) {
+  std::optional<int64_t> maybeOperandNumber = getOperandNumber();
+  SmallVector<Value> operands;
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    if (!maybeOperandNumber) {
+      for (Value operand : target->getOperands())
+        operands.push_back(operand);
+      continue;
+    }
+    int64_t operandNumber = *maybeOperandNumber;
+    if (operandNumber >= target->getNumOperands()) {
+      DiagnosedSilenceableFailure diag =
+          emitSilenceableError() << "targeted op does not have enough operands";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+    operands.push_back(target->getOperand(operandNumber));
+  }
+  results.setValues(llvm::cast<OpResult>(getResult()), operands);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // GetResultOp
 //===----------------------------------------------------------------------===//
@@ -1472,9 +1501,15 @@ DiagnosedSilenceableFailure
 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
                               transform::TransformResults &results,
                               transform::TransformState &state) {
-  int64_t resultNumber = getResultNumber();
+  std::optional<int64_t> maybeResultNumber = getResultNumber();
   SmallVector<Value> opResults;
   for (Operation *target : state.getPayloadOps(getTarget())) {
+    if (!maybeResultNumber) {
+      for (Value result : target->getResults())
+        opResults.push_back(result);
+      continue;
+    }
+    int64_t resultNumber = *maybeResultNumber;
     if (resultNumber >= target->getNumResults()) {
       DiagnosedSilenceableFailure diag =
           emitSilenceableError() << "targeted op does not have enough results";
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 96f2122e976df5a..b89b52e2f403d5c 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1483,6 +1483,60 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// expected-remark @below {{addi operand}}
+// expected-note @below {{value handle points to a block argument #0}}
+func.func @get_operand_of_op(%arg0: index, %arg1: index) -> index {
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %operand = transform.get_operand %addi[0] : (!transform.any_op) -> !transform.any_value
+    transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @get_out_of_bounds_operand_of_op(%arg0: index, %arg1: index) -> index {
+  // expected-note @below {{target op}}
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{targeted op does not have enough operands}}
+    %operand = transform.get_operand %addi[2] : (!transform.any_op) -> !transform.any_value
+    transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @get_multiple_operands_of_op(%arg0: index, %arg1: index) -> index {
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addui = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %operands = transform.get_operand %addui : (!transform.any_op) -> !transform.any_value
+    %p = transform.num_associations %operands : (!transform.any_value) -> !transform.param<i64>
+    // expected-remark @below {{2}}
+    transform.debug.emit_param_as_remark %p : !transform.param<i64>
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @get_result_of_op(%arg0: index, %arg1: index) -> index {
   // expected-remark @below {{addi result}}
   // expected-note @below {{value handle points to an op result #0}}
@@ -1537,6 +1591,25 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @get_multiple_result_of_op(%arg0: index, %arg1: index) -> (index, i1) {
+  // expected-remark @below {{matched bool}}
+  %r, %b = arith.addui_extended %arg0, %arg1 : index, i1
+  return %r, %b : index, i1
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addui = transform.structured.match ops{["arith.addui_extended"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %results = transform.get_result %addui : (!transform.any_op) -> !transform.any_value
+    %adds = transform.get_defining_op %results : (!transform.any_value) -> !transform.any_op
+    %_, %add_again = transform.split_handle %adds : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.debug.emit_remark_at %add_again, "matched bool" : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // expected-note @below {{target value}}
 func.func @get_result_of_op_bbarg(%arg0: index, %arg1: index) -> index {
   %r = arith.addi %arg0, %arg1 : index

>From 59fae0926a3afb2ede3c4e40653e9d53f1fb4000 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Wed, 17 Jan 2024 11:41:27 -0500
Subject: [PATCH 2/4] Address comments and switch to mirroring the linalg.match
 positional spec

---
 .../Linalg/TransformOps/LinalgMatchOps.td     |   4 +-
 .../Dialect/Transform/IR/MatchInterfaces.h    |  51 ++++++-
 .../mlir/Dialect/Transform/IR/TransformOps.td |  53 +++++--
 .../Linalg/TransformOps/LinalgMatchOps.cpp    | 112 +--------------
 .../Dialect/Transform/IR/MatchInterfaces.cpp  | 135 ++++++++++++++++++
 .../lib/Dialect/Transform/IR/TransformOps.cpp |  60 ++++----
 .../Dialect/Transform/test-interpreter.mlir   |  37 +++--
 7 files changed, 295 insertions(+), 157 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index 9e108529ec129b2..162dd05f93030f2 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -288,7 +288,7 @@ def MatchStructuredDimOp : Op<Transform_Dialect, "match.structured.dim", [
   let results = (outs Optional<TransformParamTypeInterface>:$result);
   let assemblyFormat =
       "$operand_handle `[`"
-      "custom<StructuredTransformDims>($raw_dim_list, $is_inverted, $is_all)"
+      "custom<TransformMatchDims>($raw_dim_list, $is_inverted, $is_all)"
       "`]` attr-dict `:` "
       "custom<SemiFunctionType>(type($operand_handle), type($result))";
 
@@ -347,7 +347,7 @@ class MatchStructuredOperandOp<string opname> : Op<Transform_Dialect, opname, [
       (outs Optional<AnyTypeOf<[TransformAnyHandle,Transform_AffineMapParamType]>>:$result);
   let assemblyFormat =
       "$operand_handle `[`"
-      "custom<StructuredTransformDims>($raw_position_list, $is_inverted, $is_all)"
+      "custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
       "`]` attr-dict "
       "`:` custom<SemiFunctionType>(type($operand_handle), type($result))";
 
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index b155b110677d6c7..36aeb4583029c96 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -9,11 +9,12 @@
 #ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
 #define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
 
+#include <optional>
+#include <type_traits>
+
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
 #include "llvm/ADT/STLExtras.h"
-#include <optional>
-#include <type_traits>
 
 namespace mlir {
 namespace transform {
@@ -168,6 +169,52 @@ class SingleValueMatcherOpTrait
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Printing/parsing for positional specification matchers
+//===----------------------------------------------------------------------===//
+
+/// Parses a positional index specification for transform match operations.
+/// The following forms are accepted:
+///
+///  - `all`: sets `isAll` and returns;
+///  - comma-separated-integer-list: populates `rawDimList` with the values;
+///  - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
+///  with the values and sets `isInverted`.
+ParseResult parseTransformMatchDims(OpAsmParser &parser,
+                                    DenseI64ArrayAttr &rawDimList,
+                                    UnitAttr &isInverted, UnitAttr &isAll);
+
+/// Prints a positional index specification for transform match operations.
+void printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
+                             DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
+                             UnitAttr isAll);
+
+//===----------------------------------------------------------------------===//
+// Utilities for positional specification matchers
+//===----------------------------------------------------------------------===//
+
+/// Checks if the positional specification defined is valid and reports errors
+/// otherwise.
+LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef<int64_t> raw,
+                                         bool inverted, bool all);
+
+/// Populates `result` with the positional identifiers relative to `maxNumber`.
+/// If `isAll` is set, the result will contain all numbers from `0` to
+/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
+/// values from `rawList` are  are interpreted as counting backwards from
+/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
+/// numbers remain as is. If `isInverted` is set, populates `result` with those
+/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
+/// `rawList`. If `rawList` contains values that are greater than or equal to
+/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
+/// given location. `maxNumber` must be positive. If `rawList` contains
+/// duplicate numbers or numbers that become duplicate after negative value
+/// remapping, emits a silenceable error.
+DiagnosedSilenceableFailure
+expandTargetSpecification(Location loc, bool isAll, bool isInverted,
+                          ArrayRef<int64_t> rawList, int64_t maxNumber,
+                          SmallVectorImpl<int64_t> &result);
+
 } // namespace transform
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index e1da7299fad1a93..9f513822ed0a4e1 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -733,19 +733,31 @@ def GetOperandOp : TransformDialectOp<"get_operand",
      NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
   let summary = "Get a handle to the operand(s) of the targeted op";
   let description = [{
-    The handle defined by this Transform op corresponds to the Operands of the
-    given `target` operation. Optionally `operand_number` can be specified to
-    select a specific operand.
+    The handle defined by this Transform op corresponds to the operands of the
+    given `target` operation specified by the given set of positions. There are
+    three possible modes:
+
+     - Position list directly, i.e. `%target[0, 1, 2]`. This will return the
+       operands at the specified positions.
+     - Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
+       all operands except those at the given positions.
+     - All, i.e. `%target[all]`. This will return all operands of the operation.
     
-    This transform fails silently if the targeted operation does not have enough
-    operands. It reads the target handle and produces the result handle.
+    This transform produces a silenceable failure if any of the operand indices
+    exceeds the number of operands in the target. It reads the target handle and
+    produces the result handle.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       OptionalAttr<I64Attr>:$operand_number);
+                       DenseI64ArrayAttr:$raw_position_list,
+                       UnitAttr:$is_inverted,
+                       UnitAttr:$is_all);
   let results = (outs TransformValueHandleTypeInterface:$result);
-  let assemblyFormat = "$target (`[` $operand_number^ `]`)? attr-dict `:` "
-                       "functional-type(operands, results)";
+  let assemblyFormat =
+      "$target `[`"
+      "custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
+      "`]` attr-dict `:` functional-type(operands, results)";
+  let hasVerifier = 1;
 }
 
 def GetResultOp : TransformDialectOp<"get_result",
@@ -759,13 +771,32 @@ def GetResultOp : TransformDialectOp<"get_result",
     
     This transform fails silently if the targeted operation does not have enough
     results. It reads the target handle and produces the result handle.
+
+    The handle defined by this Transform op corresponds to the results of the
+    given `target` operation specified by the given set of positions. There are
+    three possible modes:
+
+     - Position list directly, i.e. `%target[0, 1, 2]`. This will return the
+       results at the specified positions.
+     - Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
+       all results except those at the given positions.
+     - All, i.e. `%target[all]`. This will return all results of the operation.
+    
+    This transform produces a silenceable failure if any of the result indices
+    exceeds the number of results returned by the target. It reads the target
+    handle and produces the result handle.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       OptionalAttr<I64Attr>:$result_number);
+                       DenseI64ArrayAttr:$raw_position_list,
+                       UnitAttr:$is_inverted,
+                       UnitAttr:$is_all);
   let results = (outs TransformValueHandleTypeInterface:$result);
-  let assemblyFormat = "$target (`[` $result_number^ `]`)? attr-dict `:` "
-                       "functional-type(operands, results)";
+  let assemblyFormat =
+      "$target `[`"
+      "custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
+      "`]` attr-dict `:` functional-type(operands, results)";
+  let hasVerifier = 1;
 }
 
 def GetTypeOp : TransformDialectOp<"get_type",
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index fb021ed76242e90..c739813d01195b4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -388,33 +388,6 @@ expandTargetSpecification(Location loc, bool isAll, bool isInverted,
   return DiagnosedSilenceableFailure::success();
 }
 
-/// Checks if the positional specification defined is valid and reports errors
-/// otherwise.
-LogicalResult verifyStructuredTransformDimsOp(Operation *op,
-                                              ArrayRef<int64_t> raw,
-                                              bool inverted, bool all) {
-  if (all) {
-    if (inverted) {
-      return op->emitOpError()
-             << "cannot request both 'all' and 'inverted' values in the list";
-    }
-    if (!raw.empty()) {
-      return op->emitOpError()
-             << "cannot both request 'all' and specific values in the list";
-    }
-  }
-  if (!all && raw.empty()) {
-    return op->emitOpError() << "must request specific values in the list if "
-                                "'all' is not specified";
-  }
-  SmallVector<int64_t> rawVector = llvm::to_vector(raw);
-  auto *it = std::unique(rawVector.begin(), rawVector.end());
-  if (it != rawVector.end())
-    return op->emitOpError() << "expected the listed values to be unique";
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // MatchStructuredDimOp
 //===----------------------------------------------------------------------===//
@@ -475,8 +448,8 @@ LogicalResult transform::MatchStructuredDimOp::verify() {
     return emitOpError() << "cannot request the same dimension to be both "
                             "parallel and reduction";
   }
-  return verifyStructuredTransformDimsOp(getOperation(), getRawDimList(),
-                                         getIsInverted(), getIsAll());
+  return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
+                                    getIsInverted(), getIsAll());
 }
 
 //===----------------------------------------------------------------------===//
@@ -592,8 +565,8 @@ LogicalResult verifyStructuredOperandOp(OpTy op) {
 LogicalResult transform::MatchStructuredInputOp::verify() {
   if (failed(verifyStructuredOperandOp(*this)))
     return failure();
-  return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
-                                         getIsInverted(), getIsAll());
+  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+                                    getIsInverted(), getIsAll());
 }
 
 //===----------------------------------------------------------------------===//
@@ -665,8 +638,8 @@ DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
 LogicalResult transform::MatchStructuredInitOp::verify() {
   if (failed(verifyStructuredOperandOp(*this)))
     return failure();
-  return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
-                                         getIsInverted(), getIsAll());
+  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+                                    getIsInverted(), getIsAll());
 }
 
 //===----------------------------------------------------------------------===//
@@ -793,78 +766,5 @@ void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
   build(builder, state, ValueRange());
 }
 
-//===----------------------------------------------------------------------===//
-// Printing and parsing for structured match ops.
-//===----------------------------------------------------------------------===//
-
-/// Keyword syntax for positional specification inversion.
-constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
-
-/// Keyword syntax for full inclusion in positional specification.
-constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
-
-/// Parses a positional specification for structured transform operations. The
-/// following forms are accepted:
-///
-///  - `all`: sets `isAll` and returns;
-///  - comma-separated-integer-list: populates `rawDimList` with the values;
-///  - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
-///  with the values and sets `isInverted`.
-static ParseResult parseStructuredTransformDims(OpAsmParser &parser,
-                                                DenseI64ArrayAttr &rawDimList,
-                                                UnitAttr &isInverted,
-                                                UnitAttr &isAll) {
-  Builder &builder = parser.getBuilder();
-  if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
-    rawDimList = builder.getDenseI64ArrayAttr({});
-    isInverted = nullptr;
-    isAll = builder.getUnitAttr();
-    return success();
-  }
-
-  isAll = nullptr;
-  isInverted = nullptr;
-  if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
-    isInverted = builder.getUnitAttr();
-  }
-
-  if (isInverted) {
-    if (parser.parseLParen().failed())
-      return failure();
-  }
-
-  SmallVector<int64_t> values;
-  ParseResult listResult = parser.parseCommaSeparatedList(
-      [&]() { return parser.parseInteger(values.emplace_back()); });
-  if (listResult.failed())
-    return failure();
-
-  rawDimList = builder.getDenseI64ArrayAttr(values);
-
-  if (isInverted) {
-    if (parser.parseRParen().failed())
-      return failure();
-  }
-  return success();
-}
-
-/// Prints a positional specification for structured transform operations.
-static void printStructuredTransformDims(OpAsmPrinter &printer, Operation *op,
-                                         DenseI64ArrayAttr rawDimList,
-                                         UnitAttr isInverted, UnitAttr isAll) {
-  if (isAll) {
-    printer << kDimAllKeyword;
-    return;
-  }
-  if (isInverted) {
-    printer << kDimExceptKeyword << "(";
-  }
-  llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
-                        [&](int64_t value) { printer << value; });
-  if (isInverted) {
-    printer << ")";
-  }
-}
-
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
diff --git a/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp
index 6049ee767e1bb74..b9b6dabc26216e1 100644
--- a/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/MatchInterfaces.cpp
@@ -10,6 +10,141 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Printing and parsing for match ops.
+//===----------------------------------------------------------------------===//
+
+/// Keyword syntax for positional specification inversion.
+constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
+
+/// Keyword syntax for full inclusion in positional specification.
+constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
+
+ParseResult transform::parseTransformMatchDims(OpAsmParser &parser,
+                                               DenseI64ArrayAttr &rawDimList,
+                                               UnitAttr &isInverted,
+                                               UnitAttr &isAll) {
+  Builder &builder = parser.getBuilder();
+  if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
+    rawDimList = builder.getDenseI64ArrayAttr({});
+    isInverted = nullptr;
+    isAll = builder.getUnitAttr();
+    return success();
+  }
+
+  isAll = nullptr;
+  isInverted = nullptr;
+  if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
+    isInverted = builder.getUnitAttr();
+  }
+
+  if (isInverted) {
+    if (parser.parseLParen().failed())
+      return failure();
+  }
+
+  SmallVector<int64_t> values;
+  ParseResult listResult = parser.parseCommaSeparatedList(
+      [&]() { return parser.parseInteger(values.emplace_back()); });
+  if (listResult.failed())
+    return failure();
+
+  rawDimList = builder.getDenseI64ArrayAttr(values);
+
+  if (isInverted) {
+    if (parser.parseRParen().failed())
+      return failure();
+  }
+  return success();
+}
+
+void transform::printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
+                                        DenseI64ArrayAttr rawDimList,
+                                        UnitAttr isInverted, UnitAttr isAll) {
+  if (isAll) {
+    printer << kDimAllKeyword;
+    return;
+  }
+  if (isInverted) {
+    printer << kDimExceptKeyword << "(";
+  }
+  llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
+                        [&](int64_t value) { printer << value; });
+  if (isInverted) {
+    printer << ")";
+  }
+}
+
+LogicalResult transform::verifyTransformMatchDimsOp(Operation *op,
+                                                    ArrayRef<int64_t> raw,
+                                                    bool inverted, bool all) {
+  if (all) {
+    if (inverted) {
+      return op->emitOpError()
+             << "cannot request both 'all' and 'inverted' values in the list";
+    }
+    if (!raw.empty()) {
+      return op->emitOpError()
+             << "cannot both request 'all' and specific values in the list";
+    }
+  }
+  if (!all && raw.empty()) {
+    return op->emitOpError() << "must request specific values in the list if "
+                                "'all' is not specified";
+  }
+  SmallVector<int64_t> rawVector = llvm::to_vector(raw);
+  auto *it = std::unique(rawVector.begin(), rawVector.end());
+  if (it != rawVector.end())
+    return op->emitOpError() << "expected the listed values to be unique";
+
+  return success();
+}
+
+DiagnosedSilenceableFailure transform::expandTargetSpecification(
+    Location loc, bool isAll, bool isInverted, ArrayRef<int64_t> rawList,
+    int64_t maxNumber, SmallVectorImpl<int64_t> &result) {
+  assert(maxNumber > 0 && "expected size to be positive");
+  assert(!(isAll && isInverted) && "cannot invert all");
+  if (isAll) {
+    result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  SmallVector<int64_t> expanded;
+  llvm::SmallDenseSet<int64_t> visited;
+  expanded.reserve(rawList.size());
+  SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
+  for (int64_t raw : rawList) {
+    int64_t updated = raw < 0 ? maxNumber + raw : raw;
+    if (updated >= maxNumber) {
+      return emitSilenceableFailure(loc)
+             << "position overflow " << updated << " (updated from " << raw
+             << ") for maximum " << maxNumber;
+    }
+    if (updated < 0) {
+      return emitSilenceableFailure(loc) << "position underflow " << updated
+                                         << " (updated from " << raw << ")";
+    }
+    if (!visited.insert(updated).second) {
+      return emitSilenceableFailure(loc) << "repeated position " << updated
+                                         << " (updated from " << raw << ")";
+    }
+    target.push_back(updated);
+  }
+
+  if (!isInverted)
+    return DiagnosedSilenceableFailure::success();
+
+  result.reserve(result.size() + (maxNumber - expanded.size()));
+  for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
+    if (llvm::is_contained(expanded, candidate))
+      continue;
+    result.push_back(candidate);
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // Generated interface implementation.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 56baae9b5fadf28..485d4448e7c3683 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1472,27 +1472,31 @@ DiagnosedSilenceableFailure
 transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
                                transform::TransformResults &results,
                                transform::TransformState &state) {
-  std::optional<int64_t> maybeOperandNumber = getOperandNumber();
   SmallVector<Value> operands;
   for (Operation *target : state.getPayloadOps(getTarget())) {
-    if (!maybeOperandNumber) {
-      for (Value operand : target->getOperands())
-        operands.push_back(operand);
-      continue;
-    }
-    int64_t operandNumber = *maybeOperandNumber;
-    if (operandNumber >= target->getNumOperands()) {
-      DiagnosedSilenceableFailure diag =
-          emitSilenceableError() << "targeted op does not have enough operands";
-      diag.attachNote(target->getLoc()) << "target op";
+    SmallVector<int64_t> operandPositions;
+    DiagnosedSilenceableFailure diag = expandTargetSpecification(
+        getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
+        target->getNumOperands(), operandPositions);
+    if (diag.isSilenceableFailure()) {
+      diag.attachNote(target->getLoc())
+          << "while considering positions of this payload operation";
       return diag;
     }
-    operands.push_back(target->getOperand(operandNumber));
+    llvm::append_range(operands,
+                       llvm::map_range(operandPositions, [&](int64_t pos) {
+                         return target->getOperand(pos);
+                       }));
   }
-  results.setValues(llvm::cast<OpResult>(getResult()), operands);
+  results.setValues(cast<OpResult>(getResult()), operands);
   return DiagnosedSilenceableFailure::success();
 }
 
+LogicalResult transform::GetOperandOp::verify() {
+  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+                                    getIsInverted(), getIsAll());
+}
+
 //===----------------------------------------------------------------------===//
 // GetResultOp
 //===----------------------------------------------------------------------===//
@@ -1501,27 +1505,31 @@ DiagnosedSilenceableFailure
 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
                               transform::TransformResults &results,
                               transform::TransformState &state) {
-  std::optional<int64_t> maybeResultNumber = getResultNumber();
   SmallVector<Value> opResults;
   for (Operation *target : state.getPayloadOps(getTarget())) {
-    if (!maybeResultNumber) {
-      for (Value result : target->getResults())
-        opResults.push_back(result);
-      continue;
-    }
-    int64_t resultNumber = *maybeResultNumber;
-    if (resultNumber >= target->getNumResults()) {
-      DiagnosedSilenceableFailure diag =
-          emitSilenceableError() << "targeted op does not have enough results";
-      diag.attachNote(target->getLoc()) << "target op";
+    SmallVector<int64_t> resultPositions;
+    DiagnosedSilenceableFailure diag = expandTargetSpecification(
+        getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
+        target->getNumResults(), resultPositions);
+    if (diag.isSilenceableFailure()) {
+      diag.attachNote(target->getLoc())
+          << "while considering positions of this payload operation";
       return diag;
     }
-    opResults.push_back(target->getOpResult(resultNumber));
+    llvm::append_range(opResults,
+                       llvm::map_range(resultPositions, [&](int64_t pos) {
+                         return target->getResult(pos);
+                       }));
   }
-  results.setValues(llvm::cast<OpResult>(getResult()), opResults);
+  results.setValues(cast<OpResult>(getResult()), opResults);
   return DiagnosedSilenceableFailure::success();
 }
 
+LogicalResult transform::GetResultOp::verify() {
+  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
+                                    getIsInverted(), getIsAll());
+}
+
 //===----------------------------------------------------------------------===//
 // GetTypeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index b89b52e2f403d5c..de5807b2874b27a 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1502,7 +1502,7 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 func.func @get_out_of_bounds_operand_of_op(%arg0: index, %arg1: index) -> index {
-  // expected-note @below {{target op}}
+  // expected-note @below {{while considering positions of this payload operation}}
   %r = arith.addi %arg0, %arg1 : index
   return %r : index
 }
@@ -1510,7 +1510,7 @@ func.func @get_out_of_bounds_operand_of_op(%arg0: index, %arg1: index) -> index
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    // expected-error @below {{targeted op does not have enough operands}}
+    // expected-error @below {{position overflow 2 (updated from 2) for maximum 2}}
     %operand = transform.get_operand %addi[2] : (!transform.any_op) -> !transform.any_value
     transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
     transform.yield
@@ -1519,6 +1519,24 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// expected-remark @below {{addi operand}}
+// expected-note @below {{value handle points to a block argument #1}}
+func.func @get_inverted_operand_of_op(%arg0: index, %arg1: index) -> index {
+  %r = arith.addi %arg0, %arg1 : index
+  return %r : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %operand = transform.get_operand %addi[except(0)] : (!transform.any_op) -> !transform.any_value
+    transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @get_multiple_operands_of_op(%arg0: index, %arg1: index) -> index {
   %r = arith.addi %arg0, %arg1 : index
   return %r : index
@@ -1527,7 +1545,7 @@ func.func @get_multiple_operands_of_op(%arg0: index, %arg1: index) -> index {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %addui = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %operands = transform.get_operand %addui : (!transform.any_op) -> !transform.any_value
+    %operands = transform.get_operand %addui[all] : (!transform.any_op) -> !transform.any_value
     %p = transform.num_associations %operands : (!transform.any_value) -> !transform.param<i64>
     // expected-remark @below {{2}}
     transform.debug.emit_param_as_remark %p : !transform.param<i64>
@@ -1556,7 +1574,7 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index {
-  // expected-note @below {{target op}}
+  // expected-note @below {{while considering positions of this payload operation}}
   %r = arith.addi %arg0, %arg1 : index
   return %r : index
 }
@@ -1564,7 +1582,7 @@ func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    // expected-error @below {{targeted op does not have enough results}}
+    // expected-error @below {{position overflow 1 (updated from 1) for maximum 1}}
     %result = transform.get_result %addi[1] : (!transform.any_op) -> !transform.any_value
     transform.debug.emit_remark_at %result, "addi result" : !transform.any_value
     transform.yield
@@ -1592,7 +1610,6 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 func.func @get_multiple_result_of_op(%arg0: index, %arg1: index) -> (index, i1) {
-  // expected-remark @below {{matched bool}}
   %r, %b = arith.addui_extended %arg0, %arg1 : index, i1
   return %r, %b : index, i1
 }
@@ -1600,10 +1617,10 @@ func.func @get_multiple_result_of_op(%arg0: index, %arg1: index) -> (index, i1)
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %addui = transform.structured.match ops{["arith.addui_extended"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %results = transform.get_result %addui : (!transform.any_op) -> !transform.any_value
-    %adds = transform.get_defining_op %results : (!transform.any_value) -> !transform.any_op
-    %_, %add_again = transform.split_handle %adds : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.debug.emit_remark_at %add_again, "matched bool" : !transform.any_op
+    %results = transform.get_result %addui[all] : (!transform.any_op) -> !transform.any_value
+    %p = transform.num_associations %results : (!transform.any_value) -> !transform.param<i64>
+    // expected-remark @below {{2}}
+    transform.debug.emit_param_as_remark %p : !transform.param<i64>
     transform.yield
   }
 }

>From 9cd797ffa6c831936d5262d2e99efcc0253894b6 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Wed, 17 Jan 2024 11:59:21 -0500
Subject: [PATCH 3/4] Drop unused function and rebase

---
 .../Linalg/TransformOps/LinalgMatchOps.cpp    | 58 -------------------
 1 file changed, 58 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index c739813d01195b4..115da4b90e063ac 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -330,64 +330,6 @@ static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
   return DiagnosedSilenceableFailure::success();
 }
 
-/// Populates `result` with the positional identifiers relative to `maxNumber`.
-/// If `isAll` is set, the result will contain all numbers from `0` to
-/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
-/// values from `rawList` are  are interpreted as counting backwards from
-/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
-/// numbers remain as is. If `isInverted` is set, populates `result` with those
-/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
-/// `rawList`. If `rawList` contains values that are greater than or equal to
-/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
-/// given location. `maxNumber` must be positive. If `rawList` contains
-/// duplicate numbers or numbers that become duplicate after negative value
-/// remapping, emits a silenceable error.
-static DiagnosedSilenceableFailure
-expandTargetSpecification(Location loc, bool isAll, bool isInverted,
-                          ArrayRef<int64_t> rawList, int64_t maxNumber,
-                          SmallVectorImpl<int64_t> &result) {
-  assert(maxNumber > 0 && "expected size to be positive");
-  assert(!(isAll && isInverted) && "cannot invert all");
-  if (isAll) {
-    result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
-    return DiagnosedSilenceableFailure::success();
-  }
-
-  SmallVector<int64_t> expanded;
-  llvm::SmallDenseSet<int64_t> visited;
-  expanded.reserve(rawList.size());
-  SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
-  for (int64_t raw : rawList) {
-    int64_t updated = raw < 0 ? maxNumber + raw : raw;
-    if (updated >= maxNumber) {
-      return emitSilenceableFailure(loc)
-             << "position overflow " << updated << " (updated from " << raw
-             << ") for maximum " << maxNumber;
-    }
-    if (updated < 0) {
-      return emitSilenceableFailure(loc) << "position underflow " << updated
-                                         << " (updated from " << raw << ")";
-    }
-    if (!visited.insert(updated).second) {
-      return emitSilenceableFailure(loc) << "repeated position " << updated
-                                         << " (updated from " << raw << ")";
-    }
-    target.push_back(updated);
-  }
-
-  if (!isInverted)
-    return DiagnosedSilenceableFailure::success();
-
-  result.reserve(result.size() + (maxNumber - expanded.size()));
-  for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
-    if (llvm::is_contained(expanded, candidate))
-      continue;
-    result.push_back(candidate);
-  }
-
-  return DiagnosedSilenceableFailure::success();
-}
-
 //===----------------------------------------------------------------------===//
 // MatchStructuredDimOp
 //===----------------------------------------------------------------------===//

>From c57bbef8b3a478873278693a8464eab3600b7181 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Wed, 17 Jan 2024 22:57:28 -0500
Subject: [PATCH 4/4] Fix python test

---
 mlir/python/mlir/dialects/transform/extras/__init__.py | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index ba51c400fe2cb2c..8d045cad7a4a36f 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -43,7 +43,6 @@ def __init__(
         self.parent = parent
         self.children = children if children is not None else []
 
-
 @ir.register_value_caster(AnyOpType.get_static_typeid())
 @ir.register_value_caster(OperationType.get_static_typeid())
 class OpHandle(Handle):
@@ -61,16 +60,16 @@ def __init__(
     ):
         super().__init__(v, parent=parent, children=children)
 
-    def get_result(self, idx: int = 0) -> "ValueHandle":
+    def get_result(self, indices: Sequence[int] = [0]) -> "ValueHandle":
         """
         Emits a `transform.GetResultOp`.
         Returns a handle to the result of the payload operation at the given
-        index.
+        indices.
         """
         get_result_op = transform.GetResultOp(
             AnyValueType.get(),
             self,
-            idx,
+            indices,
         )
         return get_result_op.result
 



More information about the Mlir-commits mailing list