[Mlir-commits] [mlir] 984c2c8 - [mlir] verify against nullptr payload in transform dialect

Alex Zinenko llvmlistbot at llvm.org
Mon Jan 9 05:03:42 PST 2023


Author: Alex Zinenko
Date: 2023-01-09T14:03:35+01:00
New Revision: 984c2c8cb343e9a9d43b085f27f2f2ac3253cae7

URL: https://github.com/llvm/llvm-project/commit/984c2c8cb343e9a9d43b085f27f2f2ac3253cae7
DIFF: https://github.com/llvm/llvm-project/commit/984c2c8cb343e9a9d43b085f27f2f2ac3253cae7.diff

LOG: [mlir] verify against nullptr payload in transform dialect

When establishing the correspondence between transform values and
payload operations or parameters, check that the latter are non-null and
report errors. This was previously allowed for exotic cases of partially
successfull transformations with "apply each" trait, but was dangerous.
The "apply each" implementation was reworked to remove the need for this
functionality, so this can now be hardned to avoid null pointer
dereferences.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D141142

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 0ac2c457188e..b2c3827fdb82 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -832,36 +832,30 @@ applyTransformToEach(TransformOpTy transformOp, ArrayRef<Operation *> targets,
   SmallVector<Diagnostic> silenceableStack;
   unsigned expectedNumResults = transformOp->getNumResults();
   for (Operation *target : targets) {
-    // Emplace back a placeholder for the returned new ops and params.
-    // This is filled with `expectedNumResults` if the op fails to apply.
-    ApplyToEachResultList placeholder;
-    placeholder.reserve(expectedNumResults);
-    results.push_back(std::move(placeholder));
-
     auto specificOp = dyn_cast<OpTy>(target);
     if (!specificOp) {
       Diagnostic diag(transformOp->getLoc(), DiagnosticSeverity::Error);
       diag << "transform applied to the wrong op kind";
       diag.attachNote(target->getLoc()) << "when applied to this op";
-      // Producing `expectedNumResults` nullptr is a silenceableFailure mode.
-      // TODO: encode this implicit `expectedNumResults` nullptr ==
-      // silenceableFailure with a proper trait.
-      results.back().assign(expectedNumResults, nullptr);
       silenceableStack.push_back(std::move(diag));
       continue;
     }
 
+    ApplyToEachResultList partialResults;
+    partialResults.reserve(expectedNumResults);
     Location specificOpLoc = specificOp->getLoc();
     DiagnosedSilenceableFailure res =
-        transformOp.applyToOne(specificOp, results.back(), state);
+        transformOp.applyToOne(specificOp, partialResults, state);
     if (res.isDefiniteFailure() ||
         failed(detail::checkApplyToOne(transformOp, specificOpLoc,
-                                       results.back()))) {
+                                       partialResults))) {
       return DiagnosedSilenceableFailure::definiteFailure();
     }
 
     if (res.isSilenceableFailure())
       res.takeDiagnostics(silenceableStack);
+    else
+      results.push_back(std::move(partialResults));
   }
   if (!silenceableStack.empty()) {
     return DiagnosedSilenceableFailure::silenceableFailure(

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 40e4d7908c55..e8ea213142d5 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -80,6 +80,13 @@ transform::TransformState::setPayloadOps(Value value,
   assert(!value.getType().isa<TransformParamTypeInterface>() &&
          "cannot associate payload ops with a value of parameter type");
 
+  for (Operation *target : targets) {
+    if (target)
+      continue;
+    return emitError(value.getLoc())
+           << "attempting to assign a null payload op to this transform value";
+  }
+
   auto iface = value.getType().cast<TransformHandleTypeInterface>();
   DiagnosedSilenceableFailure result =
       iface.checkPayload(value.getLoc(), targets);
@@ -105,6 +112,13 @@ LogicalResult transform::TransformState::setParams(Value value,
                                                    ArrayRef<Param> params) {
   assert(value != nullptr && "attempting to set params for a null value");
 
+  for (Attribute attr : params) {
+    if (attr)
+      continue;
+    return emitError(value.getLoc())
+           << "attempting to assign a null parameter to this transform value";
+  }
+
   auto valueType = value.getType().dyn_cast<TransformParamTypeInterface>();
   assert(value &&
          "cannot associate parameter with a value of non-parameter type");

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index ca327c67ed58..da48fe234633 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1024,3 +1024,19 @@ transform.sequence failures(propagate) {
     { second_result_is_handle }
     : (!transform.any_op) -> (!transform.any_op, !transform.param<i64>)
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{attempting to assign a null payload op to this transform value}}
+  %0 = transform.test_produce_null_payload : !transform.any_op
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{attempting to assign a null parameter to this transform value}}
+  %0 = transform.test_produce_null_param : !transform.param<i64>
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 71bf51de70b5..338d72e3042d 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -458,6 +458,28 @@ mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+void mlir::test::TestProduceNullPayloadOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::producesHandle(getOut(), effects);
+}
+
+DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  SmallVector<Operation *, 1> null({nullptr});
+  results.set(getOut().cast<OpResult>(), null);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void mlir::test::TestProduceNullParamOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
+
+DiagnosedSilenceableFailure
+mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results,
+                                          transform::TransformState &state) {
+  results.setParams(getOut().cast<OpResult>(), Attribute());
+  return DiagnosedSilenceableFailure::success();
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index dbe058c86bb3..9ff5e30944e7 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -334,4 +334,22 @@ def TestProduceTransformParamOrForwardOperandOp
   }];
 }
 
+def TestProduceNullPayloadOp
+  : Op<Transform_Dialect, "test_produce_null_payload",
+      [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+       DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let results = (outs TransformHandleTypeInterface:$out);
+  let assemblyFormat = "attr-dict `:` type($out)";
+  let cppNamespace = "::mlir::test";
+}
+
+def TestProduceNullParamOp
+  : Op<Transform_Dialect, "test_produce_null_param",
+      [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+       DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let results = (outs TransformParamTypeInterface:$out);
+  let assemblyFormat = "attr-dict `:` type($out)";
+  let cppNamespace = "::mlir::test";
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD


        


More information about the Mlir-commits mailing list