[llvm-branch-commits] [mlir] b8d211f - [MLIR][Linalg] Canonicalization patterns for linalg.generic.

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Mar 15 11:44:20 PDT 2022


Author: Nirvedh
Date: 2022-03-15T18:42:43Z
New Revision: b8d211fc317ffefaed1d65b226cda6c464f7d216

URL: https://github.com/llvm/llvm-project/commit/b8d211fc317ffefaed1d65b226cda6c464f7d216
DIFF: https://github.com/llvm/llvm-project/commit/b8d211fc317ffefaed1d65b226cda6c464f7d216.diff

LOG: [MLIR][Linalg] Canonicalization patterns for linalg.generic.
Fold linalg.fill into linalg.generic.
Remove dead arguments used in linalg.generic.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D121535

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/fusion-indexed.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 331a8b91bd330..2c9b1e0c53553 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -105,6 +105,33 @@ static LogicalResult foldMemRefCast(Operation *op) {
   return success(folded);
 }
 
+/// Helper function to find if there is atleast one dimension in an AffineMap
+/// testMap that is contained in `testMapLocation` of  `maps` but not in any
+/// other locations
+static bool hasaUniqueDim(ArrayRef<AffineMap> maps, unsigned testMapLocation) {
+  AffineMap testMap = maps[testMapLocation];
+  llvm::SmallDenseSet<unsigned> dimsToCheck;
+  for (auto result : testMap.getResults()) {
+    auto expr = result.dyn_cast<AffineDimExpr>();
+    if (expr != nullptr)
+      dimsToCheck.insert(expr.getPosition());
+  }
+  for (auto It : llvm::enumerate(maps)) {
+    if (It.index() == testMapLocation)
+      continue;
+    auto map = It.value();
+    for (auto result : map.getResults()) {
+      auto expr = result.dyn_cast<AffineDimExpr>();
+      if (expr != nullptr) {
+        dimsToCheck.erase(expr.getPosition());
+      }
+      if (dimsToCheck.empty())
+        return false;
+    }
+  }
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // Region builder helper.
 // TODO: Move this to a utility library.
@@ -826,11 +853,95 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
     return success();
   }
 };
+
+/// Drop dead args of a linalg generic op.
+/// An arg is dead if it has zero uses in the op region.
+struct DeadArgsGenericOpInputs : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<AffineMap> oldIndexingMaps = genericOp.getIndexingMaps();
+    // Maps must be projected permutations.
+    if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) {
+          return !map.isProjectedPermutation();
+        }))
+      return failure();
+    Block &payload = genericOp.region().front();
+    SmallVector<Value> newInputOperands;
+    SmallVector<AffineMap> newIndexingMaps;
+    bool deadArgFound = false;
+    int inputSize = genericOp.getInputOperands().size();
+    for (int i = inputSize - 1; i >= 0; i--) {
+      OpOperand *opOperand = genericOp.getInputOperand(i);
+      // Iterate in reverse, so that we erase later args first, preventing the
+      // argument list from shifting unexpectedly and invalidating all our
+      // indices.
+      if (payload.getArgument(i).use_empty() &&
+          !hasaUniqueDim(oldIndexingMaps, i)) {
+        payload.eraseArgument(i);
+        deadArgFound = true;
+        // remove this indexing map out of consideration for hasaUniqueDim check
+        oldIndexingMaps.erase(oldIndexingMaps.begin() + i);
+      } else {
+        newInputOperands.insert(newInputOperands.begin(), opOperand->get());
+        newIndexingMaps.insert(newIndexingMaps.begin(),
+                               genericOp.getTiedIndexingMap(opOperand));
+      }
+    }
+    // Bail out if there are no dead args.
+    if (!deadArgFound)
+      return failure();
+    for (OpOperand *opOperand : genericOp.getOutputOperands())
+      newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
+    SmallVector<Value> outputOperands = genericOp.getOutputOperands();
+
+    auto newOp = rewriter.create<GenericOp>(
+        genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands,
+        outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps),
+        genericOp.iterator_types(), genericOp.docAttr(),
+        genericOp.library_callAttr());
+    // Copy over unknown attributes. They might be load bearing for some flow.
+    ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
+    for (NamedAttribute kv : genericOp->getAttrs()) {
+      if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) {
+        newOp->setAttr(kv.getName(), kv.getValue());
+      }
+    }
+    rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
+                                newOp.region().begin());
+    rewriter.replaceOp(genericOp, newOp->getResults());
+    return success();
+  }
+};
+
+/// Fold linalg.fill into linalg.generic
+struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    if (!genericOp.hasTensorSemantics())
+      return failure();
+    bool fillFound = false;
+    Block &payload = genericOp.region().front();
+    for (OpOperand *opOperand : genericOp.getInputOperands()) {
+      FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
+      if (fillOp) {
+        fillFound = true;
+        payload.getArgument(opOperand->getOperandNumber())
+            .replaceAllUsesWith(fillOp.value());
+      }
+    }
+    // fail if there are no FillOps to fold.
+    return success(fillFound);
+  }
+};
 } // namespace
 
 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context);
