[Mlir-commits] [mlir] 9d3057c - [mlir][Linalg] Add support for lowerPack on dynamic outer shapes.

Hanhan Wang llvmlistbot at llvm.org
Thu May 11 10:47:32 PDT 2023


Author: Hanhan Wang
Date: 2023-05-11T10:47:19-07:00
New Revision: 9d3057c1cf11759720f4d71f34b4e0e14d273f57

URL: https://github.com/llvm/llvm-project/commit/9d3057c1cf11759720f4d71f34b4e0e14d273f57
DIFF: https://github.com/llvm/llvm-project/commit/9d3057c1cf11759720f4d71f34b4e0e14d273f57.diff

LOG: [mlir][Linalg] Add support for lowerPack on dynamic outer shapes.

The revision adds support for tensor.pack op decomposition when all
inner tile sizes are static. The generated tensor.expand_shape op is
still valid because only one of the expanding dimension is dynamic.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D150233

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 984ff35515230..a9e8ac0bbabbb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -477,7 +477,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   // 1. Filter out NYI cases.
   auto packedTensorType =
       packOp->getResultTypes().front().cast<RankedTensorType>();
-  if (!packedTensorType.hasStaticShape()) {
+  if (llvm::any_of(packOp.getStaticInnerTiles(),
+                   [](int64_t size) { return ShapedType::isDynamic(size); })) {
     return rewriter.notifyMatchFailure(
         packOp,
         "non-static shape NYI, needs a more powerful tensor.expand_shape op");
@@ -520,6 +521,22 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
 
   // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
+  SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
+                                 rewriter.getIndexAttr(0));
+  SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
+                                  rewriter.getIndexAttr(0));
+  for (auto [pos, innerSize] :
+       llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
+    OpFoldResult origSize = rewriter.createOrFold<tensor::DimOp>(
+        loc, packOp.getSource(),
+        rewriter.create<arith::ConstantIndexOp>(loc, pos));
+    AffineExpr s0, d0;
+    bindDims(rewriter.getContext(), d0);
+    bindSymbols(rewriter.getContext(), s0);
+    auto map = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0 - d0);
+    highs[pos] = affine::makeComposedFoldedAffineApply(rewriter, loc, map,
+                                                       {origSize, innerSize});
+  }
   RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
       RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
       packingMetadata.reassociations);
@@ -529,8 +546,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
         loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
   }
   auto padOp =
