[Mlir-commits] [mlir] bffec21 - [mlir][transform] Add ForeachOp to transform dialect

Matthias Springer llvmlistbot at llvm.org
Tue Jul 26 09:11:05 PDT 2022


Author: Matthias Springer
Date: 2022-07-26T18:07:44+02:00
New Revision: bffec215abbd643ceac83e58caa244ded8cd837c

URL: https://github.com/llvm/llvm-project/commit/bffec215abbd643ceac83e58caa244ded8cd837c
DIFF: https://github.com/llvm/llvm-project/commit/bffec215abbd643ceac83e58caa244ded8cd837c.diff

LOG: [mlir][transform] Add ForeachOp to transform dialect

This op "unbatches" an op handle and executes the loop body for each payload op.

Differential Revision: https://reviews.llvm.org/D130257

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/ops-invalid.mlir
    mlir/test/Dialect/Transform/ops.mlir
    mlir/test/Dialect/Transform/test-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index 27a68b05429c4..f11e3da6f93d3 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -64,6 +64,14 @@ def Transform_Dialect : Dialect {
     correspond to groups of outer and inner loops, respectively, produced by
     the tiling transformation.
 
+    A Transform IR value such as `%0` may be associated with multiple payload
+    operations. This is conceptually a set of operations and no assumptions
+    should be made about the order of ops. Most Transform IR ops support
+    operand values that are mapped to multiple operations. They usually apply
+    the respective transformation for every mapped op ("batched execution").
+    Deviations from this convention are described in the documentation of
+    Transform IR ops.
+
     Overall, Transform IR ops are expected to be contained in a single top-level
     op. Such top-level ops specify how to apply the transformations described
     by the operations they contain, e.g., `transform.sequence` executes

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index d578c15e48370..bc5fd01d215ae 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -95,6 +95,46 @@ def AlternativesOp : TransformDialectOp<"alternatives",
   let hasVerifier = 1;
 }
 
+def ForeachOp : TransformDialectOp<"foreach",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     DeclareOpInterfaceMethods<RegionBranchOpInterface, [
+         "getSuccessorRegions", "getSuccessorEntryOperands"]>,
+     SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
+    ]> {
+  let summary = "Executes the body for each payload op";
+  let description = [{
+    This op has exactly one region with exactly one block ("body"). The body is
+    executed for each payload op that is associated to the target operand in an
+    unbatched fashion. I.e., the block argument ("iteration variable") is always
+    mapped to exactly one payload op.
+
+    This op always reads the target handle. Furthermore, it consumes the handle
+    if there is a transform op in the body that consumes the iteration variable.
+    This op does not return anything.
+
+    The transformations inside the body are applied in order of their
+    appearance. During application, if any transformation in the sequence fails,
+    the entire sequence fails immediately leaving the payload IR in potentially
+    invalid state, i.e., this operation offers no transformation rollback
+    capabilities.
+  }];
+
+  let arguments = (ins PDL_Operation:$target);
+  let results = (outs);
+  let regions = (region SizedRegion<1>:$body);
+  let assemblyFormat = "$target $body attr-dict";
+
+  let extraClassDeclaration = [{
+    /// Allow the dialect prefix to be omitted.
+    static StringRef getDefaultDialect() { return "transform"; }
+
+    BlockArgument getIterationVariable() {
+      return getBody().front().getArgument(0);
+    }
+  }];
+}
+
 def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      NavigationTransformOpTrait, MemoryEffectsOpInterface]> {

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index bc2d9710b8a41..5d702d9f3dbdd 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -273,6 +273,64 @@ LogicalResult transform::AlternativesOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ForeachOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ForeachOp::apply(transform::TransformResults &results,
+                            transform::TransformState &state) {
+  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+  for (Operation *op : payloadOps) {
+    auto scope = state.make_region_scope(getBody());
+    if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
+      return DiagnosedSilenceableFailure::definiteFailure();
+
+    for (Operation &transform : getBody().front().without_terminator()) {
+      DiagnosedSilenceableFailure result = state.applyTransform(
+          cast<transform::TransformOpInterface>(transform));
+      if (!result.succeeded())
+        return result;
+    }
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ForeachOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  BlockArgument iterVar = getIterationVariable();
+  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+        return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
+      })) {
+    consumesHandle(getTarget(), effects);
+  } else {
+    onlyReadsHandle(getTarget(), effects);
+  }
+}
+
+void transform::ForeachOp::getSuccessorRegions(
+    Optional<unsigned> index, ArrayRef<Attribute> operands,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  Region *bodyRegion = &getBody();
+  if (!index) {
+    regions.emplace_back(bodyRegion, bodyRegion->getArguments());
+    return;
+  }
+
+  // Branch back to the region or the parent.
+  assert(*index == 0 && "unexpected region index");
+  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
+  regions.emplace_back();
+}
+
+OperandRange
+transform::ForeachOp::getSuccessorEntryOperands(Optional<unsigned> index) {
+  // The iteration variable op handle is mapped to a subset (one op to be
+  // precise) of the payload ops of the ForeachOp operand.
+  assert(index && *index == 0 && "unexpected region index");
+  return getOperation()->getOperands();
+}
+
 //===----------------------------------------------------------------------===//
 // GetClosestIsolatedParentOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index b76bd07d3a475..2650651d25b33 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -184,3 +184,18 @@ transform.alternatives {
 ^bb0:
   transform.yield
 }
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-error @below {{result #0 has more than one potential consumer}}
+  %0 = test_produce_param_or_forward_operand 42
+  // expected-note @below {{used here as operand #0}}
+  transform.foreach %0 {
+  ^bb1(%arg1: !pdl.operation):
+    transform.test_consume_operand %arg1
+  }
+  // expected-note @below {{used here as operand #0}}
+  transform.test_consume_operand %0
+}

diff  --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index e9e99de310b72..23dd6b835dd8f 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -49,3 +49,12 @@ transform.sequence {
   ^bb3(%arg3: !pdl.operation):
   }
 }
+
+// CHECK: transform.sequence
+// CHECK: foreach
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  transform.foreach %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+  }
+}

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 1f0dd40fc8b8b..ed3a2507e2a1c 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -597,3 +597,33 @@ module {
     }
   }
 }
+
+// -----
+
+func.func @bar() {
+  // expected-remark @below {{transform applied}}
+  %0 = arith.constant 0 : i32
+  // expected-remark @below {{transform applied}}
+  %1 = arith.constant 1 : i32
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @const : benefit(1) {
+    %r = pdl.types
+    %0 = pdl.operation "arith.constant" -> (%r : !pdl.range<type>)
+    pdl.rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %f = pdl_match @const in %arg1
+    transform.foreach %f {
+    ^bb2(%arg2: !pdl.operation):
+      // expected-remark @below {{1}}
+      transform.test_print_number_of_associated_payload_ir_ops %arg2
+      transform.test_print_remark_at_operand %arg2, "transform applied"
+    }
+  }
+}


        


More information about the Mlir-commits mailing list