[Mlir-commits] [mlir] 0e777e4 - [mlir][linalg] remove interchange option on linalg to loop lowering.
Tobias Gysi
llvmlistbot at llvm.org
Thu Apr 22 01:56:19 PDT 2021
Author: Tobias Gysi
Date: 2021-04-22T08:55:17Z
New Revision: 0e777e4ad7d554436a1c181674bdbaeab9053c31
URL: https://github.com/llvm/llvm-project/commit/0e777e4ad7d554436a1c181674bdbaeab9053c31
DIFF: https://github.com/llvm/llvm-project/commit/0e777e4ad7d554436a1c181674bdbaeab9053c31.diff
LOG: [mlir][linalg] remove interchange option on linalg to loop lowering.
The interchange option attached to the linalg to loop lowering affects only the loops and does not update the memory accesses generated in to body of the operation. Instead of performing the interchange during the loop lowering use the interchange pattern.
Differential Revision: https://reviews.llvm.org/D100758
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
Removed:
mlir/test/Dialect/Linalg/loop-order.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 344ffe977caf..8d411d5964c5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -62,12 +62,6 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
let summary = "Lower the operations from the linalg dialect into affine "
"loops";
let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
- let options = [
- ListOption<"interchangeVector", "interchange-vector", "unsigned",
- "Permute the loops in the nest following the given "
- "interchange vector",
- "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
- ];
let dependentDialects = [
"AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"];
}
@@ -75,12 +69,6 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
let summary = "Lower the operations from the linalg dialect into loops";
let constructor = "mlir::createConvertLinalgToLoopsPass()";
- let options = [
- ListOption<"interchangeVector", "interchange-vector", "unsigned",
- "Permute the loops in the nest following the given "
- "interchange vector",
- "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
- ];
let dependentDialects = [
"linalg::LinalgDialect",
"scf::SCFDialect",
@@ -103,12 +91,6 @@ def LinalgLowerToParallelLoops
let summary = "Lower the operations from the linalg dialect into parallel "
"loops";
let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
- let options = [
- ListOption<"interchangeVector", "interchange-vector", "unsigned",
- "Permute the loops in the nest following the given "
- "interchange vector",
- "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
- ];
let dependentDialects = [
"AffineDialect",
"linalg::LinalgDialect",
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 251a2f8e6d03..2338198b5f2e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -338,28 +338,16 @@ LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
/// Emits a loop nest of `LoopTy` with the proper body for `op`.
template <typename LoopTy>
-Optional<LinalgLoops>
-linalgLowerOpToLoops(OpBuilder &builder, Operation *op,
- ArrayRef<unsigned> interchangeVector = {});
-
-/// Emits a loop nest of `scf.for` with the proper body for `op`. The generated
-/// loop nest will follow the `interchangeVector`-permutated iterator order. If
-/// `interchangeVector` is empty, then no permutation happens.
-LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op,
- ArrayRef<unsigned> interchangeVector = {});
-
-/// Emits a loop nest of `scf.parallel` with the proper body for `op`. The
-/// generated loop nest will follow the `interchangeVector`-permutated
-// iterator order. If `interchangeVector` is empty, then no permutation happens.
-LogicalResult
-linalgOpToParallelLoops(OpBuilder &builder, Operation *op,
- ArrayRef<unsigned> interchangeVector = {});
+Optional<LinalgLoops> linalgLowerOpToLoops(OpBuilder &builder, Operation *op);
+
+/// Emits a loop nest of `scf.for` with the proper body for `op`.
+LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op);
-/// Emits a loop nest of `affine.for` with the proper body for `op`. The
-/// generated loop nest will follow the `interchangeVector`-permutated
-// iterator order. If `interchangeVector` is empty, then no permutation happens.
-LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op,
- ArrayRef<unsigned> interchangeVector = {});
+/// Emits a loop nest of `scf.parallel` with the proper body for `op`.
+LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op);
+
+/// Emits a loop nest of `affine.for` with the proper body for `op`.
+LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op);
//===----------------------------------------------------------------------===//
// Preconditions that ensure the corresponding transformation succeeds and can
@@ -808,10 +796,9 @@ struct LinalgLoweringPattern : public RewritePattern {
LinalgLoweringPattern(
MLIRContext *context, LinalgLoweringType loweringType,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
- ArrayRef<unsigned> interchangeVector = {}, PatternBenefit benefit = 1)
+ PatternBenefit benefit = 1)
: RewritePattern(OpTy::getOperationName(), benefit, context),
- filter(filter), loweringType(loweringType),
- interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
+ filter(filter), loweringType(loweringType) {}
// TODO: Move implementation to .cpp once named ops are auto-generated.
LogicalResult matchAndRewrite(Operation *op,
@@ -827,15 +814,15 @@ struct LinalgLoweringPattern : public RewritePattern {
// TODO: Move lowering to library calls here.
return failure();
case LinalgLoweringType::Loops:
- if (failed(linalgOpToLoops(rewriter, op, interchangeVector)))
+ if (failed(linalgOpToLoops(rewriter, op)))
return failure();
break;
case LinalgLoweringType::AffineLoops:
- if (failed(linalgOpToAffineLoops(rewriter, op, interchangeVector)))
+ if (failed(linalgOpToAffineLoops(rewriter, op)))
return failure();
break;
case LinalgLoweringType::ParallelLoops:
- if (failed(linalgOpToParallelLoops(rewriter, op, interchangeVector)))
+ if (failed(linalgOpToParallelLoops(rewriter, op)))
return failure();
break;
}
@@ -850,8 +837,6 @@ struct LinalgLoweringPattern : public RewritePattern {
/// Controls whether the pattern lowers to library calls, scf.for, affine.for
/// or scf.parallel.
LinalgLoweringType loweringType;
- /// Permutated loop order in the generated loop nest.
- SmallVector<unsigned, 4> interchangeVector;
};
/// Linalg generalization patterns
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index c85f4a9abd38..f19493c3cca9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -457,9 +457,8 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
}
template <typename LoopTy>
-static Optional<LinalgLoops>
-linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
- ArrayRef<unsigned> interchangeVector) {
+static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
+ OpBuilder &builder) {
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
ScopedContext scope(builder, op->getLoc());
@@ -472,13 +471,6 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc());
auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
- if (!interchangeVector.empty()) {
- assert(interchangeVector.size() == loopRanges.size());
- assert(interchangeVector.size() == iteratorTypes.size());
- applyPermutationToVector(loopRanges, interchangeVector);
- applyPermutationToVector(iteratorTypes, interchangeVector);
- }
-
SmallVector<Value, 4> allIvs;
GenerateLoopNest<LoopTy>::doit(
loopRanges, /*iterInitArgs=*/{}, iteratorTypes,
@@ -511,11 +503,10 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
}
/// Replace the index operations in the body of the loop nest by the matching
-/// induction variables. If available use the interchange vector to map the
-/// interchanged induction variables to the dimension of the index operation.
-static void replaceIndexOpsByInductionVariables(
- LinalgOp linalgOp, PatternRewriter &rewriter, ArrayRef<Operation *> loopOps,
- ArrayRef<unsigned> interchangeVector) {
+/// induction variables.
+static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
+ PatternRewriter &rewriter,
+ ArrayRef<Operation *> loopOps) {
// Extract the induction variables of the loop nest from outer to inner.
SmallVector<Value> allIvs;
for (Operation *loopOp : loopOps) {
@@ -538,16 +529,8 @@ static void replaceIndexOpsByInductionVariables(
if (!loopOps.empty()) {
LoopLikeOpInterface loopOp = loopOps.back();
for (IndexOp indexOp :
- llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>())) {
- // Search the indexing dimension in the interchange vector if available.
- assert(interchangeVector.empty() ||
- interchangeVector.size() == linalgOp.getNumLoops());
- const auto *it = llvm::find(interchangeVector, indexOp.dim());
- uint64_t dim = it != interchangeVector.end()
- ? std::distance(interchangeVector.begin(), it)
- : indexOp.dim();
- rewriter.replaceOp(indexOp, allIvs[dim]);
- }
+ llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
+ rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]);
}
}
@@ -555,39 +538,31 @@ namespace {
template <typename LoopType>
class LinalgRewritePattern : public RewritePattern {
public:
- LinalgRewritePattern(MLIRContext *context,
- ArrayRef<unsigned> interchangeVector)
- : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
- interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
+ LinalgRewritePattern(MLIRContext *context)
+ : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto linalgOp = dyn_cast<LinalgOp>(op);
if (!isa<LinalgOp>(op))
return failure();
- Optional<LinalgLoops> loopOps =
- linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector);
+ Optional<LinalgLoops> loopOps = linalgOpToLoopsImpl<LoopType>(op, rewriter);
if (!loopOps.hasValue())
return failure();
- replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue(),
- interchangeVector);
+ replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue());
rewriter.eraseOp(op);
return success();
}
-
-private:
- SmallVector<unsigned, 4> interchangeVector;
};
struct FoldAffineOp;
} // namespace
template <typename LoopType>
-static void lowerLinalgToLoopsImpl(FuncOp funcOp,
- ArrayRef<unsigned> interchangeVector) {
+static void lowerLinalgToLoopsImpl(FuncOp funcOp) {
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
- patterns.add<LinalgRewritePattern<LoopType>>(context, interchangeVector);
+ patterns.add<LinalgRewritePattern<LoopType>>(context);
memref::DimOp::getCanonicalizationPatterns(patterns, context);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.add<FoldAffineOp>(context);
@@ -639,7 +614,7 @@ struct LowerToAffineLoops
registry.insert<memref::MemRefDialect>();
}
void runOnFunction() override {
- lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), interchangeVector);
+ lowerLinalgToLoopsImpl<AffineForOp>(getFunction());
}
};
@@ -648,14 +623,14 @@ struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
registry.insert<memref::MemRefDialect, scf::SCFDialect>();
}
void runOnFunction() override {
- lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), interchangeVector);
+ lowerLinalgToLoopsImpl<scf::ForOp>(getFunction());
}
};
struct LowerToParallelLoops
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
void runOnFunction() override {
- lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), interchangeVector);
+ lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction());
}
};
} // namespace
@@ -676,43 +651,38 @@ mlir::createConvertLinalgToAffineLoopsPass() {
/// Emits a loop nest with the proper body for `op`.
template <typename LoopTy>
-Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, Operation *op,
- ArrayRef<unsigned> interchangeVector) {
- return linalgOpToLoopsImpl<LoopTy>(op, builder, interchangeVector);
+Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
+ Operation *op) {
+ return linalgOpToLoopsImpl<LoopTy>(op, builder);
}
-template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<AffineForOp>(
- OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector);
-template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(
- OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector);
template Optional<LinalgLoops>
-mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(
- OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector);
+mlir::linalg::linalgLowerOpToLoops<AffineForOp>(OpBuilder &builder,
+ Operation *op);
+template Optional<LinalgLoops>
+mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(OpBuilder &builder,
+ Operation *op);
+template Optional<LinalgLoops>
+mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(OpBuilder &builder,
+ Operation *op);
/// Emits a loop nest of `affine.for` with the proper body for `op`.
-LogicalResult
-mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op,
- ArrayRef<unsigned> interchangeVector) {
- Optional<LinalgLoops> loops =
- linalgLowerOpToLoops<AffineForOp>(builder, op, interchangeVector);
+LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder,
+ Operation *op) {
+ Optional<LinalgLoops> loops = linalgLowerOpToLoops<AffineForOp>(builder, op);
return loops ? success() : failure();
}
/// Emits a loop nest of `scf.for` with the proper body for `op`.
-LogicalResult
-mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op,
- ArrayRef<unsigned> interchangeVector) {
- Optional<LinalgLoops> loops =
- linalgLowerOpToLoops<scf::ForOp>(builder, op, interchangeVector);
+LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) {
+ Optional<LinalgLoops> loops = linalgLowerOpToLoops<scf::ForOp>(builder, op);
return loops ? success() : failure();
}
/// Emits a loop nest of `scf.parallel` with the proper body for `op`.
-LogicalResult
-mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op,
- ArrayRef<unsigned> interchangeVector) {
+LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder,
+ Operation *op) {
Optional<LinalgLoops> loops =
- linalgLowerOpToLoops<scf::ParallelOp>(builder, op, interchangeVector);
+ linalgLowerOpToLoops<scf::ParallelOp>(builder, op);
return loops ? success() : failure();
}
diff --git a/mlir/test/Dialect/Linalg/loop-order.mlir b/mlir/test/Dialect/Linalg/loop-order.mlir
deleted file mode 100644
index c572967e6d10..000000000000
--- a/mlir/test/Dialect/Linalg/loop-order.mlir
+++ /dev/null
@@ -1,72 +0,0 @@
-// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=LOOP %s
-// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=PARALLEL %s
-// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=AFFINE %s
-
-func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) {
- linalg.copy(%input, %output): memref<1x2x3x4x5xf32>, memref<1x2x3x4x5xf32>
- return
-}
-
-// LOOP: scf.for %{{.*}} = %c0 to %c5 step %c1
-// LOOP: scf.for %{{.*}} = %c0 to %c1 step %c1
-// LOOP: scf.for %{{.*}} = %c0 to %c4 step %c1
-// LOOP: scf.for %{{.*}} = %c0 to %c2 step %c1
-// LOOP: scf.for %{{.*}} = %c0 to %c3 step %c1
-
-// PARALLEL: scf.parallel
-// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3)
-
-// AFFINE: affine.for %{{.*}} = 0 to 5
-// AFFINE: affine.for %{{.*}} = 0 to 1
-// AFFINE: affine.for %{{.*}} = 0 to 4
-// AFFINE: affine.for %{{.*}} = 0 to 2
-// AFFINE: affine.for %{{.*}} = 0 to 3
-
-// -----
-
-#map = affine_map<(i, j, k, l, m) -> (i, j, k, l, m)>
-func @generic(%output: memref<1x2x3x4x5xindex>) {
- linalg.generic {indexing_maps = [#map],
- iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
- outs(%output : memref<1x2x3x4x5xindex>) {
- ^bb0(%arg0 : index):
- %i = linalg.index 0 : index
- %j = linalg.index 1 : index
- %k = linalg.index 2 : index
- %l = linalg.index 3 : index
- %m = linalg.index 4 : index
- %0 = addi %i, %j : index
- %1 = addi %0, %k : index
- %2 = addi %1, %l : index
- %3 = addi %2, %m : index
- linalg.yield %3: index
- }
- return
-}
-
-// LOOP: scf.for %[[m:.*]] = %c0 to %c5 step %c1
-// LOOP: scf.for %[[i:.*]] = %c0 to %c1 step %c1
-// LOOP: scf.for %[[l:.*]] = %c0 to %c4 step %c1
-// LOOP: scf.for %[[j:.*]] = %c0 to %c2 step %c1
-// LOOP: scf.for %[[k:.*]] = %c0 to %c3 step %c1
-// LOOP: %{{.*}} = addi %[[i]], %[[j]] : index
-// LOOP: %{{.*}} = addi %{{.*}}, %[[k]] : index
-// LOOP: %{{.*}} = addi %{{.*}}, %[[l]] : index
-// LOOP: %{{.*}} = addi %{{.*}}, %[[m]] : index
-
-// PARALLEL: scf.parallel (%[[m:.*]], %[[i:.*]], %[[l:.*]], %[[j:.*]], %[[k:.*]]) =
-// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3)
-// PARALLEL: %{{.*}} = addi %[[i]], %[[j]] : index
-// PARALLEL: %{{.*}} = addi %{{.*}}, %[[k]] : index
-// PARALLEL: %{{.*}} = addi %{{.*}}, %[[l]] : index
-// PARALLEL: %{{.*}} = addi %{{.*}}, %[[m]] : index
-
-// AFFINE: affine.for %[[m:.*]] = 0 to 5
-// AFFINE: affine.for %[[i:.*]] = 0 to 1
-// AFFINE: affine.for %[[l:.*]] = 0 to 4
-// AFFINE: affine.for %[[j:.*]] = 0 to 2
-// AFFINE: affine.for %[[k:.*]] = 0 to 3
-// AFFINE: %{{.*}} = addi %[[i]], %[[j]] : index
-// AFFINE: %{{.*}} = addi %{{.*}}, %[[k]] : index
-// AFFINE: %{{.*}} = addi %{{.*}}, %[[l]] : index
-// AFFINE: %{{.*}} = addi %{{.*}}, %[[m]] : index
More information about the Mlir-commits
mailing list