-      tensor::createPadHighOp(collapsed, packOp.getSource(), paddingValue,
-                              /*nofold=*/false, loc, rewriter);
+      rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
+                                     highs, paddingValue, /*nofold=*/false);
 
   LLVM_DEBUG(
       DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,

diff  --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 40f9f3e0761e7..9e33b27505f16 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -1,12 +1,11 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -cse --split-input-file | FileCheck %s
 
   // CHECK-LABEL: func.func @pack(
 func.func @pack(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
   %cst_0 = arith.constant 0.0 : f32
 
   // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
-  //      CHECK: %[[C0:.*]] = arith.constant 0 : index
-  //      CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+  //      CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
   //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
   //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]]
   // CHECK-SAME:   : tensor<136x64x16x16xf32> into tensor<17x8x2x32x16x16xf32>
@@ -33,8 +32,7 @@ transform.sequence failures(propagate) {
 func.func @pack(%arg0: tensor<128x8xf32>, %arg1: tensor<8x8x16x1xf32>) -> tensor<8x8x16x1xf32> {
 
   // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
-  //      CHECK: %[[C0:.*]] = arith.constant 0 : index
-  //      CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]]]
+  //      CHECK: tensor.pad {{.*}} low[0, 0]
   //      CHECK:   : tensor<128x8xf32> to tensor<128x8xf32>
   //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3]]
   // CHECK-SAME:   : tensor<128x8xf32> into tensor<8x16x8x1xf32>
@@ -64,8 +62,7 @@ func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x13
   %cst_0 = arith.constant 0.0 : f32
 
   // tensor.pack is lowered to tensor.pad + tensor.insert_slice
-  //      CHECK: %[[C0:.*]] = arith.constant 0 : index
-  //      CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+  //      CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
   //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
   //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
   //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
@@ -100,8 +97,7 @@ transform.sequence failures(propagate) {
 func.func @pack_not_a_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x16x16x136x64xf32>) -> tensor<1x1x16x16x136x64xf32> {
   %cst_0 = arith.constant 0.0 : f32
 
-  //      CHECK: %[[C0:.*]] = arith.constant 0 : index
-  //      CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+  //      CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
   //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
   //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]]
   // CHECK-SAME:   : tensor<136x64x16x16xf32> into tensor<1x136x1x64x16x16xf32>
@@ -190,8 +186,7 @@ transform.sequence failures(propagate) {
 func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
                                      %dest: tensor<200x4x16x100x16x32xi32>)
     -> tensor<200x4x16x100x16x32xi32> {
-  //      CHECK: %[[C0:.*]] = arith.constant 0 : index
-  //      CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+  //      CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
   //      CHECK:   : tensor<100x200x128x256xi32> to tensor<100x200x128x256xi32>
   //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
   // CHECK-SAME:   : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
@@ -221,8 +216,7 @@ transform.sequence failures(propagate) {
 func.func @pack_with_pad_and_outer_dims_perm(%src: tensor<100x200x127x255xi32>,
                                              %dest: tensor<200x4x16x100x16x32xi32>)
     -> tensor<200x4x16x100x16x32xi32> {
-  //      CHECK: %[[C0:.*]] = arith.constant 0 : index
-  //      CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+  //      CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
   //      CHECK:   : tensor<100x200x127x255xi32> to tensor<100x200x128x256xi32>
   //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
   // CHECK-SAME:   : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
@@ -250,13 +244,64 @@ transform.sequence failures(propagate) {
 
 // -----
 
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 16) * 16)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 32) * 32)>
+// CHECK:       func.func @dynamic_pack_pad_transpose_inner_and_outer_dims(
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+func.func @dynamic_pack_pad_transpose_inner_and_outer_dims(%source: tensor<?x?xf32>) -> tensor<?x?x16x32xf32> {
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+  // CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+  // CHECK-DAG:   %[[C16:.+]] = arith.constant 16 : index
+  // CHECK-DAG:   %[[C32:.+]] = arith.constant 32 : index
+  // CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[SRC]], %[[C0]]
+  // CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[SRC]], %[[C1]]
+  // CHECK-DAG:   %[[OUT_D0:.+]] = arith.ceildivui %[[D1]], %[[C16]] : index
+  // CHECK-DAG:   %[[OUT_D1:.+]] = arith.ceildivui %[[D0]], %[[C32]] : index
+  // CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty(%[[OUT_D0]], %[[OUT_D1]]) : tensor<?x?x16x32xf32>
+  // CHECK-DAG:   %[[H1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
+  // CHECK-DAG:   %[[H0:.+]] = affine.apply #[[MAP1]]()[%[[D0]]]
+  // CHECK:       %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high[%[[H0]], %[[H1]]]
+  // CHECK:         : tensor<?x?xf32> to tensor<?x?xf32>
+  // CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0, 1], [2, 3]]
+  // CHECK-SAME:   : tensor<?x?xf32> into tensor<?x32x?x16xf32>
+  // CHECK:       %[[TRANSP:.+]] = linalg.transpose
+  // CHECK-SAME:    ins(%[[EXPAND]] : tensor<?x32x?x16xf32>)
+  // CHECK-SAME:    outs(%[[EMPTY]] : tensor<?x?x16x32xf32>)
+  // CHECK-SAME:    permutation = [2, 0, 3, 1]
+  // CHECK:       return %[[TRANSP]]
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %source, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %source, %c1 : tensor<?x?xf32>
+  %padding_value = arith.constant 0.0 : f32
+
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %tiled_d0 = arith.ceildivui %d0, %c32 : index
+  %tiled_d1 = arith.ceildivui %d1, %c16 : index
+  %init_pack = tensor.empty(%tiled_d1, %tiled_d0) : tensor<?x?x16x32xf32>
+  %pack = tensor.pack %source padding_value(%padding_value : f32)
+      outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %init_pack
+      : tensor<?x?xf32> -> tensor<?x?x16x32xf32>
+  return %pack : tensor<?x?x16x32xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+    : (!pdl.operation) -> !transform.op<"tensor.pack">
+  transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+    -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
+
 // CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
 func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
   %cst_0 = arith.constant 0.0 : f32
 
   // tensor.pack is lowered to tensor.pad + tensor.insert_slice
-  //      CHECK: %[[C0:.*]] = arith.constant 0 : index
-  //      CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+  //      CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
   //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
   //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
   //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]


        


More information about the Mlir-commits mailing list