[Mlir-commits] [mlir] f32427e - [mlir][linalg] Fix lowering of tensor.pack operations
Lorenzo Chelini
llvmlistbot at llvm.org
Tue Sep 5 12:14:03 PDT 2023
Author: Spenser Bauman
Date: 2023-09-05T21:08:13+02:00
New Revision: f32427e0447c787755fb29415f4deeb4c00adacd
URL: https://github.com/llvm/llvm-project/commit/f32427e0447c787755fb29415f4deeb4c00adacd
DIFF: https://github.com/llvm/llvm-project/commit/f32427e0447c787755fb29415f4deeb4c00adacd.diff
LOG: [mlir][linalg] Fix lowering of tensor.pack operations
Tensor pack operations are optimistically lowered to pad + insert_slice
when the pack operation only pads the input tensor. The existing
lowering emits insert_slice operations which do not meet the
rank-reducibility requirements of insert_slice.
This change updates the logic in linalg::lowerPack to first check the
rank-reducibility requirement. When the requirement is not met, the
lowering will emit the full sequence of pad + expand + transpose.
Reviewed By: chelini
Differential Revision: https://reviews.llvm.org/D159382
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/transform-lower-pack.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 12f0bed76031af..a2d219f669905e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -321,28 +321,36 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
if (packOp.isLikePad()) {
- // This pack is just a plain pad.
- // Just insert the pad in the higher ranked tensor.
- auto emptyOp =
- rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
- // Offsets.
- SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
- // Strides.
- SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
- SmallVector<OpFoldResult> sizes =
- tensor::getMixedSizes(rewriter, loc, packOp.getDest());
-
- auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, /*source=*/padOp, /*dest=*/emptyOp,
- /*offsets=*/zeros, sizes,
- /*strides=*/ones);
-
- LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
-
- rewriter.replaceOp(packOp, insertSliceOp->getResults());
-
- return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
- /*transposeOp=*/nullptr};
+ // Pack ops which operate as simple pads may not produce legal
+ // tensor.insert_slice operations when the packed type does not rank reduce
+ // to the padded type.
+ SliceVerificationResult rankReduces =
+ isRankReducedType(packedTensorType, padOp.getResultType());
+
+ if (rankReduces == SliceVerificationResult::Success) {
+ // This pack is just a plain pad.
+ // Just insert the pad in the higher ranked tensor.
+ auto emptyOp =
+ rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
+ // Offsets.
+ SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
+ // Strides.
+ SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> sizes =
+ tensor::getMixedSizes(rewriter, loc, packOp.getDest());
+
+ auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ loc, /*source=*/padOp, /*dest=*/emptyOp,
+ /*offsets=*/zeros, sizes,
+ /*strides=*/ones);
+
+ LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
+
+ rewriter.replaceOp(packOp, insertSliceOp->getResults());
+
+ return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
+ /*transposeOp=*/nullptr};
+ }
}
// 5. Expand from the padded result to the stripMinedShape.
auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index c11d301140039a..c71feddcc1c848 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -356,6 +356,40 @@ func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %ar
return %pack : tensor<1x1x1x1x136x64x16x16xf32>
}
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
+// -----
+
+// CHECK-LABEL: func.func @pack_as_pad_with_unit_dims(
+// CHECK: %[[SRC:.+]]: tensor<3x1x1x1xf32>,
+// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x8x1xf32>)
+func.func @pack_as_pad_with_unit_dims(%arg0: tensor<3x1x1x1xf32>, %arg1: tensor<1x1x1x1x8x1xf32>) -> (tensor<1x1x1x1x8x1xf32>) {
+ %zero = arith.constant 0.0 : f32
+
+ // CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[5, 0, 0, 0] {
+ // CHECK: : tensor<3x1x1x1xf32> to tensor<8x1x1x1xf32>
+ // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] [{{.*}}[0, 1], [2, 3], [4], [5]]
+ // CHECK-SAME: tensor<8x1x1x1xf32> into tensor<1x8x1x1x1x1xf32>
+ // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose
+ // CHECK-SAME: ins(%[[EXPAND]] : tensor<1x8x1x1x1x1xf32>)
+ // CHECK-SAME: outs(%[[OUT]] : tensor<1x1x1x1x8x1xf32>)
+ // CHECK-SAME: permutation = [0, 2, 4, 5, 1, 3]
+ // CHECK: return %[[TRANSPOSED]] : tensor<1x1x1x1x8x1xf32>
+ %pack = tensor.pack %arg0
+ padding_value(%zero : f32)
+ inner_dims_pos = [0, 1]
+ inner_tiles = [8, 1] into %arg1 : tensor<3x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32>
+
+ return %pack : tensor<1x1x1x1x8x1xf32>
+}
+
+
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
%pack = transform.structured.match ops{["tensor.pack"]} in %module_op
More information about the Mlir-commits
mailing list