[Mlir-commits] [mlir] 54fafd1 - [mlir][Linalg] Introduce canonicalization to remove dead LinalgOps

Nicolas Vasilache llvmlistbot at llvm.org
Thu Aug 6 03:14:21 PDT 2020


Author: Nicolas Vasilache
Date: 2020-08-06T06:08:46-04:00
New Revision: 54fafd17a728f3dd33b3cf999b6dfd3cd1d49f12

URL: https://github.com/llvm/llvm-project/commit/54fafd17a728f3dd33b3cf999b6dfd3cd1d49f12
DIFF: https://github.com/llvm/llvm-project/commit/54fafd17a728f3dd33b3cf999b6dfd3cd1d49f12.diff

LOG: [mlir][Linalg] Introduce canonicalization to remove dead LinalgOps

When any of the memrefs in a structured linalg op has a zero dimension, it becomes dead.
This is consistent with the fact that linalg ops deduce their loop bounds from their operands.

Note however that this is not the case for the `tensor<0xelt_type>` which is a special convention
that must be lowered away into either `memref<elt_type>` or just `elt_type` before this
canonicalization can kick in.

Differential Revision: https://reviews.llvm.org/D85413

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/IR/StandardTypes.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index dad6f4597e62..26406ccdc9ef 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -153,6 +153,7 @@ def CopyOp : LinalgStructured_Op<"copy", [
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
@@ -178,6 +179,7 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 /// A base class for pooling operation such as conv. The arguments must contain
@@ -358,6 +360,7 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 class SingleInputPoolingBase_Op<string mnemonic>
@@ -417,6 +420,7 @@ class SingleInputPoolingBase_Op<string mnemonic>
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> {
@@ -658,6 +662,7 @@ def GenericOp : GenericOpBase<"generic"> {
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 /// GenericOp with Indexing (i.e. multi-for style in which the region is passed
@@ -795,6 +800,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -817,6 +823,7 @@ class LinalgNamedStructured_Op<string mnemonic, list<OpTrait> props>
   let printer = [{ return ::printNamedStructuredOp(p, *this); }];
   let verifier = [{ return ::verifyNamedStructuredOp(*this); }];
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 // This file is auto-generated from a tc specification.

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 03bd71f17716..a8d98af2ce40 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1153,38 +1153,6 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
 // TODO: Consider making all this boilerplate easy to autogenerate
 // with Tablegen. This seems a desirable property in the context of OpInterfaces
 // where a Linalg "named" op **isa** LinalgOp.
-LogicalResult ConvOp::fold(ArrayRef<Attribute>,
-                           SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult PoolingMaxOp::fold(ArrayRef<Attribute>,
-                                 SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult PoolingMinOp::fold(ArrayRef<Attribute>,
-                                 SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult PoolingSumOp::fold(ArrayRef<Attribute>,
-                                 SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult CopyOp::fold(ArrayRef<Attribute>,
-                           SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult FillOp::fold(ArrayRef<Attribute>,
-                           SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult GenericOp::fold(ArrayRef<Attribute>,
-                              SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult IndexedGenericOp::fold(ArrayRef<Attribute>,
-                                     SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
@@ -1299,58 +1267,64 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
   return verifyGenericOp<NamedStructuredOpType>(op);
 }
 
+struct EraseDeadLinalgOp : public RewritePattern {
+  EraseDeadLinalgOp(PatternBenefit benefit = 1)
+      : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    auto linalgOp = dyn_cast<LinalgOp>(op);
+    if (!linalgOp)
+      return failure();
+    for (Value v : linalgOp.getInputsAndOutputBuffers()) {
+      // Linalg "inputs" may be either tensor or memref type.
+      // tensor<0xelt_type> is a convention that may not always mean
+      // "0 iterations". Only erase in cases we see memref<...x0x...>.
+      auto mt = v.getType().dyn_cast<MemRefType>();
+      if (!mt)
+        continue;
+      if (llvm::is_contained(mt.getShape(), 0)) {
+        rewriter.eraseOp(linalgOp);
+        return success();
+      }
+    }
+    return failure();
+  }
+};
+
+#define CANONICALIZERS_AND_FOLDERS(XXX)                                        \
+  void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results,     \
+                                        MLIRContext *context) {                \
+    results.insert<EraseDeadLinalgOp>();                                       \
+  }                                                                            \
+                                                                               \
+  LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \
+                          SmallVectorImpl<OpFoldResult> &) {                   \
+    return foldMemRefCast(*this);                                              \
+  }
+
+CANONICALIZERS_AND_FOLDERS(ConvOp);
+CANONICALIZERS_AND_FOLDERS(PoolingMaxOp);
+CANONICALIZERS_AND_FOLDERS(PoolingMinOp);
+CANONICALIZERS_AND_FOLDERS(PoolingSumOp);
+CANONICALIZERS_AND_FOLDERS(CopyOp);
+CANONICALIZERS_AND_FOLDERS(FillOp);
+CANONICALIZERS_AND_FOLDERS(GenericOp);
+CANONICALIZERS_AND_FOLDERS(IndexedGenericOp);
+
 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
 
 // TODO: Determine whether we can generate the folders and verifiers.
-LogicalResult BatchMatmulOp::fold(ArrayRef<Attribute>,
-                                  SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult DotOp::fold(ArrayRef<Attribute>,
-                          SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
-                             SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
-                             SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvWOp::fold(ArrayRef<Attribute>,
-                            SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvNWCOp::fold(ArrayRef<Attribute>,
-                              SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvNCWOp::fold(ArrayRef<Attribute>,
-                              SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvHWOp::fold(ArrayRef<Attribute>,
-                             SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvNHWCOp::fold(ArrayRef<Attribute>,
-                               SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvNCHWOp::fold(ArrayRef<Attribute>,
-                               SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvDHWOp::fold(ArrayRef<Attribute>,
-                              SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvNDHWCOp::fold(ArrayRef<Attribute>,
-                                SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
-LogicalResult ConvNCDHWOp::fold(ArrayRef<Attribute>,
-                                SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
-}
+CANONICALIZERS_AND_FOLDERS(BatchMatmulOp);
+CANONICALIZERS_AND_FOLDERS(DotOp);
+CANONICALIZERS_AND_FOLDERS(MatmulOp);
+CANONICALIZERS_AND_FOLDERS(MatvecOp);
+CANONICALIZERS_AND_FOLDERS(ConvWOp);
+CANONICALIZERS_AND_FOLDERS(ConvNWCOp);
+CANONICALIZERS_AND_FOLDERS(ConvNCWOp);
+CANONICALIZERS_AND_FOLDERS(ConvHWOp);
+CANONICALIZERS_AND_FOLDERS(ConvNHWCOp);
+CANONICALIZERS_AND_FOLDERS(ConvNCHWOp);
+CANONICALIZERS_AND_FOLDERS(ConvDHWOp);
+CANONICALIZERS_AND_FOLDERS(ConvNDHWCOp);
+CANONICALIZERS_AND_FOLDERS(ConvNCDHWOp);

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 70b00cf8963a..f878672cd912 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -732,19 +732,16 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
                                                 ArrayRef<AffineExpr> exprs,
                                                 MLIRContext *context) {
+  // Size 0 corner case is useful for canonicalizations.
+  if (llvm::is_contained(sizes, 0))
+    return getAffineConstantExpr(0, context);
+
+  auto maps = AffineMap::inferFromExprList(exprs);
+  assert(!maps.empty() && "Expected one non-empty map");
+  unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
+
   AffineExpr expr;
   bool dynamicPoisonBit = false;
-  unsigned numDims = 0;
-  unsigned nSymbols = 0;
-  // Compute the number of symbols and dimensions of the passed exprs.
-  for (AffineExpr expr : exprs) {
-    expr.walk([&numDims, &nSymbols](AffineExpr d) {
-      if (AffineDimExpr dim = d.dyn_cast<AffineDimExpr>())
-        numDims = std::max(numDims, dim.getPosition() + 1);
-      else if (AffineSymbolExpr symbol = d.dyn_cast<AffineSymbolExpr>())
-        nSymbols = std::max(nSymbols, symbol.getPosition() + 1);
-    });
-  }
   int64_t runningSize = 1;
   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
     int64_t size = std::get<1>(en);

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 9cb7df05d63e..005bd1c87445 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -172,3 +172,34 @@ func @no_fold_memref_reshape(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
 // CHECK-LABEL: @no_fold_memref_reshape
 //       CHECK:   linalg.reshape
 //       CHECK:   linalg.reshape
+
+// -----
+
+#accesses = [
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> (i)>
+]
+
+#trait = {
+  args_in = 1,
+  args_out = 1,
+  indexing_maps = #accesses,
+  iterator_types = ["parallel"]
+}
+
+func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
+  // memref<0x32> is expected to be dce'ed
+  linalg.copy(%arg0, %arg0): memref<0xf32>, memref<0xf32>
+
+  // tensor<0xf32> cannot be dce'ed
+  %1 = linalg.generic #trait %arg1 {
+  ^bb(%0: f32) :
+    linalg.yield %0 : f32
+  } : tensor<0xf32> -> tensor<0xf32>
+
+  return %1: tensor<0xf32>
+}
+// CHECK-LABEL: @dce_zero_memref
+//   CHECK-NOT:   linalg.copy
+//  CHECK-NEXT:   linalg.generic
+


        


More information about the Mlir-commits mailing list