[Mlir-commits] [mlir] c303d9b - [mlir][Linalg] NFC - Cleanup explicitly instantiated paterns 2/n - Loops.cpp

Nicolas Vasilache llvmlistbot at llvm.org
Fri Oct 9 13:00:46 PDT 2020


Author: Nicolas Vasilache
Date: 2020-10-09T19:59:49Z
New Revision: c303d9b394427e93aa772d543426715b24f98fd1

URL: https://github.com/llvm/llvm-project/commit/c303d9b394427e93aa772d543426715b24f98fd1
DIFF: https://github.com/llvm/llvm-project/commit/c303d9b394427e93aa772d543426715b24f98fd1.diff

LOG: [mlir][Linalg] NFC - Cleanup explicitly instantiated paterns 2/n - Loops.cpp

This revision belongs to a series of patches that reduce reliance of Linalg transformations on templated rewrite and conversion patterns.
Instead, this uses a MatchAnyTag pattern for the vast majority of cases and dispatches internally.

Differential revision: https://reviews.llvm.org/D89133

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index b95469d8a955..2abf0aed37b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -23,6 +23,8 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/FoldUtils.h"
 
+#include "llvm/ADT/TypeSwitch.h"
+
 using namespace mlir;
 using namespace mlir::edsc;
 using namespace mlir::edsc::intrinsics;
@@ -65,7 +67,7 @@ static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
   assert(op.getOperation()->getNumRegions() == 1 &&
          "Expected single region op");
   auto &b = ScopedContext::getBuilderRef();
-  auto &block = op.region().front();
+  auto &block = op.getOperation()->getRegion(0).front();
   BlockAndValueMapping map;
   map.map(block.getArguments(), indexedValues);
   for (auto &op : block.without_terminator()) {
@@ -102,8 +104,6 @@ static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs,
       makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
 }
 
-namespace {
-
 /// Emits the MLIR for the scalar part of the generic op by:
 ///   1. Emitting load ops for each input and output view in order. This is
 ///      achieved by applying the appropriate input or output map to the
@@ -134,10 +134,9 @@ namespace {
 ///      }
 ///    }
 /// ```
-// TODO: need a LinalgStructuredOpInterface.
-template <typename IndexedValueType, typename LinalgStructuredOpType>
-void emitScalarImplementation(ArrayRef<Value> allIvs,
-                              LinalgStructuredOpType linalgOp) {
+template <typename IndexedValueType>
+static void emitScalarImplementation(ArrayRef<Value> allIvs,
+                                     LinalgOp linalgOp) {
   assert(linalgOp.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
   auto &b = ScopedContext::getBuilderRef();
@@ -150,7 +149,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
   auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
   auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end());
   if (attr) {
-    auto operand = linalgOp.getOperand(attr.getInt());
+    auto operand = linalgOp.getOperation()->getOperand(attr.getInt());
     auto shapedType = operand.getType().template cast<ShapedType>();
     allIvsPlusDims.reserve(allIvs.size() + shapedType.getRank());
     for (unsigned idx = 0, e = shapedType.getRank(); idx < e; ++idx)
@@ -190,7 +189,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
 }
 
 template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
+static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
   assert(copyOp.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
   auto nPar = copyOp.getNumParallelLoops();
@@ -211,7 +210,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
 }
 
 template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
+static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
   assert(fillOp.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
   auto nPar = fillOp.getNumParallelLoops();
@@ -224,8 +223,8 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
 }
 
 template <typename IndexedValueType>
-Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
-                     MutableArrayRef<Value> imIdx) {
+static Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
+                            MutableArrayRef<Value> imIdx) {
   // TODO: add a level of indirection to linalg.generic.
   if (!convOp.padding())
     return im(imIdx);
@@ -311,8 +310,9 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
   }
 }
 
-template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
+template <typename IndexedValueType, typename OpType>
+static void emitPoolingMinMaxScalarImplementation(ArrayRef<Value> allIvs,
+                                                  OpType op) {
   InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
   // Emit scalar form.
   IndexedValueType output(op.output());
@@ -320,30 +320,34 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
   Value lhs = output(indices.outputs);
   Value rhs = input(indices.inputs);
   using edsc::op::sgt;
-  Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs);
-  output(indices.outputs) = maxValue;
+  using edsc::op::slt;
+  Value value = std::is_same<OpType, PoolingMinOp>()
+                    ? std_select(slt(lhs, rhs), lhs, rhs)
+                    : std_select(sgt(lhs, rhs), lhs, rhs);
+  output(indices.outputs) = value;
 }
 
 template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
-  InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
-  // Emit scalar form.
-  IndexedValueType output(op.output());
-  IndexedValueType input(op.input());
-  Value lhs = output(indices.outputs);
-  Value rhs = input(indices.inputs);
-  using edsc::op::slt;
-  Value minValue = std_select(slt(lhs, rhs), lhs, rhs);
-  output(indices.outputs) = minValue;
+static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
+  emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMaxOp>(allIvs,
+                                                                        op);
 }
