[Mlir-commits] [mlir] 5e7ac25 - [mlir][transform] Add op for adding attributes to payload IR

Quinn Dawkins llvmlistbot at llvm.org
Tue May 30 08:54:46 PDT 2023


Author: Quinn Dawkins
Date: 2023-05-30T11:46:18-04:00
New Revision: 5e7ac2503a1bbfa13b84f00d8e12865cd16b0164

URL: https://github.com/llvm/llvm-project/commit/5e7ac2503a1bbfa13b84f00d8e12865cd16b0164
DIFF: https://github.com/llvm/llvm-project/commit/5e7ac2503a1bbfa13b84f00d8e12865cd16b0164.diff

LOG: [mlir][transform] Add op for adding attributes to payload IR

The ability to add attributes to payload IR is useful functionality
independent of any dialect. This is added here through `transform.annotate`
by enabling attributes tied to a `TransformParamTypeInterface` (which
internally refers to an Attribute) to be added to a target operation by
name.

The AnnotateOp does not produce a new handle as no existing handles
should be affected by adding an attribute. Existing attributes on
the payload with the same name will be overwritten.

Differential Revision: https://reviews.llvm.org/D151689

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 9305b6b0859e..6036687017a5 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -101,6 +101,31 @@ def AlternativesOp : TransformDialectOp<"alternatives",
   let hasVerifier = 1;
 }
 
+def AnnotateOp : TransformDialectOp<"annotate",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let summary = "Annotates the target operation with an attribute by name";
+  let description = [{
+    Adds an attribute with the given `name` to the `target` operation. An
+    optional `param` handle can be provided to give the attribute a specific
+    value, else a UnitAttr is added. A single attribute will be broadcasted to
+    all target operations, otherwise the attributes will be mapped 1:1 based on
+    the order within the handles.
+
+    Fails silently if the length of the parameter payload does not match the length of
+    the target payload. Does not consume the provided handles.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       StrAttr:$name,
+                       Optional<TransformParamTypeInterface>:$param);
+  let results = (outs);
+
+  let assemblyFormat =
+    "$target $name attr-dict (`=` $param^)?"
+    "`:` type($target) (`,` type($param)^)?";
+}
+
 def CastOp : TransformDialectOp<"cast",
     [TransformOpInterface, TransformEachOpTrait,
      DeclareOpInterfaceMethods<CastOpInterface>,

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index a3b55a45dd96..5f18d9042fdf 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -301,6 +301,43 @@ LogicalResult transform::AlternativesOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// AnnotateOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::AnnotateOp::apply(transform::TransformResults &results,
+                             transform::TransformState &state) {
+  SmallVector<Operation *> targets =
+      llvm::to_vector(state.getPayloadOps(getTarget()));
+
+  Attribute attr = UnitAttr::get(getContext());
+  if (auto paramH = getParam()) {
+    ArrayRef<Attribute> params = state.getParams(paramH);
+    if (params.size() != 1) {
+      if (targets.size() != params.size()) {
+        return emitSilenceableError()
+               << "parameter and target have 
diff erent payload lengths ("
+               << params.size() << " vs " << targets.size() << ")";
+      }
+      for (auto &&[target, attr] : llvm::zip_equal(targets, params))
+        target->setAttr(getName(), attr);
+      return DiagnosedSilenceableFailure::success();
+    }
+    attr = params[0];
+  }
+  for (auto target : targets)
+    target->setAttr(getName(), attr);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::AnnotateOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getTarget(), effects);
+  onlyReadsHandle(getParam(), effects);
+  modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index a885c89af031..932b2cb01135 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1620,3 +1620,37 @@ transform.sequence failures(propagate) {
   // expected-remark @below {{2}}
   test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
 }
+
+
+// -----
+
+// CHECK-LABEL: func @test_annotation()
+//  CHECK-NEXT:   "test.annotate_me"()
+//  CHECK-SAME:                        broadcast_attr = 2 : i64
+//  CHECK-SAME:                        new_attr = 1 : i32
+//  CHECK-SAME:                        unit_attr
+//  CHECK-NEXT:   "test.annotate_me"()
+//  CHECK-SAME:                        broadcast_attr = 2 : i64
+//  CHECK-SAME:                        existing_attr = "test"
+//  CHECK-SAME:                        new_attr = 1 : i32
+//  CHECK-SAME:                        unit_attr
+//  CHECK-NEXT:   "test.annotate_me"()
+//  CHECK-SAME:                        broadcast_attr = 2 : i64
+//  CHECK-SAME:                        new_attr = 1 : i32
+//  CHECK-SAME:                        unit_attr
+func.func @test_annotation() {
+  %0 = "test.annotate_me"() : () -> (i1)
+  %1 = "test.annotate_me"() {existing_attr = "test"} : () -> (i1)
+  %2 = "test.annotate_me"() {new_attr = 0} : () -> (i1)
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %0 = transform.structured.match ops{["test.annotate_me"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.test_produce_param_with_number_of_test_ops %0 : !transform.any_op
+  transform.annotate %0 "new_attr" = %1 : !transform.any_op, !transform.test_dialect_param
+
+  %2 = transform.param.constant 2 -> !transform.param<i64>
+  transform.annotate %0 "broadcast_attr" = %2 : !transform.any_op, !transform.param<i64>
+  transform.annotate %0 "unit_attr" : !transform.any_op
+}


        


More information about the Mlir-commits mailing list