[Mlir-commits] [mlir] e4e64ea - [MLIR][Transform] Consolidate the transform ops of get_parent_for and loop unroll from affine and scf dialects.

Prabhdeep Singh Soni llvmlistbot at llvm.org
Wed Nov 30 08:08:19 PST 2022


Author: Amy Wang
Date: 2022-11-30T11:07:44-05:00
New Revision: e4e64eaade9463974ce92e4ab5f04d8e7a699de5

URL: https://github.com/llvm/llvm-project/commit/e4e64eaade9463974ce92e4ab5f04d8e7a699de5
DIFF: https://github.com/llvm/llvm-project/commit/e4e64eaade9463974ce92e4ab5f04d8e7a699de5.diff

LOG: [MLIR][Transform] Consolidate the transform ops of get_parent_for and loop unroll from affine and scf dialects.

This patch consolidates the two transform ops from the affine dialect
and the scf dialect to avoid code duplication.

This is to address the review comments from
https://reviews.llvm.org/D137997.

The transform ops directory / file structure for the affine dialect is
kept for the purpose of forth-coming transform ops
for affine, but get_parent_for and unroll are removed.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D138980

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
    mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
    mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
    mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
    mlir/test/Dialect/SCF/transform-ops.mlir

Removed: 
    mlir/test/Dialect/Affine/transform-ops.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
index dc59435a7896d..e2b7e50ef8cd7 100644
--- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
+++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
@@ -18,61 +18,4 @@ include "mlir/IR/OpBase.td"
 
 def Transform_AffineForOp : Transform_ConcreteOpType<"affine.for">;
 
-def AffineGetParentForOp : Op<Transform_Dialect, "affine.get_parent_for", [
-  NavigationTransformOpTrait, MemoryEffectsOpInterface,
-  DeclareOpInterfaceMethods<TransformOpInterface>]> {
-  let summary =
-      "Gets a handle to the parent 'affine.for' loop of the given operation";
-  let description = [{
-    Produces a handle to the n-th (default 1) parent `affine.for` loop for each
-    Payload IR operation associated with the operand. Fails if such a loop
-    cannot be found. The list of operations associated with the handle contains
-    parent operations in the same order as the list associated with the operand,
-    except for operations that are parents to more than one input which are only
-    present once.
-  }];
-
-  let arguments =
-    (ins TransformTypeInterface:$target,
-         DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
-                           "1">:$num_loops);
-  let results = (outs TransformTypeInterface:$parent);
-
-  let assemblyFormat =
-      "$target attr-dict `:` functional-type(operands, results)";
-}
-
-
-def AffineLoopUnrollOp : Op<Transform_Dialect, "affine.unroll", [
-  FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-  TransformOpInterface, TransformEachOpTrait]> {
-  let summary = "Unrolls the given loop with the given unroll factor";
-  let description = [{
-    Unrolls each loop associated with the given handle to have up to the given
-    number of loop body copies per iteration. If the unroll factor is larger
-    than the loop trip count, the latter is used as the unroll factor instead.
-
-    #### Return modes
-
-    This operation ignores non-affine::For ops and drops them in the return.
-    If all the operations referred to by the `target` PDLOperation unroll
-    properly, the transform succeeds. Otherwise the transform silently fails.
-
-    Does not return handles as the operation may result in the loop being
-    removed after a full unrolling.
-  }];
-
-  let arguments = (ins Transform_AffineForOp:$target,
-                       ConfinedAttr<I64Attr, [IntPositive]>:$factor);
-
-  let assemblyFormat = "$target attr-dict `:` type($target)";
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::AffineForOp target,
-        ::llvm::SmallVector<::mlir::Operation *> & results,
-        ::mlir::transform::TransformState & state);
-  }];
-}
-
 #endif // Affine_TRANSFORM_OPS

