[Mlir-commits] [mlir] [mlir][Linalg] Move `linalg.fill` -> `linalg.pack` pattern into `fill` canonicalization patterns. (PR #66002)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 11 12:57:56 PDT 2023


llvmbot wrote:

@llvm/pr-subscribers-mlir-core

<details>
<summary>Changes</summary>

This pattern fits better with the other canonicalization patterns that exist for `linalg.fill`.
--
Full diff: https://github.com/llvm/llvm-project/pull/66002.diff

4 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+57-6) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (-60) 
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+44) 
- (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (-44) 


<pre>
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index bcc36d6bd3e95e5..e05a82855c66bd0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -737,7 +737,8 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
 
   LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
-    // See if tensor input of tensor.extract op is the result of a linalg.fill op.
+    // See if tensor input of tensor.extract op is the result of a linalg.fill
+    // op.
     auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
     if (!fillOp)
       return failure();
@@ -751,15 +752,65 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
   }
 };
 
+/// Folds pack(fill) into a single fill op if
+///   1. The pack op does not have padding value, or
+///   2. The filled value and padding value are the same.
+static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
+                                                tensor::PackOp packOp) {
+  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
+  if (!fillOp)
+    return failure();
+
+  if (auto paddingValue = packOp.getPaddingValue())
+    if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
+      return failure();
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(fillOp);
+
+  Value packOpDest = packOp.getDest();
+  if (!packOpDest.hasOneUse())
+    return failure();
+  if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
+    packOpDest = tensor::PackOp::createDestinationTensor(
+        rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
+        packOp.getMixedTiles(), packOp.getInnerDimsPos(),
+        packOp.getOuterDimsPerm());
+  } else {
+    DominanceInfo dom(fillOp);
+    if (!dom.properlyDominates(packOpDest, fillOp))
+      return failure();
+  }
+
+  Value fillDest = packOpDest;
+  return clone(rewriter, fillOp, packOpDest.getType(),
+               {fillOp.value(), fillDest});
+}
+
+/// Wrapper pattern that applies foldFillPackIntoFillOp method.
+struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
+public:
+  FoldFillWithPack(MLIRContext *context)
+      : OpRewritePattern<tensor::PackOp>(context) {}
+
+  LogicalResult matchAndRewrite(tensor::PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
+    if (failed(fillOp))
+      return failure();
+    rewriter.replaceOp(packOp, fillOp.value().result());
+    return success();
+  }
+};
+
 } // namespace
 
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results
-      .add<FoldFillWithTensorExtract,
-           FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
-           FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
-           FoldInsertPadIntoFill>(context);
+  results.add<FoldFillWithTensorExtract, FoldFillWithPack, FoldFillWithPad,
+              FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
+              FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
+              FoldInsertPadIntoFill>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 1ddd8b144c60e85..95a20f2369f9e07 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -448,46 +448,6 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
                        *packInfo);
 }
 
-/// Folds pack(fill) into a single fill op if
-///   1. The pack op does not have padding value, or
-///   2. The filled value and padding value are the same.
-static FailureOr<FillOp>
-foldFillPackIntoFillOp(RewriterBase &rewriter, tensor::PackOp packOp,
-                       ControlPropagationFn controlFn) {
-  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
-  if (!fillOp)
-    return failure();
-
-  // User controlled propagation function.
-  if (!controlFn(fillOp))
-    return failure();
-
-  if (auto paddingValue = packOp.getPaddingValue())
-    if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
-      return failure();
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(fillOp);
-
-  Value packOpDest = packOp.getDest();
-  if (!packOpDest.hasOneUse())
-    return failure();
-  if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
-    packOpDest = tensor::PackOp::createDestinationTensor(
-        rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
-        packOp.getMixedTiles(), packOp.getInnerDimsPos(),
-        packOp.getOuterDimsPerm());
-  } else {
-    DominanceInfo dom(fillOp);
-    if (!dom.properlyDominates(packOpDest, fillOp))
-      return failure();
-  }
-
-  Value fillDest = packOpDest;
-  return clone(rewriter, fillOp, packOpDest.getType(),
-               {fillOp.value(), fillDest});
-}
-
 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
 struct BubbleUpPackOpThroughGenericOpPattern
     : public OpRewritePattern<tensor::PackOp> {
@@ -510,25 +470,6 @@ struct BubbleUpPackOpThroughGenericOpPattern
   ControlPropagationFn controlFn;
 };
 
-/// Wrapper pattern that applies foldFillPackIntoFillOp method.
-struct FoldFillPackIntoFillOpPattern : public OpRewritePattern<tensor::PackOp> {
-public:
-  FoldFillPackIntoFillOpPattern(MLIRContext *context, ControlPropagationFn fun)
-      : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
-
-  LogicalResult matchAndRewrite(tensor::PackOp packOp,
-                                PatternRewriter &rewriter) const override {
-    auto fillOp = foldFillPackIntoFillOp(rewriter, packOp, controlFn);
-    if (failed(fillOp))
-      return failure();
-    rewriter.replaceOp(packOp, fillOp.value().result());
-    return success();
-  }
-
-private:
-  ControlPropagationFn controlFn;
-};
-
 // TODO: Relax this restriction. We should unpack a generic op also
 // in the presence of multiple unpack ops as producers.
 /// Return the unpacked operand, if present, for the current generic op.
