[Mlir-commits] [mlir] 5ddff0d - [mlir][linalg] Fix a bug in lower_pack when there are no padding values.
Hanhan Wang
llvmlistbot at llvm.org
Fri Apr 14 11:27:13 PDT 2023
Author: Hanhan Wang
Date: 2023-04-14T11:27:00-07:00
New Revision: 5ddff0d8cd79fbef2b32ecbe4d1caf1e08a037e5
URL: https://github.com/llvm/llvm-project/commit/5ddff0d8cd79fbef2b32ecbe4d1caf1e08a037e5
DIFF: https://github.com/llvm/llvm-project/commit/5ddff0d8cd79fbef2b32ecbe4d1caf1e08a037e5.diff
LOG: [mlir][linalg] Fix a bug in lower_pack when there are no padding values.
Reviewed By: chelini
Differential Revision: https://reviews.llvm.org/D148061
Added:
Modified:
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-lower-pack.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1abbe1699382..23c402020052 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -790,7 +790,7 @@ static FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
packingMetadata.reassociations);
Value paddingValue = packOp.getPaddingValue();
if (!paddingValue) {
- rewriter.create<arith::ConstantOp>(
+ paddingValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
}
auto padOp =
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index f42e3e3dbbcb..83141ec75aba 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -27,6 +27,36 @@ transform.sequence failures(propagate) {
-> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
}
+// -----
+
+ // CHECK-LABEL: func.func @pack(
+func.func @pack(%arg0: tensor<128x8xf32>, %arg1: tensor<8x8x16x1xf32>) -> tensor<8x8x16x1xf32> {
+
+ // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]]]
+ // CHECK: : tensor<128x8xf32> to tensor<128x8xf32>
+ // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3]]
+ // CHECK-SAME: : tensor<128x8xf32> into tensor<8x16x8x1xf32>
+ // CHECK: linalg.transpose
+ // CHECK-SAME: ins(%{{.*}} : tensor<8x16x8x1xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<8x8x16x1xf32>)
+ // CHECK-SAME: permutation = [0, 2, 1, 3]
+
+ %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %arg1
+ : tensor<128x8xf32> -> tensor<8x8x16x1xf32>
+
+ return %pack : tensor<8x8x16x1xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!pdl.operation) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
// -----
// CHECK-LABEL: func.func @pack_as_pad(
More information about the Mlir-commits
mailing list