[Mlir-commits] [mlir] 5ddff0d - [mlir][linalg] Fix a bug in lower_pack when there are no padding values.

Hanhan Wang llvmlistbot at llvm.org
Fri Apr 14 11:27:13 PDT 2023


Author: Hanhan Wang
Date: 2023-04-14T11:27:00-07:00
New Revision: 5ddff0d8cd79fbef2b32ecbe4d1caf1e08a037e5

URL: https://github.com/llvm/llvm-project/commit/5ddff0d8cd79fbef2b32ecbe4d1caf1e08a037e5
DIFF: https://github.com/llvm/llvm-project/commit/5ddff0d8cd79fbef2b32ecbe4d1caf1e08a037e5.diff

LOG: [mlir][linalg] Fix a bug in lower_pack when there are no padding values.

Reviewed By: chelini

Differential Revision: https://reviews.llvm.org/D148061

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1abbe1699382..23c402020052 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -790,7 +790,7 @@ static FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
       packingMetadata.reassociations);
   Value paddingValue = packOp.getPaddingValue();
   if (!paddingValue) {
-    rewriter.create<arith::ConstantOp>(
+    paddingValue = rewriter.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
   }
   auto padOp =

diff  --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index f42e3e3dbbcb..83141ec75aba 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -27,6 +27,36 @@ transform.sequence failures(propagate) {
     -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
 }
 
+// -----
+
+  // CHECK-LABEL: func.func @pack(
+func.func @pack(%arg0: tensor<128x8xf32>, %arg1: tensor<8x8x16x1xf32>) -> tensor<8x8x16x1xf32> {
+
+  // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
+  //      CHECK: %[[C0:.*]] = arith.constant 0 : index
+  //      CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]]]
+  //      CHECK:   : tensor<128x8xf32> to tensor<128x8xf32>
+  //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3]]
+  // CHECK-SAME:   : tensor<128x8xf32> into tensor<8x16x8x1xf32>
+  //      CHECK: linalg.transpose
+  // CHECK-SAME:   ins(%{{.*}} : tensor<8x16x8x1xf32>)
+  // CHECK-SAME:   outs(%{{.*}} : tensor<8x8x16x1xf32>)
+  // CHECK-SAME:   permutation = [0, 2, 1, 3]
+
+  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %arg1
+    : tensor<128x8xf32> -> tensor<8x8x16x1xf32>
+
+  return %pack : tensor<8x8x16x1xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+    : (!pdl.operation) -> !transform.op<"tensor.pack">
+  transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+    -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
 // -----
 
 // CHECK-LABEL: func.func @pack_as_pad(


        


More information about the Mlir-commits mailing list