[Mlir-commits] [mlir] af664e4 - [mlir][Transform] Add a transform.split_handles operation and fix general silenceable bugs.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Oct 7 09:01:49 PDT 2022


Author: Nicolas Vasilache
Date: 2022-10-07T09:01:34-07:00
New Revision: af664e445940f24ce7af1b18af489e2b8ae17614

URL: https://github.com/llvm/llvm-project/commit/af664e445940f24ce7af1b18af489e2b8ae17614
DIFF: https://github.com/llvm/llvm-project/commit/af664e445940f24ce7af1b18af489e2b8ae17614.diff

LOG: [mlir][Transform] Add a transform.split_handles operation and fix general silenceable bugs.

The transform.split_handles op is useful for ensuring a statically known number of operations are
tracked by the source `handle` and to extract them into individual handles
that can be further manipulated in isolation.

In the process of making the op robust wrt to silenceable errors and the suppress mode, issues were
uncovered and fixed.

The main issue was that silenceable errors were short-circuited too early and the payloads were not
set. This resulted in suppressed silenceable errors not propagating correctly.
Fixing the issue triggered a few test failures: silenceable error returns now must properly set the results state.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D135426

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 4b979928314b8..817c1c652daca 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -210,6 +210,34 @@ def MergeHandlesOp : TransformDialectOp<"merge_handles",
   let hasFolder = 1;
 }
 
+def SplitHandlesOp : TransformDialectOp<"split_handles",
+    [FunctionalStyleTransformOpTrait,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let summary = "Splits handles from a union of payload ops to a list";
+  let description = [{
+    Creates `num_result_handles` transform IR handles extracted from the 
+    `handle` operand. The resulting Payload IR operation handles are listed
+    in the same order as the operations appear in the source `handle`.
+    This is useful for ensuring a statically known number of operations are
+    tracked by the source `handle` and to extract them into individual handles
+    that can be further manipulated in isolation.
+
+    This operation succeeds and returns `num_result_handles` if the statically
+    specified `num_result_handles` corresponds to the dynamic number of 
+    operations contained in the source `handle`. Otherwise it silently fails.
+  }];
+
+  let arguments = (ins PDL_Operation:$handle,
+                       I64Attr:$num_result_handles);
+  let results = (outs Variadic<PDL_Operation>:$results);
+  let assemblyFormat = [{
+    $handle `in` `[` $num_result_handles `]` 
+    custom<StaticNumPDLResults>(type($results), ref($num_result_handles))
+    attr-dict
+  }];
+}
+
 def PDLMatchOp : TransformDialectOp<"pdl_match",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4245d4cfc9dd1..99f93ed8a5ae9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -590,9 +590,11 @@ transform::MatchOp::apply(transform::TransformResults &results,
                 getOps()->getAsValueRange<StringAttr>().end());
 
   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
-  if (payloadOps.size() != 1)
+  if (payloadOps.size() != 1) {
+    results.set(getResult().cast<OpResult>(), {});
     return DiagnosedSilenceableFailure(
         this->emitOpError("requires exactly one target handle"));
+  }
 
   SmallVector<Operation *> res;
   auto matchFun = [&](Operation *op) {
@@ -877,8 +879,11 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
           }
           return OpFoldResult(op->getResult(0));
         }));
-    if (!diag.succeeded())
+    if (diag.isSilenceableFailure()) {
+      results.set(getFirst().cast<OpResult>(), {});
+      results.set(getSecond().cast<OpResult>(), {});
       return diag;
+    }
 
     if (splitPoints.size() != payload.size()) {
       emitError() << "expected the dynamic split point handle to point to as "
@@ -900,6 +905,8 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
     if (!linalgOp) {
       auto diag = emitSilenceableError() << "only applies to structured ops";
       diag.attachNote(target->getLoc()) << "target op";
+      results.set(getFirst().cast<OpResult>(), {});
+      results.set(getSecond().cast<OpResult>(), {});
       return diag;
     }
 
@@ -907,6 +914,8 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
       auto diag = emitSilenceableError() << "dimension " << getDimension()
                                          << " does not exist in target op";
       diag.attachNote(target->getLoc()) << "target op";
+      results.set(getFirst().cast<OpResult>(), {});
+      results.set(getSecond().cast<OpResult>(), {});
       return diag;
     }
 

diff  --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 5f32e78826c49..fde8f20bd72cb 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -47,6 +47,7 @@ transform::GetParentForOp::apply(transform::TransformResults &results,
                                            << scf::ForOp::getOperationName()
                                            << "' parent";
         diag.attachNote(target->getLoc()) << "target op";
+        results.set(getResult().cast<OpResult>(), {});
         return diag;
       }
       current = loop;
@@ -100,6 +101,7 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
       DiagnosedSilenceableFailure diag = emitSilenceableError()
                                          << "failed to outline";
       diag.attachNote(target->getLoc()) << "target op";
+      results.set(getTransformed().cast<OpResult>(), {});
       return diag;
     }
     func::CallOp call;

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 1d841f94accd5..176e93ae3ff78 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -225,8 +225,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   }
 
   transform::TransformResults results(transform->getNumResults());
