[Mlir-commits] [mlir] [mlir][Linalg] Fix foldFillPackIntoFillOp to work for general cases (PR #74148)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 1 13:55:37 PST 2023
https://github.com/yzhang93 created https://github.com/llvm/llvm-project/pull/74148
None
>From 87b8f55cc4397a80eb974c2364b1d143be159dab Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Fri, 1 Dec 2023 13:52:14 -0800
Subject: [PATCH] [mlir][Linalg]Fix foldFillPackIntoFillOp to work for general
cases
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 18 ++----------------
mlir/test/Dialect/Linalg/canonicalize.mlir | 19 +++++++++++++++++++
2 files changed, 21 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 58af9995548e939..9a4d5e8845b2143 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -765,26 +765,12 @@ static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
return failure();
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(fillOp);
-
Value packOpDest = packOp.getDest();
if (!packOpDest.hasOneUse())
return failure();
- if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
- packOpDest = tensor::PackOp::createDestinationTensor(
- rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
- packOp.getMixedTiles(), packOp.getInnerDimsPos(),
- packOp.getOuterDimsPerm());
- } else {
- DominanceInfo dom(fillOp);
- if (!dom.properlyDominates(packOpDest, fillOp))
- return failure();
- }
- Value fillDest = packOpDest;
- return clone(rewriter, fillOp, packOpDest.getType(),
- {fillOp.value(), fillDest});
+ return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
+ packOp.getDest());
}
/// Wrapper pattern that applies foldFillPackIntoFillOp method.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index e875bae4730946b..052dc367ca67791 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -368,6 +368,25 @@ func.func @fill_pack() -> tensor<24x32x16x16xf32> {
// -----
+func.func @fill_pack_general() -> tensor<1x1x8x4x4x8xi32>{
+ %c0_i32 = arith.constant 0 : i32
+ %alloc = memref.alloc() : memref<1x1x8x4x4x8xi32>
+ %9 = tensor.empty() : tensor<1x1x16x64xi32>
+ %extracted_slice_15 = tensor.extract_slice %9[0, 0, 0, 0] [1, 1, 16, 64] [1, 1, 1, 1] : tensor<1x1x16x64xi32> to tensor<1x1x16x64xi32>
+ %16 = linalg.fill ins(%c0_i32 : i32) outs(%extracted_slice_15 : tensor<1x1x16x64xi32>) -> tensor<1x1x16x64xi32>
+ %0 = bufferization.to_tensor %alloc restrict writable : memref<1x1x8x4x4x8xi32>
+ %pack_18 = tensor.pack %16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %0 : tensor<1x1x16x64xi32> -> tensor<1x1x8x4x4x8xi32>
+ return %pack_18 : tensor<1x1x8x4x4x8xi32>
+}
+
+// CHECK-LABEL: func.func @fill_pack_general
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x1x8x4x4x8xi32>
+// CHECK: %[[TENSOR:.+]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[TENSOR]]
+// CHECK: return %[[FILL]]
+
+// -----
+
#map = affine_map<()[s0] -> (s0 ceildiv 16)>
func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
%cst = arith.constant 0.000000e+00 : f32
More information about the Mlir-commits
mailing list