[Mlir-commits] [mlir] 281ee42 - [mlir] Add a pass to distribute linalg::TiledLoopOp.
Alexander Belyaev
llvmlistbot at llvm.org
Wed May 26 23:45:37 PDT 2021
Author: Alexander Belyaev
Date: 2021-05-27T08:45:20+02:00
New Revision: 281ee4291110af5d1337d1da819a284eecf368ec
URL: https://github.com/llvm/llvm-project/commit/281ee4291110af5d1337d1da819a284eecf368ec
DIFF: https://github.com/llvm/llvm-project/commit/281ee4291110af5d1337d1da819a284eecf368ec.diff
LOG: [mlir] Add a pass to distribute linalg::TiledLoopOp.
Differential Revision: https://reviews.llvm.org/D103194
Added:
mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp
mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index e3469039ffa6f..d6cb5cb0e39b5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -860,6 +860,13 @@ void populateLinalgConvGeneralizationPatterns(
RewritePatternSet &patterns,
LinalgTransformationFilter filter = LinalgTransformationFilter());
+/// Linalg distribution patterns
+//
+/// Populates `patterns` with patterns to distribute linalg.tiled_loop.
+void populateLinalgDistributeTiledLoopPattern(
+ RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts,
+ const LinalgTransformationFilter &marker);
+
//===----------------------------------------------------------------------===//
// Op-specific patterns.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 03728e3ea03c9..55da21d4c5b3e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -184,6 +184,8 @@ struct ProcInfo {
};
using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo, 2>(
OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges)>;
+using OneDimProcInfoCallBackFn =
+ std::function<ProcInfo(OpBuilder &b, Location loc)>;
/// Options that allow distribution of loops generated in Linalg transforms to
/// processors while generating the loops.
@@ -201,6 +203,11 @@ struct LinalgLoopDistributionOptions {
/// applied. If the vector is less than the number of `scf.parallel` loops
/// generated, then no distribution is applied.
SmallVector<DistributionMethod, 0> distributionMethod = {};
+
+ /// The map keyed by the distribution type that contains callback functions
+ /// that return the Values for processor ID (`procId`), and number of
+ /// processors (`nprocs`) used to execute the parallel loops.
+ DenseMap<StringRef, OneDimProcInfoCallBackFn> procInfoMap;
};
/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 1458c94fc905d..d954db975d108 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
CodegenStrategy.cpp
ComprehensiveBufferize.cpp
Detensorize.cpp
+ Distribution.cpp
DropUnitDims.cpp
ElementwiseToLinalg.cpp
Fusion.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp b/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp
new file mode 100644
index 0000000000000..994f7c76ddfda
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp
@@ -0,0 +1,85 @@
+//===- Distibution.cpp - linalg named ops to generic ops --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the Linalg distibution pass. It updates `tiled_loop`
+// control variables depending on the distribution type.
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#define DEBUG_TYPE "linalg-distribution"
+
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+
+struct DistributeTiledLoopPattern
+ : public OpRewritePattern<linalg::TiledLoopOp> {
+ DistributeTiledLoopPattern(MLIRContext *context,
+ LinalgLoopDistributionOptions options,
+ LinalgTransformationFilter marker)
+ : OpRewritePattern<linalg::TiledLoopOp>(context), options(options),
+ marker(marker) {}
+ LogicalResult matchAndRewrite(linalg::TiledLoopOp op,
+ PatternRewriter &rewriter) const override {
+ if (failed(marker.checkAndNotify(rewriter, op)))
+ return failure();
+ if (!op.distribution_types().hasValue())
+ return failure();
+
+ Location loc = op.getLoc();
+ SmallVector<Value, 2> newLowerBounds = op.lowerBound();
+ SmallVector<Value, 2> newUpperBounds = op.upperBound();
+ SmallVector<Value, 2> newSteps = op.step();
+
+ // Update bounds and steps.
+ auto distributionTypes = op.distribution_types().getValue();
+ for (int i = 0, e = op.getNumLoops(); i < e; ++i) {
+ StringRef type = distributionTypes[i].cast<StringAttr>().getValue();
+ auto procInfoCallback = options.procInfoMap.find(type);
+ if (procInfoCallback == options.procInfoMap.end())
+ continue;
+
+ if (!isParallelIteratorType(op.iterator_types()[i])) {
+ op.emitOpError("only support for parallel loops is implemented");
+ return failure();
+ }
+ ProcInfo info = procInfoCallback->second(rewriter, loc);
+ updateBoundsForCyclicDistribution(rewriter, loc, info.procId, info.nprocs,
+ newLowerBounds[i], newUpperBounds[i],
+ newSteps[i]);
+ }
+ rewriter.updateRootInPlace(op, [&] {
+ op.setLowerBounds(newLowerBounds);
+ op.setUpperBounds(newUpperBounds);
+ op.setSteps(newSteps);
+ });
+ marker.replaceLinalgTransformationFilter(rewriter, op);
+ return success();
+ }
+
+private:
+ LinalgLoopDistributionOptions options;
+ LinalgTransformationFilter marker;
+};
+
+} // namespace
+
+void mlir::linalg::populateLinalgDistributeTiledLoopPattern(
+ RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts,
+ const LinalgTransformationFilter &marker) {
+ patterns.add<DistributeTiledLoopPattern>(patterns.getContext(), opts, marker);
+}
diff --git a/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir b/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir
new file mode 100644
index 0000000000000..564db5ab4fbe7
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -test-linalg-distribution %s | FileCheck %s
+
+func private @foo(%A: tensor<64x64xf32>,
+ %B: tensor<64x64xf32>) -> tensor<64x64xf32>
+
+func @distribute_for_gpu(%A: tensor<64x64xf32>,
+ %B: tensor<64x64xf32>) -> tensor<64x64xf32> {
+ %c0 = constant 0 : index
+ %c16 = constant 16 : index
+ %c64 = constant 64 : index
+ %c24 = constant 24 : index
+ %0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c64, %c64) step (%c24, %c16)
+ ins (%A_ = %A: tensor<64x64xf32>) outs (%B_ = %B:tensor<64x64xf32>)
+ distribution ["block_x", "block_y"] {
+ %0 = call @foo(%A_, %B_)
+ : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32>
+ linalg.yield %0 : tensor<64x64xf32>
+ }
+ return %0 : tensor<64x64xf32>
+}
+
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 * 24)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 * 16)>
+
+// CHECK-LABEL: func @distribute_for_gpu
+// CHECK: %[[C64:.*]] = constant 64 : index
+
+// CHECK-DAG: %[[GPU_BLOCK_X:.*]] = "gpu.block_id"() {dimension = "x"}
+// CHECK-DAG: %[[GPU_GRID_DIM_X:.*]] = "gpu.grid_dim"() {dimension = "x"}
+// CHECK-DAG: %[[LB_I:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[GPU_BLOCK_X]]]
+// CHECK-DAG: %[[STEP_I:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[GPU_GRID_DIM_X]]]
+
+// CHECK-DAG: %[[GPU_BLOCK_Y:.*]] = "gpu.block_id"() {dimension = "y"}
+// CHECK-DAG: %[[GPU_GRID_DIM_Y:.*]] = "gpu.grid_dim"() {dimension = "y"}
+// CHECK-DAG: %[[LB_J:.*]] = affine.apply #[[$MAP1]](){{\[}}%[[GPU_BLOCK_Y]]]
+// CHECK-DAG: %[[STEP_J:.*]] = affine.apply #[[$MAP1]](){{\[}}%[[GPU_GRID_DIM_Y]]]
+
+// CHECK: linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) = (%[[LB_I]], %[[LB_J]])
+// CHECK-SAME: to (%[[C64]], %[[C64]]) step (%[[STEP_I]], %[[STEP_J]])
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp
new file mode 100644
index 0000000000000..224d8ca164723
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp
@@ -0,0 +1,79 @@
+//===- TestLinalgDistribution.cpp - Test Linalg hoisting functions --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements logic for testing Linalg hoisting functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+template <char dim>
+static linalg::ProcInfo getGpuBlockInfo(OpBuilder &b, Location loc) {
+ std::string d(1, dim);
+ StringAttr attr = b.getStringAttr(d);
+
+ Type indexType = b.getIndexType();
+ ProcInfo procInfo = {b.create<gpu::BlockIdOp>(loc, indexType, attr),
+ b.create<gpu::GridDimOp>(loc, indexType, attr)};
+ return procInfo;
+}
+
+static LinalgLoopDistributionOptions getDistributionOptions() {
+ LinalgLoopDistributionOptions opts;
+ opts.procInfoMap.insert(std::make_pair("block_x", getGpuBlockInfo<'x'>));
+ opts.procInfoMap.insert(std::make_pair("block_y", getGpuBlockInfo<'y'>));
+ return opts;
+}
+
+namespace {
+struct TestLinalgDistribution
+ : public PassWrapper<TestLinalgDistribution, FunctionPass> {
+ TestLinalgDistribution() = default;
+ TestLinalgDistribution(const TestLinalgDistribution &pass) {}
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect, gpu::GPUDialect>();
+ }
+
+ void runOnFunction() override;
+};
+} // namespace
+
+void TestLinalgDistribution::runOnFunction() {
+ auto funcOp = getFunction();
+ OwningRewritePatternList distributeTiledLoopsPatterns(&getContext());
+ populateLinalgDistributeTiledLoopPattern(
+ distributeTiledLoopsPatterns, getDistributionOptions(),
+ LinalgTransformationFilter(
+ ArrayRef<Identifier>{},
+ {Identifier::get("distributed", funcOp.getContext())})
+ .addFilter([](Operation *op) {
+ return success(!op->getParentOfType<linalg::TiledLoopOp>());
+ }));
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(distributeTiledLoopsPatterns));
+ // Ensure we drop the marker in the end.
+ funcOp.walk([](LinalgOp op) {
+ op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
+ });
+}
+
+namespace mlir {
+namespace test {
+void registerTestLinalgDistribution() {
+ PassRegistration<TestLinalgDistribution> testTestLinalgDistributionPass(
+ "test-linalg-distribution", "Test Linalg distribution.");
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 23bfe775cae2c..c2966e623f5cc 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -77,6 +77,7 @@ void registerTestGpuParallelLoopMappingPass();
void registerTestIRVisitorsPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
+void registerTestLinalgDistribution();
void registerTestLinalgElementwiseFusion();
void registerTestPushExpandingReshape();
void registerTestLinalgFusionTransforms();
@@ -156,6 +157,7 @@ void registerTestPasses() {
test::registerTestIRVisitorsPass();
test::registerTestInterfaces();
test::registerTestLinalgCodegenStrategy();
+ test::registerTestLinalgDistribution();
test::registerTestLinalgElementwiseFusion();
test::registerTestPushExpandingReshape();
test::registerTestLinalgFusionTransforms();
More information about the Mlir-commits
mailing list