[Mlir-commits] [mlir] c303d9b - [mlir][Linalg] NFC - Cleanup explicitly instantiated paterns 2/n - Loops.cpp
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Oct 9 13:00:46 PDT 2020
Author: Nicolas Vasilache
Date: 2020-10-09T19:59:49Z
New Revision: c303d9b394427e93aa772d543426715b24f98fd1
URL: https://github.com/llvm/llvm-project/commit/c303d9b394427e93aa772d543426715b24f98fd1
DIFF: https://github.com/llvm/llvm-project/commit/c303d9b394427e93aa772d543426715b24f98fd1.diff
LOG: [mlir][Linalg] NFC - Cleanup explicitly instantiated paterns 2/n - Loops.cpp
This revision belongs to a series of patches that reduce reliance of Linalg transformations on templated rewrite and conversion patterns.
Instead, this uses a MatchAnyTag pattern for the vast majority of cases and dispatches internally.
Differential revision: https://reviews.llvm.org/D89133
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index b95469d8a955..2abf0aed37b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -23,6 +23,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
+#include "llvm/ADT/TypeSwitch.h"
+
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
@@ -65,7 +67,7 @@ static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
assert(op.getOperation()->getNumRegions() == 1 &&
"Expected single region op");
auto &b = ScopedContext::getBuilderRef();
- auto &block = op.region().front();
+ auto &block = op.getOperation()->getRegion(0).front();
BlockAndValueMapping map;
map.map(block.getArguments(), indexedValues);
for (auto &op : block.without_terminator()) {
@@ -102,8 +104,6 @@ static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs,
makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
}
-namespace {
-
/// Emits the MLIR for the scalar part of the generic op by:
/// 1. Emitting load ops for each input and output view in order. This is
/// achieved by applying the appropriate input or output map to the
@@ -134,10 +134,9 @@ namespace {
/// }
/// }
/// ```
-// TODO: need a LinalgStructuredOpInterface.
-template <typename IndexedValueType, typename LinalgStructuredOpType>
-void emitScalarImplementation(ArrayRef<Value> allIvs,
- LinalgStructuredOpType linalgOp) {
+template <typename IndexedValueType>
+static void emitScalarImplementation(ArrayRef<Value> allIvs,
+ LinalgOp linalgOp) {
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto &b = ScopedContext::getBuilderRef();
@@ -150,7 +149,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end());
if (attr) {
- auto operand = linalgOp.getOperand(attr.getInt());
+ auto operand = linalgOp.getOperation()->getOperand(attr.getInt());
auto shapedType = operand.getType().template cast<ShapedType>();
allIvsPlusDims.reserve(allIvs.size() + shapedType.getRank());
for (unsigned idx = 0, e = shapedType.getRank(); idx < e; ++idx)
@@ -190,7 +189,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
}
template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
+static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
assert(copyOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto nPar = copyOp.getNumParallelLoops();
@@ -211,7 +210,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
}
template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
+static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
assert(fillOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto nPar = fillOp.getNumParallelLoops();
@@ -224,8 +223,8 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
}
template <typename IndexedValueType>
-Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
- MutableArrayRef<Value> imIdx) {
+static Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
+ MutableArrayRef<Value> imIdx) {
// TODO: add a level of indirection to linalg.generic.
if (!convOp.padding())
return im(imIdx);
@@ -311,8 +310,9 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
}
}
-template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
+template <typename IndexedValueType, typename OpType>
+static void emitPoolingMinMaxScalarImplementation(ArrayRef<Value> allIvs,
+ OpType op) {
InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
// Emit scalar form.
IndexedValueType output(op.output());
@@ -320,30 +320,34 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
Value lhs = output(indices.outputs);
Value rhs = input(indices.inputs);
using edsc::op::sgt;
- Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs);
- output(indices.outputs) = maxValue;
+ using edsc::op::slt;
+ Value value = std::is_same<OpType, PoolingMinOp>()
+ ? std_select(slt(lhs, rhs), lhs, rhs)
+ : std_select(sgt(lhs, rhs), lhs, rhs);
+ output(indices.outputs) = value;
}
template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
- InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
- // Emit scalar form.
- IndexedValueType output(op.output());
- IndexedValueType input(op.input());
- Value lhs = output(indices.outputs);
- Value rhs = input(indices.inputs);
- using edsc::op::slt;
- Value minValue = std_select(slt(lhs, rhs), lhs, rhs);
- output(indices.outputs) = minValue;
+static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
+ emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMaxOp>(allIvs,
+ op);
}
+
template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
+static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
+ emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMinOp>(allIvs,
+ op);
+}
+
+template <typename IndexedValueType>
+static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
auto indices = getInputAndOutputIndices(allIvs, op);
IndexedValueType input(op.input()), output(op.output());
// Emit scalar form.
output(indices.outputs) += input(indices.inputs);
}
+
/// Emits the MLIR for the scalar part of the indexed generic op by:
/// 1. Emitting load ops for each input and output view in order. This is
/// achieved by applying the appropriate input or output map to the
@@ -422,15 +426,16 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
indexing, outputBuffers);
}
-template <typename LoopTy, typename ConcreteOpTy>
-Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
+template <typename LoopTy>
+static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
+ OpBuilder &builder) {
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
ScopedContext scope(builder, op->getLoc());
// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
- auto linalgOp = cast<ConcreteOpTy>(op);
+ auto linalgOp = cast<LinalgOp>(op);
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto mapsRange =
@@ -447,7 +452,12 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
[&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
assert(iterArgs.empty() && "unexpected iterArgs");
allIvs.append(ivs.begin(), ivs.end());
- emitScalarImplementation<IndexedValueTy>(allIvs, linalgOp);
+ llvm::TypeSwitch<Operation *>(op)
+ .Case<CopyOp, FillOp, ConvOp, PoolingMaxOp, PoolingMinOp,
+ PoolingSumOp, IndexedGenericOp, LinalgOp>([&](auto op) {
+ emitScalarImplementation<IndexedValueTy>(allIvs, op);
+ })
+ .Default([&](Operation *op) { assert(false && "unexpected op"); });
return scf::ValueVector{};
});
// Number of loop ops might be
diff erent from the number of ivs since some
@@ -467,32 +477,38 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
return loops;
}
-template <typename LoopType, typename ConcreteOp>
+namespace {
+template <typename LoopType>
class LinalgRewritePattern : public RewritePattern {
public:
- explicit LinalgRewritePattern(MLIRContext *context)
- : RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
+ LinalgRewritePattern() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- if (!linalgOpToLoopsImpl<LoopType, ConcreteOp>(op, rewriter))
+ if (!isa<LinalgOp>(op))
+ return failure();
+ if (!linalgOpToLoopsImpl<LoopType>(op, rewriter))
return failure();
rewriter.eraseOp(op);
return success();
}
};
-template <typename LoopType, typename ConcreteOp>
-void insertOnePattern(OwningRewritePatternList &patterns, MLIRContext *ctx) {
- patterns.insert<LinalgRewritePattern<LoopType, ConcreteOp>>(ctx);
-}
+struct FoldAffineOp;
+} // namespace
-template <typename LoopType, typename... Args>
-void insertPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
- (void)std::initializer_list<int>{
- 0, (insertOnePattern<LoopType, Args>(patterns, ctx), 0)...};
+template <typename LoopType>
+static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
+ OwningRewritePatternList patterns;
+ patterns.insert<LinalgRewritePattern<LoopType>>();
+ DimOp::getCanonicalizationPatterns(patterns, context);
+ AffineApplyOp::getCanonicalizationPatterns(patterns, context);
+ patterns.insert<FoldAffineOp>(context);
+ // Just apply the patterns greedily.
+ applyPatternsAndFoldGreedily(funcOp, patterns);
}
+namespace {
/// Local folding pattern for AffineApplyOp that we can apply greedily.
/// This replaces AffineApplyOp by the proper value in cases where the
/// associated map is trivial.
@@ -529,38 +545,20 @@ struct FoldAffineOp : public RewritePattern {
return failure();
}
};
-} // namespace
-
-template <typename LoopType>
-static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
- OwningRewritePatternList patterns;
- // Canonicalization and folding patterns applied greedily allow cleaning up
- // the emitted IR on the fly.
- // TODO: fold view and subview ops?
- insertPatterns<LoopType,
-#define GET_OP_LIST
-#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
- >(patterns, context);
- DimOp::getCanonicalizationPatterns(patterns, context);
- AffineApplyOp::getCanonicalizationPatterns(patterns, context);
- patterns.insert<FoldAffineOp>(context);
- // Just apply the patterns greedily.
- applyPatternsAndFoldGreedily(funcOp, patterns);
-}
-
-namespace {
struct LowerToAffineLoops
: public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
void runOnFunction() override {
lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext());
}
};
+
struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
void runOnFunction() override {
lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext());
}
};
+
struct LowerToParallelLoops
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
void runOnFunction() override {
@@ -583,60 +581,6 @@ mlir::createConvertLinalgToAffineLoopsPass() {
return std::make_unique<LowerToAffineLoops>();
}
-// TODO: gradually remove this layer as more ops become "named".
-template <typename LoopTy>
-static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
- OpBuilder &builder) {
- assert(isa<LinalgOp>(op) && "LinalgOp expected");
- if (isa<CopyOp>(op))
- return linalgOpToLoopsImpl<LoopTy, CopyOp>(op, builder);
- if (isa<FillOp>(op))
- return linalgOpToLoopsImpl<LoopTy, FillOp>(op, builder);
- if (isa<ConvOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvOp>(op, builder);
- if (isa<PoolingMaxOp>(op))
- return linalgOpToLoopsImpl<LoopTy, PoolingMaxOp>(op, builder);
- if (isa<PoolingMinOp>(op))
- return linalgOpToLoopsImpl<LoopTy, PoolingMinOp>(op, builder);
- if (isa<PoolingSumOp>(op))
- return linalgOpToLoopsImpl<LoopTy, PoolingSumOp>(op, builder);
- if (isa<IndexedGenericOp>(op))
- return linalgOpToLoopsImpl<LoopTy, IndexedGenericOp>(op, builder);
-
- // TODO: Cases below are generic and need a LinalgStructuredOpInterface.
- if (isa<GenericOp>(op))
- return linalgOpToLoopsImpl<LoopTy, GenericOp>(op, builder);
- if (isa<MatmulOp>(op))
- return linalgOpToLoopsImpl<LoopTy, MatmulOp>(op, builder);
- if (isa<MatvecOp>(op))
- return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder);
- if (isa<VecmatOp>(op))
- return linalgOpToLoopsImpl<LoopTy, VecmatOp>(op, builder);
- if (isa<DotOp>(op))
- return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder);
- if (isa<BatchMatmulOp>(op))
- return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder);
- if (isa<ConvWOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvWOp>(op, builder);
- if (isa<ConvNWCOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvNWCOp>(op, builder);
- if (isa<ConvNCWOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvNCWOp>(op, builder);
- if (isa<ConvHWOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvHWOp>(op, builder);
- if (isa<ConvNHWCOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvNHWCOp>(op, builder);
- if (isa<ConvNCHWOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvNCHWOp>(op, builder);
- if (isa<ConvDHWOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvDHWOp>(op, builder);
- if (isa<ConvNDHWCOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvNDHWCOp>(op, builder);
- if (isa<ConvNCDHWOp>(op))
- return linalgOpToLoopsImpl<LoopTy, ConvNCDHWOp>(op, builder);
- llvm_unreachable("Unexpected op in linalgOpToLoopsImpl");
-}
-
SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
AffineMap map,
ValueRange viewSizes) {
@@ -705,7 +649,7 @@ SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
template <typename LoopTy>
Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
Operation *op) {
- return linalgOpToLoopsImplSwitch<LoopTy>(op, builder);
+ return linalgOpToLoopsImpl<LoopTy>(op, builder);
}
template Optional<LinalgLoops>
More information about the Mlir-commits
mailing list