[Mlir-commits] [mlir] 4110934 - [mlir] add readonly/consume annotations to transform named sequences
Alex Zinenko
llvmlistbot at llvm.org
Tue Apr 4 02:38:09 PDT 2023
Author: Alex Zinenko
Date: 2023-04-04T09:38:00Z
New Revision: 4110934120ed4e4309099be3389faef128f72c03
URL: https://github.com/llvm/llvm-project/commit/4110934120ed4e4309099be3389faef128f72c03
DIFF: https://github.com/llvm/llvm-project/commit/4110934120ed4e4309099be3389faef128f72c03.diff
LOG: [mlir] add readonly/consume annotations to transform named sequences
Use the argument attribute mechanism for function-like operations to
annotate the arguments of named transform sequences as consuming or only
reading the handles passed as arguments. This makes it possible to
correctly specify handle invalidation for external named sequences by
requiring their declarations to always provide such annotations.
Additionally, these annotations remove the need to analyze the body of
a named sequence to understand its effects on the arguments. Make them
required for named sequences that are called from the same file, in
addition to external sequences.
Provide a convenience pass that infers annotations by analyzing bodies
of named sequences provided they are not called from the same file.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D147223
Added:
mlir/lib/Dialect/Transform/Transforms/InferEffects.cpp
mlir/test/Dialect/Transform/infer-effects.mlir
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/Transforms/Passes.h
mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
mlir/test/Dialect/Transform/ops-invalid.mlir
mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
mlir/test/Dialect/Transform/test-interpreter.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index 639a7c70db3ef..f034f3a277f52 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -36,6 +36,13 @@ def Transform_Dialect : Dialect {
constexpr const static llvm::StringLiteral
kTargetTagAttrName = "transform.target_tag";
+ /// Names of the attributes indicating whether an argument of an external
+ /// transform dialect symbol is consumed or only read.
+ constexpr const static llvm::StringLiteral
+ kArgConsumedAttrName = "transform.consumed";
+ constexpr const static llvm::StringLiteral
+ kArgReadOnlyAttrName = "transform.readonly";
+
/// Returns the named PDL constraint functions available in the dialect
/// as a map from their name to the function.
const ::llvm::StringMap<::mlir::PDLConstraintFunction> &
@@ -114,7 +121,7 @@ def Transform_Dialect : Dialect {
}];
}
-// Base class for ops that belong to the tranfsorm dialect. Ops defined in
+// Base class for ops that belong to the transform dialect. Ops defined in
// extensions of this dialect may also use this.
class TransformDialectOp<string mnemonic, list<Trait> traits = []>
: Op<Transform_Dialect, mnemonic, traits>;
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 41b084008acf4..9f4e3d8e089ff 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -847,6 +847,11 @@ bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
+/// Populates `consumedArguments` with positions of `block` arguments that are
+/// consumed by the operations in the `block`.
+void getConsumedBlockArguments(
+ Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);
+
/// Trait implementing the MemoryEffectOpInterface for operations that "consume"
/// their operands and produce new results.
template <typename OpTy>
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.h b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.h
index f19570fbdf94c..7a7dfe4709b22 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.h
@@ -1,4 +1,4 @@
-//===- CheckUses.h - Expensive transform value validity checks --*- C++ -*-===//
+//===- Passes.h - Transform dialect pass entry points -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
index 4fdd2e3d875ff..2400066c8ad8c 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
@@ -32,4 +32,14 @@ def CheckUsesPass : Pass<"transform-dialect-check-uses"> {
}];
}
+def InferEffectsPass : Pass<"transform-infer-effects"> {
+ let summary = "infer transform side effects for symbols";
+ let description = [{
+ This pass analyzes the definitions of transform dialect callable symbol
+ operations, such as `transform.named_sequence`, and annotates the symbol
+ arguments with attributes indicating the side effects that the nested
+ operations have on them.
+ }];
+}
+
#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 99ff80e08ebdd..d4578e0648179 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -175,6 +175,14 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute(
}
return success();
}
+ if (attribute.getName().getValue() == kArgConsumedAttrName ||
+ attribute.getName().getValue() == kArgReadOnlyAttrName) {
+ if (!attribute.getValue().isa<UnitAttr>()) {
+ return op->emitError()
+ << attribute.getName() << " must be a unit attribute";
+ }
+ return success();
+ }
return emitError(op->getLoc())
<< "unknown attribute: " << attribute.getName();
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index c9d28f9eef899..c4e868ef6742c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1318,6 +1318,29 @@ void transform::onlyReadsPayload(
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
}
+void transform::getConsumedBlockArguments(
+ Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ for (Operation &nested : block) {
+ auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
+ if (!iface)
+ continue;
+
+ effects.clear();
+ iface.getEffects(effects);
+ for (const MemoryEffects::EffectInstance &effect : effects) {
+ BlockArgument argument =
+ dyn_cast_or_null<BlockArgument>(effect.getValue());
+ if (!argument || argument.getOwner() != &block ||
+ !isa<MemoryEffects::Free>(effect.getEffect()) ||
+ effect.getResource() != transform::TransformMappingResource::get()) {
+ continue;
+ }
+ consumedArguments.insert(argument.getArgNumber());
+ }
+ }
+}
+
//===----------------------------------------------------------------------===//
// Utilities for TransformOpInterface.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index a37822d9d0998..c4f5769531c9d 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -720,32 +720,44 @@ verifyNamedSequenceOp(transform::NamedSequenceOp op);
void transform::IncludeOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ // Always mark as modifying the payload.
+ // TODO: a mechanism to annotate effects on payload. Even when all handles are
+ // only read, the payload may still be modified, so we currently stay on the
+ // conservative side and always indicate modification. This may prevent some
+ // code reordering.
+ modifiesPayload(effects);
+
+ // Results are always produced.
+ producesHandle(getResults(), effects);
+
+ // Adds default effects to operands and results. This will be added if
+ // preconditions fail so the trait verifier doesn't complain about missing
+ // effects and the real precondition failure is reported later on.
+ auto defaultEffects = [&] { onlyReadsHandle(getOperands(), effects); };
+
// Bail if the callee is unknown. This may run as part of the verification
// process before we verified the validity of the callee or of this op.
auto target =
getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
if (!target)
- return;
+ return defaultEffects();
auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
getOperation(), getTarget());
if (!callee)
- return;
+ return defaultEffects();
DiagnosedSilenceableFailure earlyVerifierResult =
verifyNamedSequenceOp(callee);
if (!earlyVerifierResult.succeeded()) {
(void)earlyVerifierResult.silence();
- return;
+ return defaultEffects();
}
- // Carry over effects from the callee.
- // TODO: external callees must provides attributes annotating the
- // readonly/consume effects on operands.
- if (!callee.isExternal())
- remapArgumentEffects(callee.getBody().front(), getOperands(), effects);
-
- // Proper effects.
- onlyReadsHandle(getOperands(), effects);
- producesHandle(getResults(), effects);
+ for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
+ if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
+ consumesHandle(getOperand(i), effects);
+ else
+ onlyReadsHandle(getOperand(i), effects);
+ }
}
template <typename... Tys>
@@ -753,6 +765,52 @@ static bool implementSameInterface(Type t1, Type t2) {
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
}
+/// Checks that the attributes of the named sequence operation have correct
+/// consumption effect annotations. If `alsoVerifyInternal`, checks for
+/// annotations being present even if they can be inferred from the body.
+static DiagnosedSilenceableFailure
+verifyNamedSequenceConsumeAnnotations(transform::NamedSequenceOp op,
+ bool alsoVerifyInternal = false) {
+ llvm::SmallDenseSet<unsigned> consumedArguments;
+ if (!op.isExternal()) {
+ transform::getConsumedBlockArguments(op.getBody().front(),
+ consumedArguments);
+ }
+ for (unsigned i = 0, e = op.getFunctionType().getNumInputs(); i < e; ++i) {
+ bool isConsumed =
+ op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
+ nullptr;
+ bool isReadOnly =
+ op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
+ nullptr;
+ if (isConsumed && isReadOnly) {
+ return op.emitSilenceableError()
+ << "argument #" << i << " cannot be both readonly and consumed";
+ }
+ if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
+ return op.emitSilenceableError()
+ << "must provide consumed/readonly status for arguments of "
+ "external or called ops";
+ }
+ if (op.isExternal())
+ continue;
+
+ if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
+ return op.emitSilenceableError()
+ << "argument #" << i
+ << " is consumed in the body but is not marked as such";
+ }
+ if (!consumedArguments.contains(i) && isConsumed) {
+ Diagnostic warning(op->getLoc(), DiagnosticSeverity::Warning);
+ warning << "argument #" << i
+ << " is not consumed in the body but is marked as consumed";
+ return DiagnosedSilenceableFailure::silenceableFailure(
+ std::move(warning));
+ }
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
LogicalResult
transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Access through indirection and do additional checking because this may be
@@ -794,7 +852,9 @@ transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
}
- return success();
+ return verifyNamedSequenceConsumeAnnotations(target,
+ /*alsoVerifyInternal=*/true)
+ .checkAndReport();
}
//===----------------------------------------------------------------------===//
@@ -899,7 +959,7 @@ verifyNamedSequenceOp(transform::NamedSequenceOp op) {
}
if (op.isExternal() || op.getBody().empty())
- return DiagnosedSilenceableFailure::success();
+ return verifyNamedSequenceConsumeAnnotations(op);
if (op.getBody().front().empty())
return emitSilenceableFailure(op) << "expected a non-empty body block";
@@ -931,7 +991,7 @@ verifyNamedSequenceOp(transform::NamedSequenceOp op) {
<< operandType << " vs " << resultType << ")";
}
- return DiagnosedSilenceableFailure::success();
+ return verifyNamedSequenceConsumeAnnotations(op);
}
LogicalResult transform::NamedSequenceOp::verify() {
diff --git a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
index bf9a255bacad3..68b363d4a0961 100644
--- a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRTransformDialectTransforms
CheckUses.cpp
+ InferEffects.cpp
TransformInterpreterPassBase.cpp
DEPENDS
diff --git a/mlir/lib/Dialect/Transform/Transforms/InferEffects.cpp b/mlir/lib/Dialect/Transform/Transforms/InferEffects.cpp
new file mode 100644
index 0000000000000..461ae9b37b897
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/Transforms/InferEffects.cpp
@@ -0,0 +1,69 @@
+//===- InferEffects.cpp - Infer memory effects for named symbols ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/FunctionInterfaces.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/DenseSet.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace transform {
+#define GEN_PASS_DEF_INFEREFFECTSPASS
+#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
+} // namespace transform
+} // namespace mlir
+
+static LogicalResult inferSideEffectAnnotations(Operation *op) {
+ if (!isa<transform::TransformOpInterface>(op))
+ return success();
+
+ auto func = dyn_cast<FunctionOpInterface>(op);
+ if (!func || func.isExternal())
+ return success();
+
+ if (!func.getFunctionBody().hasOneBlock()) {
+ return op->emitError()
+ << "only single-block operations are currently supported";
+ }
+
+ // Note that there can't be an inclusion of an unannotated symbol because it
+ // wouldn't have passed the verifier, so recursion isn't necessary here.
+ llvm::SmallDenseSet<unsigned> consumedArguments;
+ transform::getConsumedBlockArguments(func.getFunctionBody().front(),
+ consumedArguments);
+
+ for (unsigned i = 0, e = func.getNumArguments(); i < e; ++i) {
+ func.setArgAttr(i,
+ consumedArguments.contains(i)
+ ? transform::TransformDialect::kArgConsumedAttrName
+ : transform::TransformDialect::kArgReadOnlyAttrName,
+ UnitAttr::get(op->getContext()));
+ }
+ return success();
+}
+
+namespace {
+class InferEffectsPass
+ : public transform::impl::InferEffectsPassBase<InferEffectsPass> {
+public:
+ void runOnOperation() override {
+ WalkResult result = getOperation()->walk([](Operation *op) {
+ return failed(inferSideEffectAnnotations(op)) ? WalkResult::interrupt()
+ : WalkResult::advance();
+ });
+ if (result.wasInterrupted())
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index b3fe45efea94b..1f651dcce115a 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/FunctionInterfaces.h"
@@ -298,6 +299,12 @@ static void performOptionalDebugActions(
/// Replaces external symbols in `block` with their (non-external) definitions
/// from the given module.
static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
+ MLIRContext &ctx = *definitions->getContext();
+ auto consumedName =
+ StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
+ auto readOnlyName =
+ StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
+
for (Operation &op : llvm::make_early_inc_range(block)) {
LLVM_DEBUG(DBGS() << op << "\n");
auto symbol = dyn_cast<SymbolOpInterface>(op);
@@ -330,6 +337,30 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
<< externalSymbolFunc.getFunctionType() << ")";
}
+ for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
+ bool isExternalConsumed =
+ externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
+ bool isExternalReadonly =
+ externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+ bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
+ bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+ if (!isExternalConsumed && !isExternalReadonly) {
+ if (isConsumed)
+ externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
+ else if (isReadonly)
+ externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
+ continue;
+ }
+
+ if ((isExternalConsumed && !isConsumed) ||
+ (isExternalReadonly && !isReadonly)) {
+ return symbolFunc.emitError()
+ << "external definition has mismatching consumption annotations "
+ "for argument #"
+ << i;
+ }
+ }
+
OpBuilder builder(&op);
builder.setInsertionPoint(&op);
builder.clone(*externalSymbol);
diff --git a/mlir/test/Dialect/Transform/infer-effects.mlir b/mlir/test/Dialect/Transform/infer-effects.mlir
new file mode 100644
index 0000000000000..05c6a5a540944
--- /dev/null
+++ b/mlir/test/Dialect/Transform/infer-effects.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s --transform-infer-effects | FileCheck %s
+
+module attributes { transform.with_named_sequence } {
+ // CHECK-LABEL: @infer
+ // CHECK-SAME: %{{.*}}: !transform.any_op {transform.consumed}
+ // CHECK-SAME: %{{.*}}: !transform.any_op {transform.readonly}
+ // CHECK-SAME: %{{.*}}: !transform.param<i32> {transform.readonly}
+ transform.named_sequence @infer(%op: !transform.any_op, %other: !transform.any_op, %param: !transform.param<i32>) {
+ transform.test_consume_operand %op : !transform.any_op
+ transform.test_print_remark_at_operand %other, "" : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index a6bfa64950707..df2792a598dcb 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -467,3 +467,98 @@ module attributes { transform.with_named_sequence} {
transform.yield %arg0 : !transform.any_op
}
}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-error @below {{must provide consumed/readonly status for arguments of external or called ops}}
+ transform.named_sequence @foo(%op: !transform.any_op )
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-error @below {{argument #0 cannot be both readonly and consumed}}
+ transform.named_sequence @foo(%op: !transform.any_op { transform.readonly, transform.consumed } )
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-error @below {{must provide consumed/readonly status for arguments of external or called ops}}
+ transform.named_sequence @foo(%op: !transform.any_op) {
+ transform.test_print_remark_at_operand %op, "message" : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-error @below {{argument #0 cannot be both readonly and consumed}}
+ transform.named_sequence @foo(%op: !transform.any_op {transform.readonly, transform.consumed}) {
+ transform.test_print_remark_at_operand %op, "message" : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-warning @below {{argument #0 is not consumed in the body but is marked as consume}}
+ transform.named_sequence @foo(%op: !transform.any_op {transform.consumed}) {
+ transform.test_print_remark_at_operand %op, "message" : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-error @below {{argument #0 is consumed in the body but is not marked as such}}
+ transform.named_sequence @foo(%op: !transform.any_op {transform.readonly}) {
+ transform.test_consume_operand %op : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+
+// -----
+
+// Checking that consumptions annotations are used correctly in invocation checks.
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @foo(%op: !transform.any_op { transform.consumed } )
+
+ // expected-error @below {{'transform.sequence' op block argument #0 has more than one potential consumer}}
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ // expected-note @below {{used here as operand #0}}
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ // expected-note @below {{used here as operand #0}}
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
index da50c6bd3c1da..3d4cb07769829 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
@@ -10,4 +10,5 @@
// produced twice at the same location only needs to be matched once.
// expected-remark @below {{message}}
+// expected-remark @below {{unannotated}}
module {}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
index bb8acf88c9765..b21abbbdfd6d0 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
@@ -6,7 +6,7 @@
module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has a mismatching signature}}
- transform.named_sequence private @foo(!transform.op<"builtin.module">)
+ transform.named_sequence private @foo(!transform.op<"builtin.module"> {transform.readonly})
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.op<"builtin.module">):
@@ -25,3 +25,15 @@ module attributes {transform.with_named_sequence} {
include @undefined_sequence failures(suppress) () : () -> ()
}
}
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+ // expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
+ transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})
+
+ transform.sequence failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
index 6e236411044cf..04b6c5a02e0ad 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
@@ -1,11 +1,11 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
-// RUN: --verify-diagnostics | FileCheck %s
+// RUN: --verify-diagnostics --split-input-file | FileCheck %s
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
-// RUN: --verify-diagnostics | FileCheck %s
+// RUN: --verify-diagnostics --split-input-file | FileCheck %s
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
-// RUN: --verify-diagnostics | FileCheck %s
+// RUN: --verify-diagnostics --split-input-file | FileCheck %s
// The definition of the @foo named sequence is provided in another file. It
// will be included because of the pass option. Repeated application of the
@@ -14,13 +14,19 @@
// needs to be matched once.
// expected-remark @below {{message}}
+// expected-remark @below {{unannotated}}
module attributes {transform.with_named_sequence} {
// CHECK: transform.named_sequence @foo
// CHECK: test_print_remark_at_operand %{{.*}}, "message"
- transform.named_sequence private @foo(!transform.any_op)
+ transform.named_sequence private @foo(!transform.any_op {transform.readonly})
+
+ // CHECK: transform.named_sequence @unannotated
+ // CHECK: test_print_remark_at_operand %{{.*}}, "unannotated"
+ transform.named_sequence private @unannotated(!transform.any_op {transform.readonly})
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ include @unannotated failures(propagate) (%arg0) : (!transform.any_op) -> ()
}
}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
index 509612b284a89..1149bda98ab85 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
@@ -1,8 +1,18 @@
// RUN: mlir-opt %s
module attributes {transform.with_named_sequence} {
- transform.named_sequence @foo(%arg0: !transform.any_op) {
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) {
transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
transform.yield
}
+
+ transform.named_sequence @consuming(%arg0: !transform.any_op {transform.consumed}) {
+ transform.test_consume_operand %arg0 : !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @unannotated(%arg0: !transform.any_op) {
+ transform.test_print_remark_at_operand %arg0, "unannotated" : !transform.any_op
+ transform.yield
+ }
}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 6b2b0dd3157c5..3c2b9b0204f8b 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1260,7 +1260,7 @@ transform.sequence failures(propagate) {
module @named_inclusion attributes { transform.with_named_sequence } {
- transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () {
// expected-remark @below {{applying transformation "a"}}
transform.test_transform_op "a"
transform.yield
@@ -1276,13 +1276,13 @@ module @named_inclusion attributes { transform.with_named_sequence } {
module @named_inclusion_in_named attributes { transform.with_named_sequence } {
- transform.named_sequence @foo(%arg0: !transform.any_op) -> () {
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () {
// expected-remark @below {{applying transformation "a"}}
transform.test_transform_op "a"
transform.yield
}
- transform.named_sequence @bar(%arg0: !transform.any_op) -> () {
+ transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> () {
// expected-remark @below {{applying transformation "b"}}
transform.test_transform_op "b"
transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
@@ -1300,7 +1300,8 @@ module @named_inclusion_in_named attributes { transform.with_named_sequence } {
// expected-remark @below {{operation}}
module @named_operands attributes { transform.with_named_sequence } {
- transform.named_sequence @foo(%arg0: !transform.any_op, %arg1: !transform.any_value) -> () {
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly},
+ %arg1: !transform.any_value {transform.readonly}) -> () {
transform.test_print_remark_at_operand %arg0, "operation" : !transform.any_op
transform.test_print_remark_at_operand_value %arg1, "value" : !transform.any_value
transform.yield
@@ -1322,7 +1323,7 @@ module @named_return attributes { transform.with_named_sequence } {
// expected-remark @below {{value}}
// expected-note @below {{value handle points to a block argument #0 in block #0 in region #0}}
- transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op, !transform.any_value) {
+ transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_value) {
%0 = transform.test_produce_value_handle_to_self_operand %arg0 : (!transform.any_op) -> !transform.any_value
transform.yield %arg0, %0 : !transform.any_op, !transform.any_value
}
More information about the Mlir-commits
mailing list