[Mlir-commits] [mlir] aa93659 - [mlir][SCF] Add utility to clone an scf.ForOp while appending new yield values.

Nicolas Vasilache llvmlistbot at llvm.org
Fri May 29 04:31:29 PDT 2020


Author: Nicolas Vasilache
Date: 2020-05-29T07:28:17-04:00
New Revision: aa93659c9ffcf36dc1b0c70ce5b0f526a2b9c007

URL: https://github.com/llvm/llvm-project/commit/aa93659c9ffcf36dc1b0c70ce5b0f526a2b9c007
DIFF: https://github.com/llvm/llvm-project/commit/aa93659c9ffcf36dc1b0c70ce5b0f526a2b9c007.diff

LOG: [mlir][SCF] Add utility to clone an scf.ForOp while appending new yield values.

This utility factors out the machinery required to add iterArgs and yield values to an scf.ForOp.

Differential Revision: https://reviews.llvm.org/D80656

Added: 
    mlir/include/mlir/Dialect/SCF/Utils.h
    mlir/lib/Dialect/SCF/Transforms/Utils.cpp
    mlir/test/Transforms/loop-utils.mlir
    mlir/test/lib/Transforms/TestSCFUtils.cpp

Modified: 
    mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils.h
new file mode 100644
index 000000000000..7f8ebd3a4260
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SCF/Utils.h
@@ -0,0 +1,50 @@
+//===- Utils.h - SCF dialect utilities --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes for various SCF utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SCF_UTILS_H_
+#define MLIR_DIALECT_SCF_UTILS_H_
+
+namespace mlir {
+class OpBuilder;
+class ValueRange;
+
+namespace scf {
+class ForOp;
+class ParallelOp;
+} // end namespace scf
+
+/// Create a clone of `loop` with `newIterOperands` added as new initialization
+/// values and `newYieldedValues` added as new yielded values. The returned
+/// ForOp has `newYieldedValues.size()` new result values.  The `loop` induction
+/// variable and `newIterOperands` are remapped to the new induction variable
+/// and the new entry block arguments respectively.
+///
+/// Additionally, if `replaceLoopResults` is true, all uses of
+/// `loop.getResults()` are replaced with the first `loop.getNumResults()`
+/// return values respectively. This additional replacement is provided as a
+/// convenience to update the consumers of `loop`, in the case e.g. when `loop`
+/// is soon to be deleted.
+///
+/// Return the cloned loop.
+///
+/// This convenience function is useful to factorize common mechanisms related
+/// to hoisting roundtrips to memory into yields. It does not perform any
+/// legality checks.
+///
+/// Prerequisite: `newYieldedValues.size() == newYieldedValues.size()`.
+scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
+                              ValueRange newIterOperands,
+                              ValueRange newYieldedValues,
+                              bool replaceLoopResults = true);
+
+} // end namespace mlir
+#endif // MLIR_DIALECT_SCF_UTILS_H_

diff  --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index f2b0265ed74d..58890d4a3782 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ParallelLoopFusion.cpp
   ParallelLoopSpecialization.cpp
   ParallelLoopTiling.cpp
+  Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF

diff  --git a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
new file mode 100644
index 000000000000..6ae360a34abc
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
@@ -0,0 +1,73 @@
+//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements miscellaneous loop transformation routines.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Utils.h"
+
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+
+using namespace mlir;
+
+scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
+                                    ValueRange newIterOperands,
+                                    ValueRange newYieldedValues,
+                                    bool replaceLoopResults) {
+  assert(newIterOperands.size() == newYieldedValues.size() &&
+         "newIterOperands must be of the same size as newYieldedValues");
+
+  // Create a new loop before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(loop);
+  auto operands = llvm::to_vector<4>(loop.getIterOperands());
+  operands.append(newIterOperands.begin(), newIterOperands.end());
+  scf::ForOp newLoop =
+      b.create<scf::ForOp>(loop.getLoc(), loop.lowerBound(), loop.upperBound(),
+                           loop.step(), operands);
+
+  auto &loopBody = *loop.getBody();
+  auto &newLoopBody = *newLoop.getBody();
+  // Clone / erase the yield inside the original loop to both:
+  //   1. augment its operands with the newYieldedValues.
+  //   2. automatically apply the BlockAndValueMapping on its operand
+  auto yield = cast<scf::YieldOp>(loopBody.getTerminator());
+  b.setInsertionPoint(yield);
+  auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
+  yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
+  auto newYield = b.create<scf::YieldOp>(yield.getLoc(), yieldOperands);
+
+  // Clone the loop body with remaps.
+  BlockAndValueMapping bvm;
+  // a. remap the induction variable.
+  bvm.map(loop.getInductionVar(), newLoop.getInductionVar());
+  // b. remap the BB args.
+  bvm.map(loopBody.getArguments(),
+          newLoopBody.getArguments().take_front(loopBody.getNumArguments()));
+  // c. remap the iter args.
+  bvm.map(newIterOperands,
+          newLoop.getRegionIterArgs().take_back(newIterOperands.size()));
+  b.setInsertionPointToStart(&newLoopBody);
+  // Skip the original yield terminator which does not have enough operands.
+  for (auto &o : loopBody.without_terminator())
+    b.clone(o, bvm);
+
+  // Replace `loop`'s results if requested.
+  if (replaceLoopResults) {
+    for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
+                                                    loop.getNumResults())))
+      std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
+  }
+
+  // TODO: this is unsafe in the context of a PatternRewrite.
+  newYield.erase();
+
+  return newLoop;
+}

