[Mlir-commits] [mlir] aa93659 - [mlir][SCF] Add utility to clone an scf.ForOp while appending new yield values.
Nicolas Vasilache
llvmlistbot at llvm.org
Fri May 29 04:31:29 PDT 2020
Author: Nicolas Vasilache
Date: 2020-05-29T07:28:17-04:00
New Revision: aa93659c9ffcf36dc1b0c70ce5b0f526a2b9c007
URL: https://github.com/llvm/llvm-project/commit/aa93659c9ffcf36dc1b0c70ce5b0f526a2b9c007
DIFF: https://github.com/llvm/llvm-project/commit/aa93659c9ffcf36dc1b0c70ce5b0f526a2b9c007.diff
LOG: [mlir][SCF] Add utility to clone an scf.ForOp while appending new yield values.
This utility factors out the machinery required to add iterArgs and yield values to an scf.ForOp.
Differential Revision: https://reviews.llvm.org/D80656
Added:
mlir/include/mlir/Dialect/SCF/Utils.h
mlir/lib/Dialect/SCF/Transforms/Utils.cpp
mlir/test/Transforms/loop-utils.mlir
mlir/test/lib/Transforms/TestSCFUtils.cpp
Modified:
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
mlir/test/lib/Transforms/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils.h
new file mode 100644
index 000000000000..7f8ebd3a4260
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SCF/Utils.h
@@ -0,0 +1,50 @@
+//===- Utils.h - SCF dialect utilities --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes for various SCF utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SCF_UTILS_H_
+#define MLIR_DIALECT_SCF_UTILS_H_
+
+namespace mlir {
+class OpBuilder;
+class ValueRange;
+
+namespace scf {
+class ForOp;
+class ParallelOp;
+} // end namespace scf
+
+/// Create a clone of `loop` with `newIterOperands` added as new initialization
+/// values and `newYieldedValues` added as new yielded values. The returned
+/// ForOp has `newYieldedValues.size()` new result values. The `loop` induction
+/// variable and `newIterOperands` are remapped to the new induction variable
+/// and the new entry block arguments respectively.
+///
+/// Additionally, if `replaceLoopResults` is true, all uses of
+/// `loop.getResults()` are replaced with the first `loop.getNumResults()`
+/// return values respectively. This additional replacement is provided as a
+/// convenience to update the consumers of `loop`, in the case e.g. when `loop`
+/// is soon to be deleted.
+///
+/// Return the cloned loop.
+///
+/// This convenience function is useful to factorize common mechanisms related
+/// to hoisting roundtrips to memory into yields. It does not perform any
+/// legality checks.
+///
+/// Prerequisite: `newYieldedValues.size() == newYieldedValues.size()`.
+scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
+ ValueRange newIterOperands,
+ ValueRange newYieldedValues,
+ bool replaceLoopResults = true);
+
+} // end namespace mlir
+#endif // MLIR_DIALECT_SCF_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index f2b0265ed74d..58890d4a3782 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
ParallelLoopFusion.cpp
ParallelLoopSpecialization.cpp
ParallelLoopTiling.cpp
+ Utils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
new file mode 100644
index 000000000000..6ae360a34abc
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
@@ -0,0 +1,73 @@
+//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements miscellaneous loop transformation routines.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Utils.h"
+
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+
+using namespace mlir;
+
+scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
+ ValueRange newIterOperands,
+ ValueRange newYieldedValues,
+ bool replaceLoopResults) {
+ assert(newIterOperands.size() == newYieldedValues.size() &&
+ "newIterOperands must be of the same size as newYieldedValues");
+
+ // Create a new loop before the existing one, with the extra operands.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(loop);
+ auto operands = llvm::to_vector<4>(loop.getIterOperands());
+ operands.append(newIterOperands.begin(), newIterOperands.end());
+ scf::ForOp newLoop =
+ b.create<scf::ForOp>(loop.getLoc(), loop.lowerBound(), loop.upperBound(),
+ loop.step(), operands);
+
+ auto &loopBody = *loop.getBody();
+ auto &newLoopBody = *newLoop.getBody();
+ // Clone / erase the yield inside the original loop to both:
+ // 1. augment its operands with the newYieldedValues.
+ // 2. automatically apply the BlockAndValueMapping on its operand
+ auto yield = cast<scf::YieldOp>(loopBody.getTerminator());
+ b.setInsertionPoint(yield);
+ auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
+ yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
+ auto newYield = b.create<scf::YieldOp>(yield.getLoc(), yieldOperands);
+
+ // Clone the loop body with remaps.
+ BlockAndValueMapping bvm;
+ // a. remap the induction variable.
+ bvm.map(loop.getInductionVar(), newLoop.getInductionVar());
+ // b. remap the BB args.
+ bvm.map(loopBody.getArguments(),
+ newLoopBody.getArguments().take_front(loopBody.getNumArguments()));
+ // c. remap the iter args.
+ bvm.map(newIterOperands,
+ newLoop.getRegionIterArgs().take_back(newIterOperands.size()));
+ b.setInsertionPointToStart(&newLoopBody);
+ // Skip the original yield terminator which does not have enough operands.
+ for (auto &o : loopBody.without_terminator())
+ b.clone(o, bvm);
+
+ // Replace `loop`'s results if requested.
+ if (replaceLoopResults) {
+ for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
+ loop.getNumResults())))
+ std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
+ }
+
+ // TODO: this is unsafe in the context of a PatternRewrite.
+ newYield.erase();
+
+ return newLoop;
+}
diff --git a/mlir/test/Transforms/loop-utils.mlir b/mlir/test/Transforms/loop-utils.mlir
new file mode 100644
index 000000000000..3d3dadfba179
--- /dev/null
+++ b/mlir/test/Transforms/loop-utils.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -allow-unregistered-dialect -test-scf-utils -mlir-disable-threading %s | FileCheck %s
+
+// CHECK-LABEL: @hoist
+// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index,
+// CHECK-SAME: %[[ub:[a-zA-Z0-9]*]]: index,
+// CHECK-SAME: %[[step:[a-zA-Z0-9]*]]: index
+func @hoist(%lb: index, %ub: index, %step: index) {
+ // CHECK: %[[A:.*]] = "fake_read"() : () -> index
+ // CHECK: %[[RES:.*]] = scf.for %[[I:.*]] = %[[lb]] to %[[ub]] step %[[step]] iter_args(%[[VAL:.*]] = %[[A]]) -> (index)
+ // CHECK: %[[YIELD:.*]] = "fake_compute"(%[[VAL]]) : (index) -> index
+ // CHECK: scf.yield %[[YIELD]] : index
+ // CHECK: "fake_write"(%[[RES]]) : (index) -> ()
+ scf.for %i = %lb to %ub step %step {
+ %0 = "fake_read"() : () -> (index)
+ %1 = "fake_compute"(%0) : (index) -> (index)
+ "fake_write"(%1) : (index) -> ()
+ }
+ return
+}
+
+// CHECK-LABEL: @hoist2
+// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index,
+// CHECK-SAME: %[[ub:[a-zA-Z0-9]*]]: index,
+// CHECK-SAME: %[[step:[a-zA-Z0-9]*]]: index
+// CHECK-SAME: %[[extra_arg:[a-zA-Z0-9]*]]: f32
+func @hoist2(%lb: index, %ub: index, %step: index, %extra_arg: f32) -> f32 {
+ // CHECK: %[[A:.*]] = "fake_read"() : () -> index
+ // CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[lb]] to %[[ub]] step %[[step]] iter_args(%[[VAL0:.*]] = %[[extra_arg]], %[[VAL1:.*]] = %[[A]]) -> (f32, index)
+ // CHECK: %[[YIELD:.*]] = "fake_compute"(%[[VAL1]]) : (index) -> index
+ // CHECK: scf.yield %[[VAL0]], %[[YIELD]] : f32, index
+ // CHECK: "fake_write"(%[[RES]]#1) : (index) -> ()
+ // CHECK: return %[[RES]]#0 : f32
+ %0 = scf.for %i = %lb to %ub step %step iter_args(%iter = %extra_arg) -> (f32) {
+ %0 = "fake_read"() : () -> (index)
+ %1 = "fake_compute"(%0) : (index) -> (index)
+ "fake_write"(%1) : (index) -> ()
+ scf.yield %iter: f32
+ }
+ return %0: f32
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 6069570316a8..3f2befefe704 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_library(MLIRTestTransforms
TestMemRefBoundCheck.cpp
TestMemRefDependenceCheck.cpp
TestMemRefStrideCalculation.cpp
+ TestSCFUtils.cpp
TestVectorTransforms.cpp
EXCLUDE_FROM_LIBMLIR
@@ -41,6 +42,7 @@ add_mlir_library(MLIRTestTransforms
MLIRLinalgTransforms
MLIRNVVMIR
MLIRSCF
+ MLIRSCFTransforms
MLIRGPU
MLIRPass
MLIRROCDLIR
diff --git a/mlir/test/lib/Transforms/TestSCFUtils.cpp b/mlir/test/lib/Transforms/TestSCFUtils.cpp
new file mode 100644
index 000000000000..ba06bbcc8860
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestSCFUtils.cpp
@@ -0,0 +1,58 @@
+//===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test SCF dialect utils.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Utils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/SetVector.h"
+
+using namespace mlir;
+
+namespace {
+class TestSCFUtilsPass : public PassWrapper<TestSCFUtilsPass, FunctionPass> {
+public:
+ explicit TestSCFUtilsPass() {}
+
+ void runOnFunction() override {
+ FuncOp func = getFunction();
+ SmallVector<scf::ForOp, 4> toErase;
+
+ func.walk([&](Operation *fakeRead) {
+ if (fakeRead->getName().getStringRef() != "fake_read")
+ return;
+ auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
+ auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
+ auto loop = fakeRead->getParentOfType<scf::ForOp>();
+
+ OpBuilder b(loop);
+ loop.moveOutOfLoop({fakeRead});
+ fakeWrite->moveAfter(loop);
+ auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
+ fakeCompute->getResult(0));
+ fakeCompute->getResult(0).replaceAllUsesWith(
+ newLoop.getResults().take_back()[0]);
+ toErase.push_back(loop);
+ });
+ for (auto loop : llvm::reverse(toErase))
+ loop.erase();
+ }
+};
+} // end namespace
+
+namespace mlir {
+void registerTestSCFUtilsPass() {
+ PassRegistration<TestSCFUtilsPass>("test-scf-utils", "test scf utils");
+}
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2764b23b7b35..165d6ad159ac 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -61,6 +61,7 @@ void registerTestMemRefStrideCalculation();
void registerTestOpaqueLoc();
void registerTestParallelismDetection();
void registerTestGpuParallelLoopMappingPass();
+void registerTestSCFUtilsPass();
void registerTestVectorConversions();
void registerVectorizerTestPass();
} // namespace mlir
@@ -131,6 +132,7 @@ void registerTestPasses() {
registerTestOpaqueLoc();
registerTestParallelismDetection();
registerTestGpuParallelLoopMappingPass();
+ registerTestSCFUtilsPass();
registerTestVectorConversions();
registerVectorizerTestPass();
}
More information about the Mlir-commits
mailing list