+
 template <typename IndexedValueType>
-void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
+static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
+  emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMinOp>(allIvs,
+                                                                        op);
+}
+
+template <typename IndexedValueType>
+static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
   auto indices = getInputAndOutputIndices(allIvs, op);
   IndexedValueType input(op.input()), output(op.output());
 
   // Emit scalar form.
   output(indices.outputs) += input(indices.inputs);
 }
+
 /// Emits the MLIR for the scalar part of the indexed generic op by:
 ///   1. Emitting load ops for each input and output view in order. This is
 ///      achieved by applying the appropriate input or output map to the
@@ -422,15 +426,16 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
                                              indexing, outputBuffers);
 }
 
-template <typename LoopTy, typename ConcreteOpTy>
-Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
+template <typename LoopTy>
+static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
+                                                 OpBuilder &builder) {
   using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
 
   ScopedContext scope(builder, op->getLoc());
 
   // The flattened loopToOperandRangesMaps is expected to be an invertible
   // permutation map (which is asserted in the inverse calculation).
-  auto linalgOp = cast<ConcreteOpTy>(op);
+  auto linalgOp = cast<LinalgOp>(op);
   assert(linalgOp.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
   auto mapsRange =
@@ -447,7 +452,12 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
       [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
         assert(iterArgs.empty() && "unexpected iterArgs");
         allIvs.append(ivs.begin(), ivs.end());
-        emitScalarImplementation<IndexedValueTy>(allIvs, linalgOp);
+        llvm::TypeSwitch<Operation *>(op)
+            .Case<CopyOp, FillOp, ConvOp, PoolingMaxOp, PoolingMinOp,
+                  PoolingSumOp, IndexedGenericOp, LinalgOp>([&](auto op) {
+              emitScalarImplementation<IndexedValueTy>(allIvs, op);
+            })
+            .Default([&](Operation *op) { assert(false && "unexpected op"); });
         return scf::ValueVector{};
       });
   // Number of loop ops might be 
diff erent from the number of ivs since some
@@ -467,32 +477,38 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
   return loops;
 }
 
-template <typename LoopType, typename ConcreteOp>
+namespace {
+template <typename LoopType>
 class LinalgRewritePattern : public RewritePattern {
 public:
-  explicit LinalgRewritePattern(MLIRContext *context)
-      : RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
+  LinalgRewritePattern() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
 
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    if (!linalgOpToLoopsImpl<LoopType, ConcreteOp>(op, rewriter))
+    if (!isa<LinalgOp>(op))
+      return failure();
+    if (!linalgOpToLoopsImpl<LoopType>(op, rewriter))
       return failure();
     rewriter.eraseOp(op);
     return success();
   }
 };
 
-template <typename LoopType, typename ConcreteOp>
-void insertOnePattern(OwningRewritePatternList &patterns, MLIRContext *ctx) {
-  patterns.insert<LinalgRewritePattern<LoopType, ConcreteOp>>(ctx);
-}
+struct FoldAffineOp;
+} // namespace
 
-template <typename LoopType, typename... Args>
-void insertPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
-  (void)std::initializer_list<int>{
-      0, (insertOnePattern<LoopType, Args>(patterns, ctx), 0)...};
+template <typename LoopType>
+static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
+  OwningRewritePatternList patterns;
+  patterns.insert<LinalgRewritePattern<LoopType>>();
+  DimOp::getCanonicalizationPatterns(patterns, context);
+  AffineApplyOp::getCanonicalizationPatterns(patterns, context);
+  patterns.insert<FoldAffineOp>(context);
+  // Just apply the patterns greedily.
+  applyPatternsAndFoldGreedily(funcOp, patterns);
 }
 