diff  --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index d1c5c595f4e2c..59d25da1b2d7c 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -23,19 +23,20 @@ def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
   let summary = "Gets a handle to the parent 'for' loop of the given operation";
   let description = [{
-    Produces a handle to the n-th (default 1) parent `scf.for` loop for each
-    Payload IR operation associated with the operand. Fails if such a loop
-    cannot be found. The list of operations associated with the handle contains
-    parent operations in the same order as the list associated with the operand,
-    except for operations that are parents to more than one input which are only
-    present once.
+    Produces a handle to the n-th (default 1) parent `scf.for` or `affine.for`
+    (when the affine flag is true) loop for each Payload IR operation
+    associated with the operand. Fails if such a loop cannot be found. The list
+    of operations associated with the handle contains parent operations in the
+    same order as the list associated with the operand, except for operations
+    that are parents to more than one input which are only present once.
   }];
 
   let arguments =
     (ins TransformTypeInterface:$target,
          DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
-                           "1">:$num_loops);
-  let results = (outs TransformTypeInterface:$parent);
+                           "1">:$num_loops,
+         DefaultValuedAttr<BoolAttr, "false">:$affine);
+  let results = (outs TransformTypeInterface : $parent);
 
   let assemblyFormat =
     "$target attr-dict `:` functional-type(operands, results)";
@@ -166,22 +167,23 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
 
     #### Return modes
 
-    This operation ignores non-scf::For ops and drops them in the return.
-    If all the operations referred to by the `target` PDLOperation unroll
-    properly, the transform succeeds. Otherwise the transform silently fails.
+    This operation ignores non-scf::For, non-affine::For ops and drops them in
+    the return.  If all the operations referred to by the `target` PDLOperation
+    unroll properly, the transform succeeds. Otherwise the transform silently
+    fails.
 
     Does not return handles as the operation may result in the loop being
     removed after a full unrolling.
   }];
 
-  let arguments = (ins Transform_ScfForOp:$target,
+  let arguments = (ins TransformTypeInterface:$target,
                        ConfinedAttr<I64Attr, [IntPositive]>:$factor);
 
   let assemblyFormat = "$target attr-dict `:` type($target)";
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::scf::ForOp target,
+        ::mlir::Operation *target,
         ::llvm::SmallVector<::mlir::Operation *> &results,
         ::mlir::transform::TransformState &state);
   }];

diff  --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index 7c32166fd6ad0..605c07f33ad44 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -22,52 +22,6 @@ class SimpleRewriter : public PatternRewriter {
 };
 } // namespace
 
