[Mlir-commits] [mlir] 5e7ac25 - [mlir][transform] Add op for adding attributes to payload IR
Quinn Dawkins
llvmlistbot at llvm.org
Tue May 30 08:54:46 PDT 2023
Author: Quinn Dawkins
Date: 2023-05-30T11:46:18-04:00
New Revision: 5e7ac2503a1bbfa13b84f00d8e12865cd16b0164
URL: https://github.com/llvm/llvm-project/commit/5e7ac2503a1bbfa13b84f00d8e12865cd16b0164
DIFF: https://github.com/llvm/llvm-project/commit/5e7ac2503a1bbfa13b84f00d8e12865cd16b0164.diff
LOG: [mlir][transform] Add op for adding attributes to payload IR
The ability to add attributes to payload IR is useful functionality
independent of any dialect. This is added here through `transform.annotate`
by enabling attributes tied to a `TransformParamTypeInterface` (which
internally refers to an Attribute) to be added to a target operation by
name.
The AnnotateOp does not produce a new handle as no existing handles
should be affected by adding an attribute. Existing attributes on
the payload with the same name will be overwritten.
Differential Revision: https://reviews.llvm.org/D151689
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 9305b6b0859e..6036687017a5 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -101,6 +101,31 @@ def AlternativesOp : TransformDialectOp<"alternatives",
let hasVerifier = 1;
}
+def AnnotateOp : TransformDialectOp<"annotate",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let summary = "Annotates the target operation with an attribute by name";
+ let description = [{
+ Adds an attribute with the given `name` to the `target` operation. An
+ optional `param` handle can be provided to give the attribute a specific
+ value, else a UnitAttr is added. A single attribute will be broadcasted to
+ all target operations, otherwise the attributes will be mapped 1:1 based on
+ the order within the handles.
+
+ Fails silently if the length of the parameter payload does not match the length of
+ the target payload. Does not consume the provided handles.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ StrAttr:$name,
+ Optional<TransformParamTypeInterface>:$param);
+ let results = (outs);
+
+ let assemblyFormat =
+ "$target $name attr-dict (`=` $param^)?"
+ "`:` type($target) (`,` type($param)^)?";
+}
+
def CastOp : TransformDialectOp<"cast",
[TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<CastOpInterface>,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index a3b55a45dd96..5f18d9042fdf 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -301,6 +301,43 @@ LogicalResult transform::AlternativesOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// AnnotateOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::AnnotateOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Operation *> targets =
+ llvm::to_vector(state.getPayloadOps(getTarget()));
+
+ Attribute attr = UnitAttr::get(getContext());
+ if (auto paramH = getParam()) {
+ ArrayRef<Attribute> params = state.getParams(paramH);
+ if (params.size() != 1) {
+ if (targets.size() != params.size()) {
+ return emitSilenceableError()
+ << "parameter and target have
diff erent payload lengths ("
+ << params.size() << " vs " << targets.size() << ")";
+ }
+ for (auto &&[target, attr] : llvm::zip_equal(targets, params))
+ target->setAttr(getName(), attr);
+ return DiagnosedSilenceableFailure::success();
+ }
+ attr = params[0];
+ }
+ for (auto target : targets)
+ target->setAttr(getName(), attr);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::AnnotateOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTarget(), effects);
+ onlyReadsHandle(getParam(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index a885c89af031..932b2cb01135 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1620,3 +1620,37 @@ transform.sequence failures(propagate) {
// expected-remark @below {{2}}
test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
}
+
+
+// -----
+
+// CHECK-LABEL: func @test_annotation()
+// CHECK-NEXT: "test.annotate_me"()
+// CHECK-SAME: broadcast_attr = 2 : i64
+// CHECK-SAME: new_attr = 1 : i32
+// CHECK-SAME: unit_attr
+// CHECK-NEXT: "test.annotate_me"()
+// CHECK-SAME: broadcast_attr = 2 : i64
+// CHECK-SAME: existing_attr = "test"
+// CHECK-SAME: new_attr = 1 : i32
+// CHECK-SAME: unit_attr
+// CHECK-NEXT: "test.annotate_me"()
+// CHECK-SAME: broadcast_attr = 2 : i64
+// CHECK-SAME: new_attr = 1 : i32
+// CHECK-SAME: unit_attr
+func.func @test_annotation() {
+ %0 = "test.annotate_me"() : () -> (i1)
+ %1 = "test.annotate_me"() {existing_attr = "test"} : () -> (i1)
+ %2 = "test.annotate_me"() {new_attr = 0} : () -> (i1)
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+ %0 = transform.structured.match ops{["test.annotate_me"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.test_produce_param_with_number_of_test_ops %0 : !transform.any_op
+ transform.annotate %0 "new_attr" = %1 : !transform.any_op, !transform.test_dialect_param
+
+ %2 = transform.param.constant 2 -> !transform.param<i64>
+ transform.annotate %0 "broadcast_attr" = %2 : !transform.any_op, !transform.param<i64>
+ transform.annotate %0 "unit_attr" : !transform.any_op
+}
More information about the Mlir-commits
mailing list