+  // Compute the result but do not short-circuit the silenceable failure case as
+  // we still want the handles to propagate properly so the "suppress" mode can
+  // proceed on a best effort basis.
   DiagnosedSilenceableFailure result(transform.apply(results, *this));
-  if (!result.succeeded())
+  if (result.isDefiniteFailure())
     return result;
 
   // Remove the mapping for the operand if it is consumed by the operation. This
@@ -258,7 +261,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
     DBGS() << "Top-level payload:\n";
     getTopLevel()->print(llvm::dbgs());
   });
-  return DiagnosedSilenceableFailure::success();
+  return result;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 26a2bafbe8746..126500f1f13cc 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -23,6 +23,7 @@
 
 using namespace mlir;
 
+/// Custom parser for ReplicateOp.
 static ParseResult parsePDLOpTypedResults(
     OpAsmParser &parser, SmallVectorImpl<Type> &types,
     const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) {
@@ -30,9 +31,23 @@ static ParseResult parsePDLOpTypedResults(
   return success();
 }
 
+/// Custom printer for ReplicateOp.
 static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange,
                                    ValueRange) {}
 
+/// Custom parser for SplitHandlesOp.
+static ParseResult parseStaticNumPDLResults(OpAsmParser &parser,
+                                            SmallVectorImpl<Type> &types,
+                                            IntegerAttr numHandlesAttr) {
+  types.resize(numHandlesAttr.getInt(),
+               pdl::OperationType::get(parser.getContext()));
+  return success();
+}
+
+/// Custom printer for SplitHandlesOp.
+static void printStaticNumPDLResults(OpAsmPrinter &, Operation *, TypeRange,
+                                     IntegerAttr) {}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
 
@@ -452,6 +467,46 @@ OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
   return getHandles().front();
 }
 
+//===----------------------------------------------------------------------===//
+// SplitHandlesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::SplitHandlesOp::apply(transform::TransformResults &results,
+                                 transform::TransformState &state) {
+  int64_t numResultHandles =
+      getHandle() ? state.getPayloadOps(getHandle()).size() : 0;
+  int64_t expectedNumResultHandles = getNumResultHandles();
+  if (numResultHandles != expectedNumResultHandles) {
+    // Failing case needs to propagate gracefully for both suppress and
+    // propagate modes.
+    for (int64_t idx = 0; idx < expectedNumResultHandles; ++idx)
+      results.set(getResults()[idx].cast<OpResult>(), {});
+    // Empty input handle corner case: always propagates empty handles in both
+    // suppress and propagate modes.
+    if (numResultHandles == 0)
+      return DiagnosedSilenceableFailure::success();
+    // If the input handle was not empty and the number of result handles does
+    // not match, this is a legit silenceable error.
+    return emitSilenceableError()
+           << getHandle() << " expected to contain " << expectedNumResultHandles
+           << " operation handles but it only contains " << numResultHandles
+           << " handles";
+  }
+  // Normal successful case.
+  for (auto en : llvm::enumerate(state.getPayloadOps(getHandle())))
+    results.set(getResults()[en.index()].cast<OpResult>(), en.value());
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SplitHandlesOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getHandle(), effects);
+  producesHandle(getResults(), effects);
+  // There are no effects on the Payload IR as this is only a handle
+  // manipulation.
+}
+
 //===----------------------------------------------------------------------===//
 // PDLMatchOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 7c0deb1dbfffa..ddfd645379f9b 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -761,3 +761,42 @@ transform.sequence failures(propagate) {
 
 }
 
+// -----
+
+func.func @split_handles(%a: index, %b: index, %c: index) {
+  %0 = arith.muli %a, %b : index  
+  %1 = arith.muli %a, %c : index  
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%fun: !pdl.operation):
+  %muli = transform.structured.match ops{["arith.muli"]} in %fun
+  %h:2 = split_handles %muli in [2]
+  // expected-remark @below {{1}}
+  transform.test_print_number_of_associated_payload_ir_ops %h#0
+  %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun
+  // expected-error @below {{expected to contain 3 operation handles but it only contains 2 handles}}
+  %h_2:3 = split_handles %muli_2 in [3]
+}
+
+// -----
+
+func.func @split_handles(%a: index, %b: index, %c: index) {
+  %0 = arith.muli %a, %b : index  
+  %1 = arith.muli %a, %c : index  
+  return
+}
+
+transform.sequence failures(suppress) {
+^bb1(%fun: !pdl.operation):
+  %muli = transform.structured.match ops{["arith.muli"]} in %fun
+  %h:2 = split_handles %muli in [2]
+  // expected-remark @below {{1}}
+  transform.test_print_number_of_associated_payload_ir_ops %h#0
+  %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun
+  // Silenceable failure and all handles are now empty.
+  %h_2:3 = split_handles %muli_2 in [3]
+  // expected-remark @below {{0}}
+ transform.test_print_number_of_associated_payload_ir_ops %h_2#0
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 752aaff298b0a..3994f0ee5564c 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -292,6 +292,8 @@ mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
 DiagnosedSilenceableFailure
 mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
     transform::TransformResults &results, transform::TransformState &state) {
+  if (!getHandle())
+    emitRemark() << 0;
   emitRemark() << state.getPayloadOps(getHandle()).size();
   return DiagnosedSilenceableFailure::success();
 }


        


More information about the Mlir-commits mailing list