-//===----------------------------------------------------------------------===//
-// AffineGetParentForOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure
-transform::AffineGetParentForOp::apply(transform::TransformResults &results,
-                                       transform::TransformState &state) {
-  SetVector<Operation *> parents;
-  for (Operation *target : state.getPayloadOps(getTarget())) {
-    AffineForOp loop;
-    Operation *current = target;
-    for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
-      loop = current->getParentOfType<AffineForOp>();
-      if (!loop) {
-        DiagnosedSilenceableFailure diag = emitSilenceableError()
-                                           << "could not find an '"
-                                           << AffineForOp::getOperationName()
-                                           << "' parent";
-        diag.attachNote(target->getLoc()) << "target op";
-        results.set(getResult().cast<OpResult>(), {});
-        return diag;
-      }
-      current = loop;
-    }
-    parents.insert(loop);
-  }
-  results.set(getResult().cast<OpResult>(), parents.getArrayRef());
-  return DiagnosedSilenceableFailure::success();
-}
-
-//===----------------------------------------------------------------------===//
-// LoopUnrollOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure
-transform::AffineLoopUnrollOp::applyToOne(AffineForOp target,
-                                          SmallVector<Operation *> &results,
-                                          transform::TransformState &state) {
-  if (failed(loopUnrollByFactor(target, getFactor()))) {
-    Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
-    diag << "op failed to unroll";
-    return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
-  }
-  return DiagnosedSilenceableFailure(success());
-}
-
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index ec8516c9e422e..ca057156c21bc 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -30,21 +31,23 @@ class SimpleRewriter : public PatternRewriter {
 //===----------------------------------------------------------------------===//
 // GetParentForOp
 //===----------------------------------------------------------------------===//
-
 DiagnosedSilenceableFailure
 transform::GetParentForOp::apply(transform::TransformResults &results,
                                  transform::TransformState &state) {
   SetVector<Operation *> parents;
   for (Operation *target : state.getPayloadOps(getTarget())) {
-    scf::ForOp loop;
-    Operation *current = target;
+    Operation *loop, *current = target;
     for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
-      loop = current->getParentOfType<scf::ForOp>();
+      loop = getAffine() ? current->getParentOfType<AffineForOp>()
+                         : current->getParentOfType<scf::ForOp>();
+
       if (!loop) {
-        DiagnosedSilenceableFailure diag = emitSilenceableError()
-                                           << "could not find an '"
-                                           << scf::ForOp::getOperationName()
-                                           << "' parent";
+        DiagnosedSilenceableFailure diag =
+            emitSilenceableError()
+            << "could not find an '"
+            << (getAffine() ? AffineForOp::getOperationName()
+                            : scf::ForOp::getOperationName())
+            << "' parent";
         diag.attachNote(target->getLoc()) << "target op";
         results.set(getResult().cast<OpResult>(), {});
         return diag;
@@ -215,12 +218,18 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp target,
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure
-transform::LoopUnrollOp::applyToOne(scf::ForOp target,
+transform::LoopUnrollOp::applyToOne(Operation *op,
                                     SmallVector<Operation *> &results,
                                     transform::TransformState &state) {
-  if (failed(loopUnrollByFactor(target, getFactor()))) {
-    Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
-    diag << "op failed to unroll";
+  LogicalResult result(failure());
+  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
+    result = loopUnrollByFactor(scfFor, getFactor());
+  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+    result = loopUnrollByFactor(affineFor, getFactor());
+
+  if (failed(result)) {
+    Diagnostic diag(op->getLoc(), DiagnosticSeverity::Note);
+    diag << "Op failed to unroll";
     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
   }
   return DiagnosedSilenceableFailure(success());

diff  --git a/mlir/test/Dialect/Affine/transform-ops.mlir b/mlir/test/Dialect/Affine/transform-ops.mlir
deleted file mode 100644
index 0a122092ac5e4..0000000000000
--- a/mlir/test/Dialect/Affine/transform-ops.mlir
+++ /dev/null
@@ -1,67 +0,0 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s
-
-// CHECK-LABEL: @get_parent_for_op
-func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) {
-  // expected-remark @below {{first loop}}
-  affine.for %i = %arg0 to %arg1 {
-    // expected-remark @below {{second loop}}
-    affine.for %j = %arg0 to %arg1 {
-      // expected-remark @below {{third loop}}
-      affine.for %k = %arg0 to %arg1 {
-        arith.addi %i, %j : index
-      }
-    }
-  }
-  return
-}
-
-transform.sequence failures(propagate) {
-^bb1(%arg1: !pdl.operation):
-  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
-  // CHECK: = transform.affine.get_parent_for
-  %1 = transform.affine.get_parent_for %0 : (!pdl.operation) -> !transform.op<"affine.for">
-  %2 = transform.affine.get_parent_for %0 { num_loops = 2 } : (!pdl.operation) -> !transform.op<"affine.for">
-  %3 = transform.affine.get_parent_for %0 { num_loops = 3 } : (!pdl.operation) -> !transform.op<"affine.for">
-  transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"affine.for">
-  transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"affine.for">
-  transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"affine.for">
-}
-
-// -----
-
-func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
-  // expected-note @below {{target op}}
-  arith.addi %arg0, %arg1 : index
-  return
-}
-
-transform.sequence failures(propagate) {
-^bb1(%arg1: !pdl.operation):
-  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
-  // expected-error @below {{could not find an 'affine.for' parent}}
-  %1 = transform.affine.get_parent_for %0 : (!pdl.operation) -> !transform.op<"affine.for">
-}
-
-// -----
-
-func.func @loop_unroll_op() {
-  %c0 = arith.constant 0 : index
-  %c42 = arith.constant 42 : index
-  %c5 = arith.constant 5 : index
-  // CHECK: affine.for %[[I:.+]] =
-  // expected-remark @below {{affine for loop}}
-  affine.for %i = %c0 to %c42 {
-    // CHECK-COUNT-4: arith.addi
-    arith.addi %i, %i : index
-  }
-  return
-}
-
-transform.sequence failures(propagate) {
-^bb1(%arg1: !pdl.operation):
-  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
-  %1 = transform.affine.get_parent_for %0 : (!pdl.operation) -> !transform.op<"affine.for">
-  transform.test_print_remark_at_operand %1, "affine for loop" : !transform.op<"affine.for">
-  transform.affine.unroll %1 { factor = 4 } : !transform.op<"affine.for">
-}
-

diff  --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index baca3c811ec0b..d6ff2f2b821dc 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -192,3 +192,94 @@ transform.sequence failures(propagate) {
   transform.loop.unroll %1 { factor = 4 } : !transform.op<"scf.for">
 }
 
+// -----
+
+// CHECK-LABEL: @get_parent_for_op
+func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) {
+  // expected-remark @below {{first loop}}
+  affine.for %i = %arg0 to %arg1 {
+    // expected-remark @below {{second loop}}
+    affine.for %j = %arg0 to %arg1 {
+      // expected-remark @below {{third loop}}
+      affine.for %k = %arg0 to %arg1 {
+        arith.addi %i, %j : index
+      }
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
+  // CHECK: = transform.loop.get_parent_for
+  %1 = transform.loop.get_parent_for %0 { affine = true } : (!pdl.operation) -> !transform.op<"affine.for">
+  %2 = transform.loop.get_parent_for %0 { num_loops = 2, affine = true } : (!pdl.operation) -> !transform.op<"affine.for">
+  %3 = transform.loop.get_parent_for %0 { num_loops = 3, affine = true } : (!pdl.operation) -> !transform.op<"affine.for">
+  transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"affine.for">
+  transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"affine.for">
+  transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"affine.for">
+}
+
+// -----
+
+func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
+  // expected-note @below {{target op}}
+  arith.addi %arg0, %arg1 : index
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
+  // expected-error @below {{could not find an 'affine.for' parent}}
+  %1 = transform.loop.get_parent_for %0 { affine = true } : (!pdl.operation) -> !transform.op<"affine.for">
+}
+
+// -----
+
+func.func @loop_unroll_op() {
+  %c0 = arith.constant 0 : index
+  %c42 = arith.constant 42 : index
+  %c5 = arith.constant 5 : index
+  // CHECK: affine.for %[[I:.+]] =
+  // expected-remark @below {{affine for loop}}
+  affine.for %i = %c0 to %c42 {
+    // CHECK-COUNT-4: arith.addi
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
+  %1 = transform.loop.get_parent_for %0 { affine = true } : (!pdl.operation) -> !transform.op<"affine.for">
+  transform.test_print_remark_at_operand %1, "affine for loop" : !transform.op<"affine.for">
+  transform.loop.unroll %1 { factor = 4, affine = true } : !transform.op<"affine.for">
+}
+
+// -----
+
+func.func @test_mixed_loops() {
+  %c0 = arith.constant 0 : index
+  %c42 = arith.constant 42 : index
+  %c5 = arith.constant 5 : index
+  scf.for %j = %c0 to %c42 step %c5 {
+    // CHECK: affine.for %[[I:.+]] =
+    // expected-remark @below {{affine for loop}}
+    affine.for %i = %c0 to %c42 {
+      // CHECK-COUNT-4: arith.addi
+      arith.addi %i, %i : index
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["arith.addi"]} in %arg1
+  %1 = transform.loop.get_parent_for %0 { num_loops = 1, affine = true } : (!pdl.operation) -> !transform.op<"affine.for">
+  transform.test_print_remark_at_operand %1, "affine for loop" : !transform.op<"affine.for">
+  transform.loop.unroll %1 { factor = 4 } : !transform.op<"affine.for">
+}


        


More information about the Mlir-commits mailing list