[Mlir-commits] [mlir] c1e6caa - [mlir][transform] Support results on ForeachOp

Matthias Springer llvmlistbot at llvm.org
Thu Jul 28 01:43:27 PDT 2022


Author: Matthias Springer
Date: 2022-07-28T10:39:54+02:00
New Revision: c1e6caac7059ee4dd4a3860f7a1c954f07f6120e

URL: https://github.com/llvm/llvm-project/commit/c1e6caac7059ee4dd4a3860f7a1c954f07f6120e
DIFF: https://github.com/llvm/llvm-project/commit/c1e6caac7059ee4dd4a3860f7a1c954f07f6120e.diff

LOG: [mlir][transform] Support results on ForeachOp

Handles can be yielded from the ForeachOp.

Differential Revision: https://reviews.llvm.org/D130640

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/test-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index bc5fd01d215ae..1f28b33aa7b93 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -118,12 +118,17 @@ def ForeachOp : TransformDialectOp<"foreach",
     the entire sequence fails immediately leaving the payload IR in potentially
     invalid state, i.e., this operation offers no transformation rollback
     capabilities.
+
+    This op generates as many handles as the terminating YieldOp has operands.
+    For each result, the payload ops of the corresponding YieldOp operand are
+    merged and mapped to the same resulting handle.
   }];
 
   let arguments = (ins PDL_Operation:$target);
-  let results = (outs);
+  let results = (outs Variadic<PDL_Operation>:$results);
   let regions = (region SizedRegion<1>:$body);
-  let assemblyFormat = "$target $body attr-dict";
+  let assemblyFormat = "$target (`->` type($results)^)? $body attr-dict";
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Allow the dialect prefix to be omitted.
@@ -132,6 +137,8 @@ def ForeachOp : TransformDialectOp<"foreach",
     BlockArgument getIterationVariable() {
       return getBody().front().getArgument(0);
     }
+
+    transform::YieldOp getYieldOp();
   }];
 }
 

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 5d702d9f3dbdd..511b8544ddfdc 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -281,18 +281,32 @@ DiagnosedSilenceableFailure
 transform::ForeachOp::apply(transform::TransformResults &results,
                             transform::TransformState &state) {
   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+  SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
+
   for (Operation *op : payloadOps) {
     auto scope = state.make_region_scope(getBody());
     if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
       return DiagnosedSilenceableFailure::definiteFailure();
 
+    // Execute loop body.
     for (Operation &transform : getBody().front().without_terminator()) {
       DiagnosedSilenceableFailure result = state.applyTransform(
           cast<transform::TransformOpInterface>(transform));
       if (!result.succeeded())
         return result;
     }
+
+    // Append yielded payload ops to result list (if any).
+    for (unsigned i = 0; i < getNumResults(); ++i) {
+      ArrayRef<Operation *> yieldedOps =
+          state.getPayloadOps(getYieldOp().getOperand(i));
+      resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
+    }
   }
+
+  for (unsigned i = 0; i < getNumResults(); ++i)
+    results.set(getResult(i).cast<OpResult>(), resultOps[i]);
+
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -306,6 +320,9 @@ void transform::ForeachOp::getEffects(
   } else {
     onlyReadsHandle(getTarget(), effects);
   }
+
+  for (Value result : getResults())
+    producesHandle(result, effects);
 }
 
 void transform::ForeachOp::getSuccessorRegions(
@@ -331,6 +348,21 @@ transform::ForeachOp::getSuccessorEntryOperands(Optional<unsigned> index) {
   return getOperation()->getOperands();
 }
 
+transform::YieldOp transform::ForeachOp::getYieldOp() {
+  return cast<transform::YieldOp>(getBody().front().getTerminator());
+}
+
+LogicalResult transform::ForeachOp::verify() {
+  auto yieldOp = getYieldOp();
+  if (getNumResults() != yieldOp.getNumOperands())
+    return emitOpError() << "expects the same number of results as the "
+                            "terminator has operands";
+  for (Value v : yieldOp.getOperands())
+    if (!v.getType().isa<pdl::OperationType>())
+      return yieldOp->emitOpError("expects only PDL_Operation operands");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // GetClosestIsolatedParentOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index ed3a2507e2a1c..425ba608adeb4 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -627,3 +627,52 @@ transform.with_pdl_patterns {
     }
   }
 }
+
+// -----
+
+func.func @bar() {
+  scf.execute_region {
+    // expected-remark @below {{transform applied}}
+    %0 = arith.constant 0 : i32
+    scf.yield
+  }
+
+  scf.execute_region {
+    // expected-remark @below {{transform applied}}
+    %1 = arith.constant 1 : i32
+    // expected-remark @below {{transform applied}}
+    %2 = arith.constant 2 : i32
+    scf.yield
+  }
+
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @const : benefit(1) {
+    %r = pdl.types
+    %0 = pdl.operation "arith.constant" -> (%r : !pdl.range<type>)
+    pdl.rewrite %0 with "transform.dialect"
+  }
+
+  pdl.pattern @execute_region : benefit(1) {
+    %r = pdl.types
+    %0 = pdl.operation "scf.execute_region" -> (%r : !pdl.range<type>)
+    pdl.rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %f = pdl_match @execute_region in %arg1
+    %results = transform.foreach %f -> !pdl.operation {
+    ^bb2(%arg2: !pdl.operation):
+      %g = transform.pdl_match @const in %arg2
+      transform.yield %g : !pdl.operation
+    }
+
+    // expected-remark @below {{3}}
+    transform.test_print_number_of_associated_payload_ir_ops %results
+    transform.test_print_remark_at_operand %results, "transform applied"
+  }
+}


        


More information about the Mlir-commits mailing list