[Mlir-commits] [mlir] 8eed9f3 - [mlir][linalg] Add support for folding pack(fill) into fill.
Hanhan Wang
llvmlistbot at llvm.org
Fri May 5 11:42:17 PDT 2023
Author: Hanhan Wang
Date: 2023-05-05T11:42:06-07:00
New Revision: 8eed9f38ca36bcb972a2aeff496e073de89c1b38
URL: https://github.com/llvm/llvm-project/commit/8eed9f38ca36bcb972a2aeff496e073de89c1b38
DIFF: https://github.com/llvm/llvm-project/commit/8eed9f38ca36bcb972a2aeff496e073de89c1b38.diff
LOG: [mlir][linalg] Add support for folding pack(fill) into fill.
Reviewed By: qedawkins
Differential Revision: https://reviews.llvm.org/D149801
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
mlir/test/Dialect/Linalg/data-layout-propagation.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index b1a3bfff239f2..8cf1718a92b47 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -446,6 +446,46 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
*packInfo);
}
+/// Folds pack(fill) into a single fill op if
+/// 1. The pack op does not have padding value, or
+/// 2. The filled value and padding value are the same.
+static FailureOr<FillOp>
+foldFillPackIntoFillOp(RewriterBase &rewriter, tensor::PackOp packOp,
+ ControlPropagationFn controlFn) {
+ auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
+ if (!fillOp)
+ return failure();
+
+ // User controlled propagation function.
+ if (!controlFn(fillOp))
+ return failure();
+
+ if (auto paddingValue = packOp.getPaddingValue())
+ if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
+ return failure();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(fillOp);
+
+ Value packOpDest = packOp.getDest();
+ if (!packOpDest.hasOneUse())
+ return failure();
+ if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
+ packOpDest = tensor::PackOp::createDestinationTensor(
+ rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
+ packOp.getMixedTiles(), packOp.getInnerDimsPos(),
+ packOp.getOuterDimsPerm());
+ } else {
+ DominanceInfo dom(fillOp);
+ if (!dom.properlyDominates(packOpDest, fillOp))
+ return failure();
+ }
+
+ Value fillDest = packOpDest;
+ return clone(rewriter, fillOp, packOpDest.getType(),
+ {fillOp.value(), fillDest});
+}
+
/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
struct BubbleUpPackOpThroughGenericOpPattern
: public OpRewritePattern<tensor::PackOp> {
@@ -468,6 +508,25 @@ struct BubbleUpPackOpThroughGenericOpPattern
ControlPropagationFn controlFn;
};
+/// Wrapper pattern that applies foldFillPackIntoFillOp method.
+struct FoldFillPackIntoFillOpPattern : public OpRewritePattern<tensor::PackOp> {
+public:
+ FoldFillPackIntoFillOpPattern(MLIRContext *context, ControlPropagationFn fun)
+ : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
+
+ LogicalResult matchAndRewrite(tensor::PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ auto fillOp = foldFillPackIntoFillOp(rewriter, packOp, controlFn);
+ if (failed(fillOp))
+ return failure();
+ rewriter.replaceOp(packOp, fillOp.value().result());
+ return success();
+ }
+
+private:
+ ControlPropagationFn controlFn;
+};
+
// TODO: Relax this restriction. We should unpack an elementwise also
// in the presence of multiple unpack ops as producers.
/// Return the unpacked operand, if present, for the current generic op.
@@ -689,6 +748,7 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation) {
patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
+ FoldFillPackIntoFillOpPattern,
PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
patterns.getContext(), controlPackUnPackPropagation);
}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index f8844b7271589..8d00770c672d3 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -830,3 +830,47 @@ func.func @unpack_
diff erent_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
// CHECK-SAME: into %[[UNPACK_NEW_DEST]]
// CHECK: return %[[UNPACK]] : tensor<16x540x960xi32>
+
+// -----
+
+func.func @fill_pack() -> tensor<24x32x16x16xf32> {
+ %dest = tensor.empty() : tensor<384x512xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<24x32x16x16xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%dest : tensor<384x512xf32>) -> tensor<384x512xf32>
+ %pack = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<384x512xf32> -> tensor<24x32x16x16xf32>
+ return %pack : tensor<24x32x16x16xf32>
+}
+// CHECK-LABEL: func.func @fill_pack
+// CHECK: %[[PACKED_EMPTY:.+]] = tensor.empty() : tensor<24x32x16x16xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
+// CHECK: return %[[FILL]]
+
+// -----
+
+#map = affine_map<()[s0] -> (s0 ceildiv 16)>
+func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %dim = tensor.dim %0, %c0 : tensor<?x?xf32>
+ %dim_0 = tensor.dim %0, %c1 : tensor<?x?xf32>
+ %1 = affine.apply #map()[%dim]
+ %2 = affine.apply #map()[%dim_0]
+ %3 = tensor.empty(%1, %2) : tensor<?x?x16x16xf32>
+ %pack = tensor.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %3 : tensor<?x?xf32> -> tensor<?x?x16x16xf32>
+ return %pack : tensor<?x?x16x16xf32>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+// CHECK: func.func @dynamic_fill_pack
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[D0:.+]] = tensor.dim %[[DEST]], %[[C0]]
+// CHECK: %[[D1:.+]] = tensor.dim %[[DEST]], %[[C1]]
+// CHECK: %[[PACKED_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK: %[[PACKED_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]]
+// CHECK: %[[PACKED_EMPTY:.+]] = tensor.empty(%[[PACKED_D0]], %[[PACKED_D1]]) : tensor<?x?x16x16xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
+// CHECK: return %[[FILL]]
More information about the Mlir-commits
mailing list