+  results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp,
+              DeadArgsGenericOpInputs, FoldFillWithGenericOp>(context);
 }
 
 LogicalResult GenericOp::fold(ArrayRef<Attribute>,

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 28655b96e5df6..0e0faab56f6c9 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -325,6 +325,106 @@ func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @fold_fill_generic_basic
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> { 
+//   CHECK-NOT: linalg.fill
+//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+//  CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
+//  CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
+#map0 = affine_map<(d0) -> (d0)>
+func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 7.0 : f32
+  %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+  %1 = linalg.init_tensor [%0] : tensor<?xf32>
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf32>) -> tensor<?xf32>
+  %3 = linalg.init_tensor [%0] : tensor<?xf32>
+  %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf32>, tensor<?xf32>) outs (%3:tensor<?xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+    %5 = arith.addf  %arg1, %arg2 : f32
+	linalg.yield %5 : f32
+  } -> tensor<?xf32>
+  return %4 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_fill_generic_mixedaccess
+//   CHECK-NOT: linalg.fill
+//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+//   CHECK-NOT: ins
+//  CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+func @fold_fill_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 0 : index
+  %cst1 = arith.constant 7.0 : f32
+  %cst2 = arith.constant 6.0 : f32
+  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+  %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %4 = linalg.init_tensor [%1, %0] : tensor<?x?xf32>
+  %5 = linalg.fill ins(%cst2 : f32) outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %6 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+  %7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%6:tensor<?x?xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+    %8 = arith.divf  %arg1, %arg2 : f32
+	linalg.yield %8 : f32
+  } -> tensor<?x?xf32>
+  return %7 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @remove_deadargs_generic_basic
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> { 
+//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+//  CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
+//  CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
+#map0 = affine_map<(d0) -> (d0)>
+func @remove_deadargs_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 7.0 : f32
+  %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+  %1 = linalg.init_tensor [%0] : tensor<?xf32>
+  %2 = linalg.init_tensor [%0] : tensor<?xf32>
+  %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %1 : tensor<?xf32>, tensor<?xf32>) outs (%2:tensor<?xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+    %4 = arith.addf  %arg1, %cst : f32
+	linalg.yield %4 : f32
+  } -> tensor<?xf32>
+  return %3 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @remove_deadargs_generic_mixedaccess
+//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+//   CHECK-NOT: ins
+//  CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+func @remove_deadargs_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 0 : index
+  %cst1 = arith.constant 7.0 : f32
+  %cst2 = arith.constant 6.0 : f32
+  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+  %3 = linalg.init_tensor [%1, %0] : tensor<?x?xf32>
+  %4 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+  %5 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%2, %3 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%4:tensor<?x?xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+    %6 = arith.divf  %cst1, %cst2 : f32
+	linalg.yield %6 : f32
+  } -> tensor<?x?xf32>
+  return %5 : tensor<?x?xf32>
+}
+
+// -----
 // CHECK-LABEL: func @fold_fill_reshape()
 func @fold_fill_reshape() -> tensor<6x4xf32> {
   %zero = arith.constant 0.0 : f32

diff  --git a/mlir/test/Dialect/Linalg/fusion-indexed.mlir b/mlir/test/Dialect/Linalg/fusion-indexed.mlir
index 1b075cc5ac483..03ac767136f00 100644
--- a/mlir/test/Dialect/Linalg/fusion-indexed.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-indexed.mlir
@@ -46,7 +46,8 @@ func @fuse_indexed_consumer(%A: memref<?x?xf32>,
         %10 = arith.index_cast %7 : index to i32
         %11 = arith.sitofp %10 : i32 to f32
         %12 = arith.addf %9, %11 : f32
-        linalg.yield %12 : f32
+        %13 = arith.addf %12, %arg4 : f32
+        linalg.yield %13 : f32
       }
     }
   }


        


More information about the llvm-branch-commits mailing list