[Mlir-commits] [mlir] 7fdc2ed - [mlir] reallow null results in TransformEachOpTrait

Alex Zinenko llvmlistbot at llvm.org
Tue Feb 14 02:11:40 PST 2023


Author: Alex Zinenko
Date: 2023-02-14T10:11:32Z
New Revision: 7fdc2ed09f441e7e0ca5c88d947e99f259291963

URL: https://github.com/llvm/llvm-project/commit/7fdc2ed09f441e7e0ca5c88d947e99f259291963
DIFF: https://github.com/llvm/llvm-project/commit/7fdc2ed09f441e7e0ca5c88d947e99f259291963.diff

LOG: [mlir] reallow null results in TransformEachOpTrait

Previous changes in 98acd7468307b6099e7deae206a749af324ff95f were overly
eager to disallow null payload everywhere. The semantics of
TransformEachOpTrait allows individual applications to return null
payloads as means of filtering out the operations to which they are not
applicable without emitting even a silenceable failure. This is a
questionable choice, but one apparently relied upon. Null payloads are
not supposed to leak outside of the trait.

Reviewed By: qcolombet

Differential Revision: https://reviews.llvm.org/D143904

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index dc9612c9679b2..b2332c83cf35e 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -1015,6 +1015,8 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
     for (OpResult r : this->getOperation()->getResults()) {
       if (r.getType().isa<TransformParamTypeInterface>())
         transformResults.setParams(r, emptyParams);
+      else if (r.getType().isa<TransformValueHandleTypeInterface>())
+        transformResults.setValues(r, ValueRange());
       else
         transformResults.set(r, emptyPayload);
     }

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 6b0f59da9101f..5995485f5a79a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -921,48 +921,60 @@ transform::detail::checkApplyToOne(Operation *transformOp,
   // Check that the right kind of value was produced.
   for (const auto &[ptr, res] :
        llvm::zip(partialResult, transformOp->getResults())) {
-    if (ptr.isNull()) {
-      return emitDiag() << "null result #" << res.getResultNumber()
-                        << " produced";
+    if (ptr.isNull())
+      continue;
+    if (res.getType().template isa<TransformHandleTypeInterface>() &&
+        !ptr.is<Operation *>()) {
+      return emitDiag() << "application of " << transformOpName
+                        << " expected to produce an Operation * for result #"
+                        << res.getResultNumber();
     }
-    if (ptr.is<Operation *>() &&
-        !res.getType().template isa<TransformHandleTypeInterface>()) {
+    if (res.getType().template isa<TransformParamTypeInterface>() &&
+        !ptr.is<Attribute>()) {
       return emitDiag() << "application of " << transformOpName
                         << " expected to produce an Attribute for result #"
                         << res.getResultNumber();
     }
-    if (ptr.is<Attribute>() &&
-        !res.getType().template isa<TransformParamTypeInterface>()) {
+    if (res.getType().template isa<TransformValueHandleTypeInterface>() &&
+        !ptr.is<Value>()) {
       return emitDiag() << "application of " << transformOpName
-                        << " expected to produce an Operation * for result #"
+                        << " expected to produce a Value for result #"
                         << res.getResultNumber();
     }
   }
   return success();
 }
 
+template <typename T>
+static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
+  return llvm::to_vector(llvm::map_range(
+      range, [](transform::MappedValue value) { return value.get<T>(); }));
+}
+
 void transform::detail::setApplyToOneResults(
     Operation *transformOp, TransformResults &transformResults,
     ArrayRef<ApplyToEachResultList> results) {
+  SmallVector<SmallVector<MappedValue>> transposed;
+  transposed.resize(transformOp->getNumResults());
+  for (const ApplyToEachResultList &partialResults : results) {
+    if (llvm::any_of(partialResults,
+                     [](MappedValue value) { return value.isNull(); }))
+      continue;
+    assert(transformOp->getNumResults() == partialResults.size() &&
+           "expected as many partial results as op as results");
+    for (auto &[i, value] : llvm::enumerate(partialResults))
+      transposed[i].push_back(value);
+  }
+
   for (OpResult r : transformOp->getResults()) {
+    unsigned position = r.getResultNumber();
     if (r.getType().isa<TransformParamTypeInterface>()) {
-      auto params = llvm::to_vector(
-          llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) {
-            return oneResult[r.getResultNumber()].get<Attribute>();
-          }));
-      transformResults.setParams(r, params);
+      transformResults.setParams(r,
+                                 castVector<Attribute>(transposed[position]));
     } else if (r.getType().isa<TransformValueHandleTypeInterface>()) {
-      auto values = llvm::to_vector(
-          llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) {
-            return oneResult[r.getResultNumber()].get<Value>();
-          }));
-      transformResults.setValues(r, values);
+      transformResults.setValues(r, castVector<Value>(transposed[position]));
     } else {
-      auto payloads = llvm::to_vector(
-          llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) {
-            return oneResult[r.getResultNumber()].get<Operation *>();
-          }));
-      transformResults.set(r, payloads);
+      transformResults.set(r, castVector<Operation *>(transposed[position]));
     }
   }
 }

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index e8bc530f6a54e..7e2804d1621b2 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -495,8 +495,9 @@ transform.with_pdl_patterns {
 
 // -----
 
+// This should not fail.
+
 func.func @foo() {
-  // expected-note @below {{when applied to this op}}
   "op" () : () -> ()
   return
 }
@@ -513,7 +514,6 @@ transform.with_pdl_patterns {
   transform.sequence %arg0 : !pdl.operation failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation
-    // expected-error @below {{null result #0 produced}}
     transform.test_mixed_null_and_non_null_results %0
   }
 }
@@ -1053,11 +1053,11 @@ module {
 
 // -----
 
-// expected-note @below {{when applied to this op}}
+// Should not fail.
+
 module {
   transform.sequence failures(propagate) {
   ^bb0(%arg0: !transform.any_op):
-    // expected-error @below {{null result #0 produced}}
     transform.test_produce_transform_param_or_forward_operand %arg0
       { first_result_is_null }
       : (!transform.any_op) -> (!transform.any_op, !transform.param<i64>)
@@ -1079,6 +1079,19 @@ module {
 
 // -----
 
+// expected-note @below {{when applied to this op}}
+module {
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    // expected-error @below {{expected to produce a Value for result #0}}
+    transform.test_produce_transform_param_or_forward_operand %arg0
+      { second_result_is_handle }
+      : (!transform.any_op) -> (!transform.any_value, !transform.param<i64>)
+  }
+}
+
+// -----
+
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
   // expected-error @below {{attempting to assign a null payload op to this transform value}}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 4c9b3d58ffcb5..7c4e02ce7e150 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -371,7 +371,7 @@ def TestProduceTransformParamOrForwardOperandOp
                        UnitAttr:$first_result_is_param,
                        UnitAttr:$first_result_is_null,
                        UnitAttr:$second_result_is_handle);
-  let results = (outs TransformHandleTypeInterface:$out,
+  let results = (outs AnyType:$out,
                       TransformParamTypeInterface:$param);
   let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)";
   let cppNamespace = "::mlir::test";


        


More information about the Mlir-commits mailing list