[Mlir-commits] [mlir] 58e4231 - [mlir][linalg] Fix tensor.pad sizes computation in lowerPack.

Hanhan Wang llvmlistbot at llvm.org
Thu May 18 15:37:02 PDT 2023


Author: Hanhan Wang
Date: 2023-05-18T15:36:50-07:00
New Revision: 58e4231b346f52aa32d3aaddc407391abed0cbd1

URL: https://github.com/llvm/llvm-project/commit/58e4231b346f52aa32d3aaddc407391abed0cbd1
DIFF: https://github.com/llvm/llvm-project/commit/58e4231b346f52aa32d3aaddc407391abed0cbd1.diff

LOG: [mlir][linalg] Fix tensor.pad sizes computation in lowerPack.

The padded sizes should be derived from destination tensor, not source
tensor. There could be more than one incomplete tile in padding domain.

Reviewed By: qedawkins

Differential Revision: https://reviews.llvm.org/D150726

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 230089582f25..6f932bd9fead 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -527,15 +527,20 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
                                   rewriter.getIndexAttr(0));
   for (auto [pos, innerSize] :
        llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
+    int outerPos =
+        packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
     OpFoldResult origSize = rewriter.createOrFold<tensor::DimOp>(
         loc, packOp.getSource(),
         rewriter.create<arith::ConstantIndexOp>(loc, pos));
-    AffineExpr s0, d0;
-    bindDims(rewriter.getContext(), d0);
+    OpFoldResult outerSize = rewriter.createOrFold<tensor::DimOp>(
+        loc, packOp.getDest(),
+        rewriter.create<arith::ConstantIndexOp>(loc, outerPos));
+    AffineExpr s0, d0, d1;
+    bindDims(rewriter.getContext(), d0, d1);
     bindSymbols(rewriter.getContext(), s0);
-    auto map = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0 - d0);
-    highs[pos] = affine::makeComposedFoldedAffineApply(rewriter, loc, map,
-                                                       {origSize, innerSize});
+    auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1);
+    highs[pos] = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, map, {outerSize, origSize, innerSize});
   }
   RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
       RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),

diff  --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index f3346c146833..374ea994ed49 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -212,6 +212,36 @@ transform.sequence failures(propagate) {
 
 // -----
 
+// CHECK-LABEL: func.func @pack_with_pad(
+func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>)
+    -> tensor<265x16x16x1xf32> {
+  //      CHECK: tensor.pad {{.*}} low[0, 0]
+  //      CHECK:   : tensor<4225x12xf32> to tensor<4240x16xf32>
+  //      CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
+  // CHECK-SAME:   : tensor<4240x16xf32> into tensor<265x16x16x1xf32>
+  //      CHECK: linalg.transpose
+  // CHECK-SAME:   ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
+  // CHECK-SAME:   outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
+  // CHECK-SAME:   permutation = [0, 2, 1, 3]
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.pack %src
+    padding_value(%cst : f32)
+    inner_dims_pos = [0, 1]
+    inner_tiles = [16, 1] into %dest
+    : tensor<4225x12xf32> -> tensor<265x16x16x1xf32>
+  return %0 : tensor<265x16x16x1xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+    : (!transform.any_op) -> !transform.op<"tensor.pack">
+  transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+    -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
+
 // CHECK-LABEL: func.func @pack_with_pad_and_outer_dims_perm(
 func.func @pack_with_pad_and_outer_dims_perm(%src: tensor<100x200x127x255xi32>,
                                              %dest: tensor<200x4x16x100x16x32xi32>)
@@ -244,8 +274,8 @@ transform.sequence failures(propagate) {
 
 // -----
 
-// CHECK-DAG:   #[[MAP0:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 16) * 16)>
-// CHECK-DAG:   #[[MAP1:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 32) * 32)>
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * 16 - s1)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 32 - s1)>
 // CHECK:       func.func @dynamic_pack_pad_transpose_inner_and_outer_dims(
 // CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
 func.func @dynamic_pack_pad_transpose_inner_and_outer_dims(%source: tensor<?x?xf32>) -> tensor<?x?x16x32xf32> {
@@ -258,8 +288,10 @@ func.func @dynamic_pack_pad_transpose_inner_and_outer_dims(%source: tensor<?x?xf
   // CHECK-DAG:   %[[OUT_D0:.+]] = arith.ceildivui %[[D1]], %[[C16]] : index
   // CHECK-DAG:   %[[OUT_D1:.+]] = arith.ceildivui %[[D0]], %[[C32]] : index
   // CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty(%[[OUT_D0]], %[[OUT_D1]]) : tensor<?x?x16x32xf32>
-  // CHECK-DAG:   %[[H1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
-  // CHECK-DAG:   %[[H0:.+]] = affine.apply #[[MAP1]]()[%[[D0]]]
+  // CHECK-DAG:   %[[DEST_D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]]
+  // CHECK-DAG:   %[[DEST_D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
+  // CHECK-DAG:   %[[H1:.+]] = affine.apply #[[MAP0]]()[%[[DEST_D0]], %[[D1]]]
+  // CHECK-DAG:   %[[H0:.+]] = affine.apply #[[MAP1]]()[%[[DEST_D1]], %[[D0]]]
   // CHECK:       %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high[%[[H0]], %[[H1]]]
   // CHECK:         : tensor<?x?xf32> to tensor<?x?xf32>
   // CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0, 1], [2, 3]]


        


More information about the Mlir-commits mailing list