[llvm-branch-commits] [mlir] 5522547 - [mlir][linalg] Support permutation when lowering to loop nests
Lei Zhang via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Jan 11 06:17:57 PST 2021
Author: Lei Zhang
Date: 2021-01-11T09:13:06-05:00
New Revision: 55225471d9838e452cfb31e0edae6162b7226221
URL: https://github.com/llvm/llvm-project/commit/55225471d9838e452cfb31e0edae6162b7226221
DIFF: https://github.com/llvm/llvm-project/commit/55225471d9838e452cfb31e0edae6162b7226221.diff
LOG: [mlir][linalg] Support permutation when lowering to loop nests
Linalg ops are perfect loop nests. When materializing the concrete
loop nest, the default order specified by the Linalg op's iterators
may not be the best for further CodeGen: targets frequently need
to plan the loop order in order to gain better data access. And
different targets can have different preferences. So there should
exist a way to control the order.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D91795
Added:
mlir/test/Dialect/Linalg/loop-order.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 14f845589a6f..a20289af3054 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -28,8 +28,8 @@ def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
let options = [
Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool",
/*default=*/"false",
- "Only folds the one-trip loops from Linalg ops on tensors "
- "(for testing purposes only)">
+ "Only folds the one-trip loops from Linalg ops on tensors "
+ "(for testing purposes only)">
];
let dependentDialects = ["linalg::LinalgDialect"];
}
@@ -52,12 +52,24 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
let summary = "Lower the operations from the linalg dialect into affine "
"loops";
let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
+ let options = [
+ ListOption<"interchangeVector", "interchange-vector", "unsigned",
+ "Permute the loops in the nest following the given "
+ "interchange vector",
+ "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
+ ];
let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
}
def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
let summary = "Lower the operations from the linalg dialect into loops";
let constructor = "mlir::createConvertLinalgToLoopsPass()";
+ let options = [
+ ListOption<"interchangeVector", "interchange-vector", "unsigned",
+ "Permute the loops in the nest following the given "
+ "interchange vector",
+ "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
+ ];
let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect", "AffineDialect"];
}
@@ -72,6 +84,12 @@ def LinalgLowerToParallelLoops
let summary = "Lower the operations from the linalg dialect into parallel "
"loops";
let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
+ let options = [
+ ListOption<"interchangeVector", "interchange-vector", "unsigned",
+ "Permute the loops in the nest following the given "
+ "interchange vector",
+ "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
+ ];
let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"];
}
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index dc82569aac38..d816414ef8b4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -267,16 +267,28 @@ void vectorizeLinalgOp(OpBuilder &builder, Operation *op);
/// Emits a loop nest of `LoopTy` with the proper body for `op`.
template <typename LoopTy>
-Optional<LinalgLoops> linalgLowerOpToLoops(OpBuilder &builder, Operation *op);
-
-/// Emits a loop nest of `scf.for` with the proper body for `op`.
-LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op);
-
-/// Emits a loop nest of `scf.parallel` with the proper body for `op`.
-LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op);
+Optional<LinalgLoops>
+linalgLowerOpToLoops(OpBuilder &builder, Operation *op,
+ ArrayRef<unsigned> interchangeVector = {});
+
+/// Emits a loop nest of `scf.for` with the proper body for `op`. The generated
+/// loop nest will follow the `interchangeVector`-permutated iterator order. If
+/// `interchangeVector` is empty, then no permutation happens.
+LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op,
+ ArrayRef<unsigned> interchangeVector = {});
+
+/// Emits a loop nest of `scf.parallel` with the proper body for `op`. The
+/// generated loop nest will follow the `interchangeVector`-permutated
+// iterator order. If `interchangeVector` is empty, then no permutation happens.
+LogicalResult
+linalgOpToParallelLoops(OpBuilder &builder, Operation *op,
+ ArrayRef<unsigned> interchangeVector = {});
-/// Emits a loop nest of `affine.for` with the proper body for `op`.
-LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op);
+/// Emits a loop nest of `affine.for` with the proper body for `op`. The
+/// generated loop nest will follow the `interchangeVector`-permutated
+// iterator order. If `interchangeVector` is empty, then no permutation happens.
+LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op,
+ ArrayRef<unsigned> interchangeVector = {});
//===----------------------------------------------------------------------===//
// Preconditions that ensure the corresponding transformation succeeds and can
@@ -587,13 +599,17 @@ enum class LinalgLoweringType {
AffineLoops = 2,
ParallelLoops = 3
};
+
template <typename OpTy>
struct LinalgLoweringPattern : public RewritePattern {
LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType,
LinalgMarker marker = LinalgMarker(),
+ ArrayRef<unsigned> interchangeVector = {},
PatternBenefit benefit = 1)
: RewritePattern(OpTy::getOperationName(), {}, benefit, context),
- marker(marker), loweringType(loweringType) {}
+ marker(marker), loweringType(loweringType),
+ interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
+
// TODO: Move implementation to .cpp once named ops are auto-generated.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
@@ -603,18 +619,24 @@ struct LinalgLoweringPattern : public RewritePattern {
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
- if (loweringType == LinalgLoweringType::LibraryCall) {
+ switch (loweringType) {
+ case LinalgLoweringType::LibraryCall:
// TODO: Move lowering to library calls here.
return failure();
- } else if (loweringType == LinalgLoweringType::Loops) {
- if (failed(linalgOpToLoops(rewriter, op)))
+ case LinalgLoweringType::Loops:
+ if (failed(linalgOpToLoops(rewriter, op, interchangeVector)))
return failure();
- } else if (loweringType == LinalgLoweringType::AffineLoops) {
- if (failed(linalgOpToAffineLoops(rewriter, op)))
+ break;
+ case LinalgLoweringType::AffineLoops:
+ if (failed(linalgOpToAffineLoops(rewriter, op, interchangeVector)))
return failure();
- } else if (failed(linalgOpToParallelLoops(rewriter, op))) {
- return failure();
+ break;
+ case LinalgLoweringType::ParallelLoops:
+ if (failed(linalgOpToParallelLoops(rewriter, op, interchangeVector)))
+ return failure();
+ break;
}
+
rewriter.eraseOp(op);
return success();
}
@@ -625,6 +647,8 @@ struct LinalgLoweringPattern : public RewritePattern {
/// Controls whether the pattern lowers to library calls, scf.for, affine.for
/// or scf.parallel.
LinalgLoweringType loweringType;
+ /// Permutated loop order in the generated loop nest.
+ SmallVector<unsigned, 4> interchangeVector;
};
/// Linalg generalization patterns
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 3a5b79176959..09b5c5ee562b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -23,7 +23,6 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -505,10 +504,10 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
}
template <typename LoopTy>
-static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
- OpBuilder &builder) {
+static Optional<LinalgLoops>
+linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
+ ArrayRef<unsigned> interchangeVector) {
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
-
ScopedContext scope(builder, op->getLoc());
// The flattened loopToOperandRangesMaps is expected to be an invertible
@@ -516,10 +515,20 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
auto linalgOp = cast<LinalgOp>(op);
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
+
auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc());
+ auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
+
+ if (!interchangeVector.empty()) {
+ assert(interchangeVector.size() == loopRanges.size());
+ assert(interchangeVector.size() == iteratorTypes.size());
+ applyPermutationToVector(loopRanges, interchangeVector);
+ applyPermutationToVector(iteratorTypes, interchangeVector);
+ }
+
SmallVector<Value, 4> allIvs;
GenerateLoopNest<LoopTy>::doit(
- loopRanges, /*iterInitArgs*/ {}, linalgOp.iterator_types().getValue(),
+ loopRanges, /*iterInitArgs=*/{}, iteratorTypes,
[&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
assert(iterArgs.empty() && "unexpected iterArgs");
allIvs.append(ivs.begin(), ivs.end());
@@ -552,26 +561,33 @@ namespace {
template <typename LoopType>
class LinalgRewritePattern : public RewritePattern {
public:
- LinalgRewritePattern() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
+ LinalgRewritePattern(ArrayRef<unsigned> interchangeVector)
+ : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()),
+ interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (!isa<LinalgOp>(op))
return failure();
- if (!linalgOpToLoopsImpl<LoopType>(op, rewriter))
+ if (!linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector))
return failure();
rewriter.eraseOp(op);
return success();
}
+
+private:
+ SmallVector<unsigned, 4> interchangeVector;
};
struct FoldAffineOp;
} // namespace
template <typename LoopType>
-static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
+static void lowerLinalgToLoopsImpl(FuncOp funcOp,
+ ArrayRef<unsigned> interchangeVector) {
+ MLIRContext *context = funcOp.getContext();
OwningRewritePatternList patterns;
- patterns.insert<LinalgRewritePattern<LoopType>>();
+ patterns.insert<LinalgRewritePattern<LoopType>>(interchangeVector);
DimOp::getCanonicalizationPatterns(patterns, context);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldAffineOp>(context);
@@ -620,20 +636,20 @@ struct FoldAffineOp : public RewritePattern {
struct LowerToAffineLoops
: public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
void runOnFunction() override {
- lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext());
+ lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), interchangeVector);
}
};
struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
void runOnFunction() override {
- lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext());
+ lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), interchangeVector);
}
};
struct LowerToParallelLoops
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
void runOnFunction() override {
- lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), &getContext());
+ lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), interchangeVector);
}
};
} // namespace
@@ -654,38 +670,43 @@ mlir::createConvertLinalgToAffineLoopsPass() {
/// Emits a loop nest with the proper body for `op`.
template <typename LoopTy>
-Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
- Operation *op) {
- return linalgOpToLoopsImpl<LoopTy>(op, builder);
+Optional<LinalgLoops>
+mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, Operation *op,
+ ArrayRef<unsigned> interchangeVector) {
+ return linalgOpToLoopsImpl<LoopTy>(op, builder, interchangeVector);
}
+template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<AffineForOp>(
+ OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector);
+template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(
+ OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector);
template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<AffineForOp>(OpBuilder &builder,
- Operation *op);
-template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(OpBuilder &builder,
- Operation *op);
-template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(OpBuilder &builder,
- Operation *op);
+mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(
+ OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector);
/// Emits a loop nest of `affine.for` with the proper body for `op`.
-LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder,
- Operation *op) {
- Optional<LinalgLoops> loops = linalgLowerOpToLoops<AffineForOp>(builder, op);
+LogicalResult
+mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op,
+ ArrayRef<unsigned> interchangeVector) {
+ Optional<LinalgLoops> loops =
+ linalgLowerOpToLoops<AffineForOp>(builder, op, interchangeVector);
return loops ? success() : failure();
}
/// Emits a loop nest of `scf.for` with the proper body for `op`.
-LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) {
- Optional<LinalgLoops> loops = linalgLowerOpToLoops<scf::ForOp>(builder, op);
+LogicalResult
+mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op,
+ ArrayRef<unsigned> interchangeVector) {
+ Optional<LinalgLoops> loops =
+ linalgLowerOpToLoops<scf::ForOp>(builder, op, interchangeVector);
return loops ? success() : failure();
}
/// Emits a loop nest of `scf.parallel` with the proper body for `op`.
-LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder,
- Operation *op) {
+LogicalResult
+mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op,
+ ArrayRef<unsigned> interchangeVector) {
Optional<LinalgLoops> loops =
- linalgLowerOpToLoops<scf::ParallelOp>(builder, op);
+ linalgLowerOpToLoops<scf::ParallelOp>(builder, op, interchangeVector);
return loops ? success() : failure();
}
diff --git a/mlir/test/Dialect/Linalg/loop-order.mlir b/mlir/test/Dialect/Linalg/loop-order.mlir
new file mode 100644
index 000000000000..d1ff47977c35
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/loop-order.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=LOOP %s
+// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=PARALLEL %s
+// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=AFFINE %s
+
+func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) {
+ linalg.copy(%input, %output): memref<1x2x3x4x5xf32>, memref<1x2x3x4x5xf32>
+ return
+}
+
+// LOOP: scf.for %{{.*}} = %c0 to %c5 step %c1
+// LOOP: scf.for %{{.*}} = %c0 to %c1 step %c1
+// LOOP: scf.for %{{.*}} = %c0 to %c4 step %c1
+// LOOP: scf.for %{{.*}} = %c0 to %c2 step %c1
+// LOOP: scf.for %{{.*}} = %c0 to %c3 step %c1
+
+// PARALLEL: scf.parallel
+// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3)
+
+// AFFINE: affine.for %{{.*}} = 0 to 5
+// AFFINE: affine.for %{{.*}} = 0 to 1
+// AFFINE: affine.for %{{.*}} = 0 to 4
+// AFFINE: affine.for %{{.*}} = 0 to 2
+// AFFINE: affine.for %{{.*}} = 0 to 3
+
More information about the llvm-branch-commits
mailing list