[Mlir-commits] [mlir] [mlir][Linalg] Fix foldFillPackIntoFillOp to work for general cases (PR #74148)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 1 13:55:37 PST 2023


https://github.com/yzhang93 created https://github.com/llvm/llvm-project/pull/74148

None

>From 87b8f55cc4397a80eb974c2364b1d143be159dab Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Fri, 1 Dec 2023 13:52:14 -0800
Subject: [PATCH] [mlir][Linalg]Fix foldFillPackIntoFillOp to work for general
 cases

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   | 18 ++----------------
 mlir/test/Dialect/Linalg/canonicalize.mlir | 19 +++++++++++++++++++
 2 files changed, 21 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 58af9995548e939..9a4d5e8845b2143 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -765,26 +765,12 @@ static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
     if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
       return failure();
 
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(fillOp);
-
   Value packOpDest = packOp.getDest();
   if (!packOpDest.hasOneUse())
     return failure();
-  if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
-    packOpDest = tensor::PackOp::createDestinationTensor(
-        rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
-        packOp.getMixedTiles(), packOp.getInnerDimsPos(),
-        packOp.getOuterDimsPerm());
-  } else {
-    DominanceInfo dom(fillOp);
-    if (!dom.properlyDominates(packOpDest, fillOp))
-      return failure();
-  }
 
-  Value fillDest = packOpDest;
-  return clone(rewriter, fillOp, packOpDest.getType(),
-               {fillOp.value(), fillDest});
+  return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
+                                         packOp.getDest());
 }
 
 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index e875bae4730946b..052dc367ca67791 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -368,6 +368,25 @@ func.func @fill_pack() -> tensor<24x32x16x16xf32> {
 
 // -----
 
+func.func @fill_pack_general() -> tensor<1x1x8x4x4x8xi32>{
+  %c0_i32 = arith.constant 0 : i32
+  %alloc = memref.alloc() : memref<1x1x8x4x4x8xi32>
+  %9 = tensor.empty() : tensor<1x1x16x64xi32>
+  %extracted_slice_15 = tensor.extract_slice %9[0, 0, 0, 0] [1, 1, 16, 64] [1, 1, 1, 1] : tensor<1x1x16x64xi32> to tensor<1x1x16x64xi32>
+  %16 = linalg.fill ins(%c0_i32 : i32) outs(%extracted_slice_15 : tensor<1x1x16x64xi32>) -> tensor<1x1x16x64xi32>
+  %0 = bufferization.to_tensor %alloc restrict writable : memref<1x1x8x4x4x8xi32>
+  %pack_18 = tensor.pack %16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %0 : tensor<1x1x16x64xi32> -> tensor<1x1x8x4x4x8xi32>
+  return %pack_18 : tensor<1x1x8x4x4x8xi32>
+}
+
+// CHECK-LABEL: func.func @fill_pack_general
+// CHECK:         %[[ALLOC:.+]] = memref.alloc() : memref<1x1x8x4x4x8xi32>
+// CHECK:         %[[TENSOR:.+]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[TENSOR]]
+// CHECK:         return %[[FILL]]
+
+// -----
+
 #map = affine_map<()[s0] -> (s0 ceildiv 16)>
 func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
   %cst = arith.constant 0.000000e+00 : f32



More information about the Mlir-commits mailing list