[Mlir-commits] [mlir] f32427e - [mlir][linalg] Fix lowering of tensor.pack operations

Lorenzo Chelini llvmlistbot at llvm.org
Tue Sep 5 12:14:03 PDT 2023


Author: Spenser Bauman
Date: 2023-09-05T21:08:13+02:00
New Revision: f32427e0447c787755fb29415f4deeb4c00adacd

URL: https://github.com/llvm/llvm-project/commit/f32427e0447c787755fb29415f4deeb4c00adacd
DIFF: https://github.com/llvm/llvm-project/commit/f32427e0447c787755fb29415f4deeb4c00adacd.diff

LOG: [mlir][linalg] Fix lowering of tensor.pack operations

Tensor pack operations are optimistically lowered to pad + insert_slice
when the pack operation only pads the input tensor. The existing
lowering emits insert_slice operations which do not meet the
rank-reducibility requirements of insert_slice.

This change updates the logic in linalg::lowerPack to first check the
rank-reducibility requirement. When the requirement is not met, the
lowering will emit the full sequence of pad + expand + transpose.

Reviewed By: chelini

Differential Revision: https://reviews.llvm.org/D159382

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 12f0bed76031af..a2d219f669905e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -321,28 +321,36 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
       DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
 
   if (packOp.isLikePad()) {
-    // This pack is just a plain pad.
-    // Just insert the pad in the higher ranked tensor.
-    auto emptyOp =
-        rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
-    // Offsets.
-    SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
-    // Strides.
-    SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
-    SmallVector<OpFoldResult> sizes =
-        tensor::getMixedSizes(rewriter, loc, packOp.getDest());
-
-    auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-        loc, /*source=*/padOp, /*dest=*/emptyOp,
-        /*offsets=*/zeros, sizes,
-        /*strides=*/ones);
-
-    LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
-
-    rewriter.replaceOp(packOp, insertSliceOp->getResults());
-
-    return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
-                           /*transposeOp=*/nullptr};
+    // Pack ops which operate as simple pads may not produce legal
+    // tensor.insert_slice operations when the packed type does not rank reduce
+    // to the padded type.
+    SliceVerificationResult rankReduces =
+        isRankReducedType(packedTensorType, padOp.getResultType());
+
+    if (rankReduces == SliceVerificationResult::Success) {
+      // This pack is just a plain pad.
+      // Just insert the pad in the higher ranked tensor.
+      auto emptyOp =
+          rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
+      // Offsets.
+      SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
+      // Strides.
+      SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
+      SmallVector<OpFoldResult> sizes =
+          tensor::getMixedSizes(rewriter, loc, packOp.getDest());
+
+      auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+          loc, /*source=*/padOp, /*dest=*/emptyOp,
+          /*offsets=*/zeros, sizes,
+          /*strides=*/ones);
+
+      LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
+
+      rewriter.replaceOp(packOp, insertSliceOp->getResults());
+
+      return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
+                             /*transposeOp=*/nullptr};
+    }
   }
   // 5. Expand from the padded result to the stripMinedShape.
   auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(

diff  --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index c11d301140039a..c71feddcc1c848 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -356,6 +356,40 @@ func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %ar
   return %pack :  tensor<1x1x1x1x136x64x16x16xf32>
 }
 
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+    : (!transform.any_op) -> !transform.op<"tensor.pack">
+  transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+    -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_as_pad_with_unit_dims(
+// CHECK: %[[SRC:.+]]: tensor<3x1x1x1xf32>,
+// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x8x1xf32>)
+func.func @pack_as_pad_with_unit_dims(%arg0: tensor<3x1x1x1xf32>, %arg1: tensor<1x1x1x1x8x1xf32>) -> (tensor<1x1x1x1x8x1xf32>) {
+  %zero = arith.constant 0.0 : f32
+
+  // CHECK:      %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[5, 0, 0, 0] {
+  // CHECK:        : tensor<3x1x1x1xf32> to tensor<8x1x1x1xf32>
+  // CHECK:      %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] [{{.*}}[0, 1], [2, 3], [4], [5]]
+  // CHECK-SAME:   tensor<8x1x1x1xf32> into tensor<1x8x1x1x1x1xf32>
+  // CHECK:      %[[TRANSPOSED:.+]] = linalg.transpose
+  // CHECK-SAME:   ins(%[[EXPAND]] : tensor<1x8x1x1x1x1xf32>)
+  // CHECK-SAME:   outs(%[[OUT]] : tensor<1x1x1x1x8x1xf32>)
+  // CHECK-SAME:   permutation = [0, 2, 4, 5, 1, 3]
+  // CHECK:      return %[[TRANSPOSED]] : tensor<1x1x1x1x8x1xf32>
+  %pack = tensor.pack %arg0
+      padding_value(%zero : f32)
+      inner_dims_pos = [0, 1]
+      inner_tiles = [8, 1] into %arg1 : tensor<3x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32>
+
+  return %pack : tensor<1x1x1x1x8x1xf32>
+}
+
+
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
   %pack = transform.structured.match ops{["tensor.pack"]} in %module_op


        


More information about the Mlir-commits mailing list