@@ -750,7 +691,6 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
     const ControlPropagationFn &controlPackUnPackPropagation) {
   patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
-                  FoldFillPackIntoFillOpPattern,
                   PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
       patterns.getContext(), controlPackUnPackPropagation);
 }
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 297b5c4e332c811..7793e435582746c 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -353,6 +353,50 @@ func.func @fold_fill_extract(%arg0 : i1) -> i1 {
 
 // -----
 
+func.func @fill_pack() -> tensor<24x32x16x16xf32> {
+  %dest = tensor.empty() : tensor<384x512xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<24x32x16x16xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%dest : tensor<384x512xf32>) -> tensor<384x512xf32>
+  %pack = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<384x512xf32> -> tensor<24x32x16x16xf32>
+  return %pack : tensor<24x32x16x16xf32>
+}
+// CHECK-LABEL: func.func @fill_pack
+// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty() : tensor<24x32x16x16xf32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
+// CHECK:         return %[[FILL]]
+
+// -----
+
+#map = affine_map<()[s0] -> (s0 ceildiv 16)>
+func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %dim = tensor.dim %0, %c0 : tensor<?x?xf32>
+  %dim_0 = tensor.dim %0, %c1 : tensor<?x?xf32>
+  %1 = affine.apply #map()[%dim]
+  %2 = affine.apply #map()[%dim_0]
+  %3 = tensor.empty(%1, %2) : tensor<?x?x16x16xf32>
+  %pack = tensor.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %3 : tensor<?x?xf32> -> tensor<?x?x16x16xf32>
+  return %pack : tensor<?x?x16x16xf32>
+}
+// CHECK-DAG:   #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+// CHECK:       func.func @dynamic_fill_pack
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+// CHECK:         %[[D0:.+]] = tensor.dim %[[DEST]], %[[C0]]
+// CHECK:         %[[D1:.+]] = tensor.dim %[[DEST]], %[[C1]]
+// CHECK:         %[[PACKED_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK:         %[[PACKED_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]]
+// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty(%[[PACKED_D0]], %[[PACKED_D1]]) : tensor<?x?x16x16xf32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
+// CHECK:         return %[[FILL]]
+
+// -----
+
 // CHECK: func @fold_self_copy
 func.func @fold_self_copy(%0 : memref<4x16xf32>) {
 // CHECK-NEXT: return
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 3f98bf06433d2ec..4c59c97aecc2519 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -839,47 +839,3 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-SAME:     inner_dims_pos = [0] inner_tiles = [16]
 // CHECK-SAME:     into %[[UNPACK_NEW_DEST]]
 // CHECK:        return %[[UNPACK]] : tensor<16x540x960xi32>
-
-// -----
-
-func.func @fill_pack() -> tensor<24x32x16x16xf32> {
-  %dest = tensor.empty() : tensor<384x512xf32>
-  %cst = arith.constant 0.000000e+00 : f32
-  %0 = tensor.empty() : tensor<24x32x16x16xf32>
-  %1 = linalg.fill ins(%cst : f32) outs(%dest : tensor<384x512xf32>) -> tensor<384x512xf32>
-  %pack = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<384x512xf32> -> tensor<24x32x16x16xf32>
-  return %pack : tensor<24x32x16x16xf32>
-}
-// CHECK-LABEL: func.func @fill_pack
-// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty() : tensor<24x32x16x16xf32>
-// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
-// CHECK:         return %[[FILL]]
-
-// -----
-
-#map = affine_map<()[s0] -> (s0 ceildiv 16)>
-func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
-  %dim = tensor.dim %0, %c0 : tensor<?x?xf32>
-  %dim_0 = tensor.dim %0, %c1 : tensor<?x?xf32>
-  %1 = affine.apply #map()[%dim]
-  %2 = affine.apply #map()[%dim_0]
-  %3 = tensor.empty(%1, %2) : tensor<?x?x16x16xf32>
-  %pack = tensor.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %3 : tensor<?x?xf32> -> tensor<?x?x16x16xf32>
-  return %pack : tensor<?x?x16x16xf32>
-}
-// CHECK-DAG:   #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
-// CHECK:       func.func @dynamic_fill_pack
-// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-// CHECK:         %[[D0:.+]] = tensor.dim %[[DEST]], %[[C0]]
-// CHECK:         %[[D1:.+]] = tensor.dim %[[DEST]], %[[C1]]
-// CHECK:         %[[PACKED_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK:         %[[PACKED_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]]
-// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty(%[[PACKED_D0]], %[[PACKED_D1]]) : tensor<?x?x16x16xf32>
-// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
-// CHECK:         return %[[FILL]]
</pre>

</details>

https://github.com/llvm/llvm-project/pull/66002


More information about the Mlir-commits mailing list