[Mlir-commits] [mlir] 9d3057c - [mlir][Linalg] Add support for lowerPack on dynamic outer shapes.
Hanhan Wang
llvmlistbot at llvm.org
Thu May 11 10:47:32 PDT 2023
Author: Hanhan Wang
Date: 2023-05-11T10:47:19-07:00
New Revision: 9d3057c1cf11759720f4d71f34b4e0e14d273f57
URL: https://github.com/llvm/llvm-project/commit/9d3057c1cf11759720f4d71f34b4e0e14d273f57
DIFF: https://github.com/llvm/llvm-project/commit/9d3057c1cf11759720f4d71f34b4e0e14d273f57.diff
LOG: [mlir][Linalg] Add support for lowerPack on dynamic outer shapes.
The revision adds support for tensor.pack op decomposition when all
inner tile sizes are static. The generated tensor.expand_shape op is
still valid because only one of the expanding dimension is dynamic.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D150233
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/transform-lower-pack.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 984ff35515230..a9e8ac0bbabbb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -477,7 +477,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
// 1. Filter out NYI cases.
auto packedTensorType =
packOp->getResultTypes().front().cast<RankedTensorType>();
- if (!packedTensorType.hasStaticShape()) {
+ if (llvm::any_of(packOp.getStaticInnerTiles(),
+ [](int64_t size) { return ShapedType::isDynamic(size); })) {
return rewriter.notifyMatchFailure(
packOp,
"non-static shape NYI, needs a more powerful tensor.expand_shape op");
@@ -520,6 +521,22 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
// 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
+ SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
+ rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
+ rewriter.getIndexAttr(0));
+ for (auto [pos, innerSize] :
+ llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
+ OpFoldResult origSize = rewriter.createOrFold<tensor::DimOp>(
+ loc, packOp.getSource(),
+ rewriter.create<arith::ConstantIndexOp>(loc, pos));
+ AffineExpr s0, d0;
+ bindDims(rewriter.getContext(), d0);
+ bindSymbols(rewriter.getContext(), s0);
+ auto map = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0 - d0);
+ highs[pos] = affine::makeComposedFoldedAffineApply(rewriter, loc, map,
+ {origSize, innerSize});
+ }
RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
packingMetadata.reassociations);
@@ -529,8 +546,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
}
auto padOp =
- tensor::createPadHighOp(collapsed, packOp.getSource(), paddingValue,
- /*nofold=*/false, loc, rewriter);
+ rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
+ highs, paddingValue, /*nofold=*/false);
LLVM_DEBUG(
DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 40f9f3e0761e7..9e33b27505f16 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -1,12 +1,11 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -cse --split-input-file | FileCheck %s
// CHECK-LABEL: func.func @pack(
func.func @pack(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
%cst_0 = arith.constant 0.0 : f32
// tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]]
// CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<17x8x2x32x16x16xf32>
@@ -33,8 +32,7 @@ transform.sequence failures(propagate) {
func.func @pack(%arg0: tensor<128x8xf32>, %arg1: tensor<8x8x16x1xf32>) -> tensor<8x8x16x1xf32> {
// tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]]]
+ // CHECK: tensor.pad {{.*}} low[0, 0]
// CHECK: : tensor<128x8xf32> to tensor<128x8xf32>
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3]]
// CHECK-SAME: : tensor<128x8xf32> into tensor<8x16x8x1xf32>
@@ -64,8 +62,7 @@ func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x13
%cst_0 = arith.constant 0.0 : f32
// tensor.pack is lowered to tensor.pad + tensor.insert_slice
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
@@ -100,8 +97,7 @@ transform.sequence failures(propagate) {
func.func @pack_not_a_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x16x16x136x64xf32>) -> tensor<1x1x16x16x136x64xf32> {
%cst_0 = arith.constant 0.0 : f32
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]]
// CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<1x136x1x64x16x16xf32>
@@ -190,8 +186,7 @@ transform.sequence failures(propagate) {
func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
%dest: tensor<200x4x16x100x16x32xi32>)
-> tensor<200x4x16x100x16x32xi32> {
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
// CHECK: : tensor<100x200x128x256xi32> to tensor<100x200x128x256xi32>
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
// CHECK-SAME: : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
@@ -221,8 +216,7 @@ transform.sequence failures(propagate) {
func.func @pack_with_pad_and_outer_dims_perm(%src: tensor<100x200x127x255xi32>,
%dest: tensor<200x4x16x100x16x32xi32>)
-> tensor<200x4x16x100x16x32xi32> {
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
// CHECK: : tensor<100x200x127x255xi32> to tensor<100x200x128x256xi32>
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
// CHECK-SAME: : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
@@ -250,13 +244,64 @@ transform.sequence failures(propagate) {
// -----
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 16) * 16)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 32) * 32)>
+// CHECK: func.func @dynamic_pack_pad_transpose_inner_and_outer_dims(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+func.func @dynamic_pack_pad_transpose_inner_and_outer_dims(%source: tensor<?x?xf32>) -> tensor<?x?x16x32xf32> {
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
+ // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
+ // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[SRC]], %[[C0]]
+ // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[SRC]], %[[C1]]
+ // CHECK-DAG: %[[OUT_D0:.+]] = arith.ceildivui %[[D1]], %[[C16]] : index
+ // CHECK-DAG: %[[OUT_D1:.+]] = arith.ceildivui %[[D0]], %[[C32]] : index
+ // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[OUT_D0]], %[[OUT_D1]]) : tensor<?x?x16x32xf32>
+ // CHECK-DAG: %[[H1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
+ // CHECK-DAG: %[[H0:.+]] = affine.apply #[[MAP1]]()[%[[D0]]]
+ // CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high[%[[H0]], %[[H1]]]
+ // CHECK: : tensor<?x?xf32> to tensor<?x?xf32>
+ // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0, 1], [2, 3]]
+ // CHECK-SAME: : tensor<?x?xf32> into tensor<?x32x?x16xf32>
+ // CHECK: %[[TRANSP:.+]] = linalg.transpose
+ // CHECK-SAME: ins(%[[EXPAND]] : tensor<?x32x?x16xf32>)
+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<?x?x16x32xf32>)
+ // CHECK-SAME: permutation = [2, 0, 3, 1]
+ // CHECK: return %[[TRANSP]]
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %source, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %source, %c1 : tensor<?x?xf32>
+ %padding_value = arith.constant 0.0 : f32
+
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %tiled_d0 = arith.ceildivui %d0, %c32 : index
+ %tiled_d1 = arith.ceildivui %d1, %c16 : index
+ %init_pack = tensor.empty(%tiled_d1, %tiled_d0) : tensor<?x?x16x32xf32>
+ %pack = tensor.pack %source padding_value(%padding_value : f32)
+ outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %init_pack
+ : tensor<?x?xf32> -> tensor<?x?x16x32xf32>
+ return %pack : tensor<?x?x16x32xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!pdl.operation) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
+
// CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
%cst_0 = arith.constant 0.0 : f32
// tensor.pack is lowered to tensor.pad + tensor.insert_slice
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
More information about the Mlir-commits
mailing list