[Mlir-commits] [mlir] 984c2c8 - [mlir] verify against nullptr payload in transform dialect
Alex Zinenko
llvmlistbot at llvm.org
Mon Jan 9 05:03:42 PST 2023
Author: Alex Zinenko
Date: 2023-01-09T14:03:35+01:00
New Revision: 984c2c8cb343e9a9d43b085f27f2f2ac3253cae7
URL: https://github.com/llvm/llvm-project/commit/984c2c8cb343e9a9d43b085f27f2f2ac3253cae7
DIFF: https://github.com/llvm/llvm-project/commit/984c2c8cb343e9a9d43b085f27f2f2ac3253cae7.diff
LOG: [mlir] verify against nullptr payload in transform dialect
When establishing the correspondence between transform values and
payload operations or parameters, check that the latter are non-null and
report errors. This was previously allowed for exotic cases of partially
successfull transformations with "apply each" trait, but was dangerous.
The "apply each" implementation was reworked to remove the need for this
functionality, so this can now be hardned to avoid null pointer
dereferences.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D141142
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 0ac2c457188e..b2c3827fdb82 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -832,36 +832,30 @@ applyTransformToEach(TransformOpTy transformOp, ArrayRef<Operation *> targets,
SmallVector<Diagnostic> silenceableStack;
unsigned expectedNumResults = transformOp->getNumResults();
for (Operation *target : targets) {
- // Emplace back a placeholder for the returned new ops and params.
- // This is filled with `expectedNumResults` if the op fails to apply.
- ApplyToEachResultList placeholder;
- placeholder.reserve(expectedNumResults);
- results.push_back(std::move(placeholder));
-
auto specificOp = dyn_cast<OpTy>(target);
if (!specificOp) {
Diagnostic diag(transformOp->getLoc(), DiagnosticSeverity::Error);
diag << "transform applied to the wrong op kind";
diag.attachNote(target->getLoc()) << "when applied to this op";
- // Producing `expectedNumResults` nullptr is a silenceableFailure mode.
- // TODO: encode this implicit `expectedNumResults` nullptr ==
- // silenceableFailure with a proper trait.
- results.back().assign(expectedNumResults, nullptr);
silenceableStack.push_back(std::move(diag));
continue;
}
+ ApplyToEachResultList partialResults;
+ partialResults.reserve(expectedNumResults);
Location specificOpLoc = specificOp->getLoc();
DiagnosedSilenceableFailure res =
- transformOp.applyToOne(specificOp, results.back(), state);
+ transformOp.applyToOne(specificOp, partialResults, state);
if (res.isDefiniteFailure() ||
failed(detail::checkApplyToOne(transformOp, specificOpLoc,
- results.back()))) {
+ partialResults))) {
return DiagnosedSilenceableFailure::definiteFailure();
}
if (res.isSilenceableFailure())
res.takeDiagnostics(silenceableStack);
+ else
+ results.push_back(std::move(partialResults));
}
if (!silenceableStack.empty()) {
return DiagnosedSilenceableFailure::silenceableFailure(
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 40e4d7908c55..e8ea213142d5 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -80,6 +80,13 @@ transform::TransformState::setPayloadOps(Value value,
assert(!value.getType().isa<TransformParamTypeInterface>() &&
"cannot associate payload ops with a value of parameter type");
+ for (Operation *target : targets) {
+ if (target)
+ continue;
+ return emitError(value.getLoc())
+ << "attempting to assign a null payload op to this transform value";
+ }
+
auto iface = value.getType().cast<TransformHandleTypeInterface>();
DiagnosedSilenceableFailure result =
iface.checkPayload(value.getLoc(), targets);
@@ -105,6 +112,13 @@ LogicalResult transform::TransformState::setParams(Value value,
ArrayRef<Param> params) {
assert(value != nullptr && "attempting to set params for a null value");
+ for (Attribute attr : params) {
+ if (attr)
+ continue;
+ return emitError(value.getLoc())
+ << "attempting to assign a null parameter to this transform value";
+ }
+
auto valueType = value.getType().dyn_cast<TransformParamTypeInterface>();
assert(value &&
"cannot associate parameter with a value of non-parameter type");
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index ca327c67ed58..da48fe234633 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1024,3 +1024,19 @@ transform.sequence failures(propagate) {
{ second_result_is_handle }
: (!transform.any_op) -> (!transform.any_op, !transform.param<i64>)
}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{attempting to assign a null payload op to this transform value}}
+ %0 = transform.test_produce_null_payload : !transform.any_op
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{attempting to assign a null parameter to this transform value}}
+ %0 = transform.test_produce_null_param : !transform.param<i64>
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 71bf51de70b5..338d72e3042d 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -458,6 +458,28 @@ mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+void mlir::test::TestProduceNullPayloadOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::producesHandle(getOut(), effects);
+}
+
+DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
+ transform::TransformResults &results, transform::TransformState &state) {
+ SmallVector<Operation *, 1> null({nullptr});
+ results.set(getOut().cast<OpResult>(), null);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void mlir::test::TestProduceNullParamOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
+
+DiagnosedSilenceableFailure
+mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ results.setParams(getOut().cast<OpResult>(), Attribute());
+ return DiagnosedSilenceableFailure::success();
+}
+
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index dbe058c86bb3..9ff5e30944e7 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -334,4 +334,22 @@ def TestProduceTransformParamOrForwardOperandOp
}];
}
+def TestProduceNullPayloadOp
+ : Op<Transform_Dialect, "test_produce_null_payload",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let results = (outs TransformHandleTypeInterface:$out);
+ let assemblyFormat = "attr-dict `:` type($out)";
+ let cppNamespace = "::mlir::test";
+}
+
+def TestProduceNullParamOp
+ : Op<Transform_Dialect, "test_produce_null_param",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let results = (outs TransformParamTypeInterface:$out);
+ let assemblyFormat = "attr-dict `:` type($out)";
+ let cppNamespace = "::mlir::test";
+}
+
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
More information about the Mlir-commits
mailing list