diff  --git a/mlir/test/Transforms/loop-utils.mlir b/mlir/test/Transforms/loop-utils.mlir
new file mode 100644
index 000000000000..3d3dadfba179
--- /dev/null
+++ b/mlir/test/Transforms/loop-utils.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -allow-unregistered-dialect -test-scf-utils -mlir-disable-threading %s | FileCheck %s
+
+// CHECK-LABEL: @hoist
+//  CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index,
+//  CHECK-SAME: %[[ub:[a-zA-Z0-9]*]]: index,
+//  CHECK-SAME: %[[step:[a-zA-Z0-9]*]]: index
+func @hoist(%lb: index, %ub: index, %step: index) {
+  // CHECK: %[[A:.*]] = "fake_read"() : () -> index
+  // CHECK: %[[RES:.*]] = scf.for %[[I:.*]] = %[[lb]] to %[[ub]] step %[[step]] iter_args(%[[VAL:.*]] = %[[A]]) -> (index)
+  // CHECK:   %[[YIELD:.*]] = "fake_compute"(%[[VAL]]) : (index) -> index
+  // CHECK:   scf.yield %[[YIELD]] : index
+  // CHECK: "fake_write"(%[[RES]]) : (index) -> ()
+  scf.for %i = %lb to %ub step %step {
+    %0 = "fake_read"() : () -> (index)
+    %1 = "fake_compute"(%0) : (index) -> (index)
+    "fake_write"(%1) : (index) -> ()
+  }
+  return
+}
+
+// CHECK-LABEL: @hoist2
+//  CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index,
+//  CHECK-SAME: %[[ub:[a-zA-Z0-9]*]]: index,
+//  CHECK-SAME: %[[step:[a-zA-Z0-9]*]]: index
+//  CHECK-SAME: %[[extra_arg:[a-zA-Z0-9]*]]: f32
+func @hoist2(%lb: index, %ub: index, %step: index, %extra_arg: f32) -> f32 {
+  // CHECK: %[[A:.*]] = "fake_read"() : () -> index
+  // CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[lb]] to %[[ub]] step %[[step]] iter_args(%[[VAL0:.*]] = %[[extra_arg]], %[[VAL1:.*]] = %[[A]]) -> (f32, index)
+  // CHECK:   %[[YIELD:.*]] = "fake_compute"(%[[VAL1]]) : (index) -> index
+  // CHECK:   scf.yield %[[VAL0]], %[[YIELD]] : f32, index
+  // CHECK: "fake_write"(%[[RES]]#1) : (index) -> ()
+  // CHECK: return %[[RES]]#0 : f32
+  %0 = scf.for %i = %lb to %ub step %step iter_args(%iter = %extra_arg) -> (f32) {
+    %0 = "fake_read"() : () -> (index)
+    %1 = "fake_compute"(%0) : (index) -> (index)
+    "fake_write"(%1) : (index) -> ()
+    scf.yield %iter: f32
+  }
+  return %0: f32
+}

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 6069570316a8..3f2befefe704 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_library(MLIRTestTransforms
   TestMemRefBoundCheck.cpp
   TestMemRefDependenceCheck.cpp
   TestMemRefStrideCalculation.cpp
+  TestSCFUtils.cpp
   TestVectorTransforms.cpp
 
   EXCLUDE_FROM_LIBMLIR
@@ -41,6 +42,7 @@ add_mlir_library(MLIRTestTransforms
   MLIRLinalgTransforms
   MLIRNVVMIR
   MLIRSCF
+  MLIRSCFTransforms
   MLIRGPU
   MLIRPass
   MLIRROCDLIR

diff  --git a/mlir/test/lib/Transforms/TestSCFUtils.cpp b/mlir/test/lib/Transforms/TestSCFUtils.cpp
new file mode 100644
index 000000000000..ba06bbcc8860
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestSCFUtils.cpp
@@ -0,0 +1,58 @@
+//===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test SCF dialect utils.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Utils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/SetVector.h"
+
+using namespace mlir;
+
+namespace {
+class TestSCFUtilsPass : public PassWrapper<TestSCFUtilsPass, FunctionPass> {
+public:
+  explicit TestSCFUtilsPass() {}
+
+  void runOnFunction() override {
+    FuncOp func = getFunction();
+    SmallVector<scf::ForOp, 4> toErase;
+
+    func.walk([&](Operation *fakeRead) {
+      if (fakeRead->getName().getStringRef() != "fake_read")
+        return;
+      auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
+      auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
+      auto loop = fakeRead->getParentOfType<scf::ForOp>();
+
+      OpBuilder b(loop);
+      loop.moveOutOfLoop({fakeRead});
+      fakeWrite->moveAfter(loop);
+      auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
+                                        fakeCompute->getResult(0));
+      fakeCompute->getResult(0).replaceAllUsesWith(
+          newLoop.getResults().take_back()[0]);
+      toErase.push_back(loop);
+    });
+    for (auto loop : llvm::reverse(toErase))
+      loop.erase();
+  }
+};
+} // end namespace
+
+namespace mlir {
+void registerTestSCFUtilsPass() {
+  PassRegistration<TestSCFUtilsPass>("test-scf-utils", "test scf utils");
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2764b23b7b35..165d6ad159ac 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -61,6 +61,7 @@ void registerTestMemRefStrideCalculation();
 void registerTestOpaqueLoc();
 void registerTestParallelismDetection();
 void registerTestGpuParallelLoopMappingPass();
+void registerTestSCFUtilsPass();
 void registerTestVectorConversions();
 void registerVectorizerTestPass();
 } // namespace mlir
@@ -131,6 +132,7 @@ void registerTestPasses() {
   registerTestOpaqueLoc();
   registerTestParallelismDetection();
   registerTestGpuParallelLoopMappingPass();
+  registerTestSCFUtilsPass();
   registerTestVectorConversions();
   registerVectorizerTestPass();
 }


        


More information about the Mlir-commits mailing list