+namespace {
 /// Local folding pattern for AffineApplyOp that we can apply greedily.
 /// This replaces AffineApplyOp by the proper value in cases where the
 /// associated map is trivial.
@@ -529,38 +545,20 @@ struct FoldAffineOp : public RewritePattern {
     return failure();
   }
 };
-} // namespace
-
-template <typename LoopType>
-static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
-  OwningRewritePatternList patterns;
-  // Canonicalization and folding patterns applied greedily allow cleaning up
-  // the emitted IR on the fly.
-  // TODO: fold view and subview ops?
-  insertPatterns<LoopType,
-#define GET_OP_LIST
-#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
-                 >(patterns, context);
 
-  DimOp::getCanonicalizationPatterns(patterns, context);
-  AffineApplyOp::getCanonicalizationPatterns(patterns, context);
-  patterns.insert<FoldAffineOp>(context);
-  // Just apply the patterns greedily.
-  applyPatternsAndFoldGreedily(funcOp, patterns);
-}
-
-namespace {
 struct LowerToAffineLoops
     : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
   void runOnFunction() override {
     lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext());
   }
 };
+
 struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
   void runOnFunction() override {
     lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext());
   }
 };
+
 struct LowerToParallelLoops
     : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
   void runOnFunction() override {
@@ -583,60 +581,6 @@ mlir::createConvertLinalgToAffineLoopsPass() {
   return std::make_unique<LowerToAffineLoops>();
 }
 
-// TODO: gradually remove this layer as more ops become "named".
-template <typename LoopTy>
-static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
-                                                       OpBuilder &builder) {
-  assert(isa<LinalgOp>(op) && "LinalgOp expected");
-  if (isa<CopyOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, CopyOp>(op, builder);
-  if (isa<FillOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, FillOp>(op, builder);
-  if (isa<ConvOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, ConvOp>(op, builder);
-  if (isa<PoolingMaxOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, PoolingMaxOp>(op, builder);
-  if (isa<PoolingMinOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, PoolingMinOp>(op, builder);
-  if (isa<PoolingSumOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, PoolingSumOp>(op, builder);
-  if (isa<IndexedGenericOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, IndexedGenericOp>(op, builder);
-
-  // TODO: Cases below are generic and need a LinalgStructuredOpInterface.
-  if (isa<GenericOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, GenericOp>(op, builder);
-  if (isa<MatmulOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, MatmulOp>(op, builder);
-  if (isa<MatvecOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder);
-  if (isa<VecmatOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, VecmatOp>(op, builder);
-  if (isa<DotOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder);
-  if (isa<BatchMatmulOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder);
-  if (isa<ConvWOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, ConvWOp>(op, builder);
-  if (isa<ConvNWCOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, ConvNWCOp>(op, builder);
-  if (isa<ConvNCWOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, ConvNCWOp>(op, builder);
-  if (isa<ConvHWOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, ConvHWOp>(op, builder);
-  if (isa<ConvNHWCOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, ConvNHWCOp>(op, builder);
-  if (isa<ConvNCHWOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, ConvNCHWOp>(op, builder);
-  if (isa<ConvDHWOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, ConvDHWOp>(op, builder);
-  if (isa<ConvNDHWCOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, ConvNDHWCOp>(op, builder);
-  if (isa<ConvNCDHWOp>(op))
-    return linalgOpToLoopsImpl<LoopTy, ConvNCDHWOp>(op, builder);
-  llvm_unreachable("Unexpected op in linalgOpToLoopsImpl");
-}
-
 SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
                                                    AffineMap map,
                                                    ValueRange viewSizes) {
@@ -705,7 +649,7 @@ SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
 template <typename LoopTy>
 Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
                                                          Operation *op) {
-  return linalgOpToLoopsImplSwitch<LoopTy>(op, builder);
+  return linalgOpToLoopsImpl<LoopTy>(op, builder);
 }
 
 template Optional<LinalgLoops>


        


More information about the Mlir-commits mailing list