[Mlir-commits] [mlir] b26ee97 - [MLIR][Linalg] Support dynamic sizes in `lower_unpack` (#75494)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 18 10:02:08 PST 2023
Author: srcarroll
Date: 2023-12-18T19:02:04+01:00
New Revision: b26ee9753777bdb6430e830397d0d6532597a0da
URL: https://github.com/llvm/llvm-project/commit/b26ee9753777bdb6430e830397d0d6532597a0da
DIFF: https://github.com/llvm/llvm-project/commit/b26ee9753777bdb6430e830397d0d6532597a0da.diff
LOG: [MLIR][Linalg] Support dynamic sizes in `lower_unpack` (#75494)
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/transform-lower-pack.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 10dfbe6cec781d..9d230e2c2e5749 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -380,17 +380,11 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
if (!unPackOp.getOuterDimsPerm().empty())
return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
- RankedTensorType packedTensorType = unPackOp.getSourceType();
- if (!packedTensorType.hasStaticShape()) {
- return rewriter.notifyMatchFailure(
- unPackOp,
- "non-static shape NYI, needs a more powerful tensor.expand_shape op");
- }
-
Location loc = unPackOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
+ RankedTensorType packedTensorType = unPackOp.getSourceType();
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -434,8 +428,14 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
stripMinedTensorType, packingMetadata.reassociations);
- auto emptyOp =
- rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, ValueRange{});
+
+ // Get dynamic dims from input tensor based on lastDimsToInsertPositionsPerm
+ // permutation.
+ SmallVector<OpFoldResult, 4> dims =
+ tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
+ applyPermutationToVector(dims, lastDimsToInsertPositionsPerm);
+ auto emptyOp = rewriter.create<tensor::EmptyOp>(
+ loc, dims, stripMinedTensorType.getElementType());
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index b9706eed54b608..316df431a9c0c8 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -464,6 +464,129 @@ module attributes {transform.with_named_sequence} {
// -----
+// Check that we can lower unpack with dynamic dimensions in the input and destination.
+// CHECK-LABEL: func.func @unpack_with_dynamic_input_dest(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x8x16xf32>, %[[ARG1:.*]]: tensor<?x?xf32>)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[DIM00:.*]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[DIM01:.*]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM00]], %[[DIM01]]) : tensor<?x8x?x16xf32>
+// CHECK: %[[TRAN:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?x8x16xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x?x16xf32>)
+// CHECK-SAME: permutation = [0, 2, 1, 3]
+// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
+// CHECK-SAME: : tensor<?x8x?x16xf32> into tensor<?x?xf32>
+// CHECK: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+// CHECK: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [%[[DIM10]], %[[DIM11]]] [1, 1]
+// CHECK-SAME: : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: linalg.copy ins(%[[SLICE]] : tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[ARG1]] : tensor<?x?xf32>)
+func.func @unpack_with_dynamic_input_dest(%arg0: tensor<?x?x8x16xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 16] into %arg1 : tensor<?x?x8x16xf32> -> tensor<?x?xf32>
+ return %unpack : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"tensor.unpack">
+ transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">)
+ transform.yield
+ }
+}
+
+// -----
+
+// Check that we can lower unpack with dynamic dimensions in the input, destination, inner_tiles.
+// CHECK-LABEL: func.func @unpack_fully_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[DIM00:.*]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[DIM01:.*]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[DIM02:.*]] = tensor.dim %[[ARG0]], %[[C2]]
+// CHECK-DAG: %[[DIM03:.*]] = tensor.dim %[[ARG0]], %[[C3]]
+// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM00]], %[[DIM02]], %[[DIM01]], %[[DIM03]]) : tensor<?x?x?x?xf32>
+// CHECK: %[[TRAN:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?x?x?xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x?x?x?xf32>)
+// CHECK-SAME: permutation = [0, 2, 1, 3]
+// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
+// CHECK-SAME: : tensor<?x?x?x?xf32> into tensor<?x?xf32>
+// CHECK: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+// CHECK: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [%[[DIM10]], %[[DIM11]]] [1, 1]
+// CHECK-SAME: : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: linalg.copy ins(%[[SLICE]] : tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[ARG1]] : tensor<?x?xf32>)
+func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?xf32>, %tile_n : index, %tile_m : index) -> tensor<?x?xf32> {
+ %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"tensor.unpack">
+ transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">)
+ transform.yield
+ }
+}
+
+// -----
+
+// Check that we can lower unpack "as unpad" with dynamic dims.
+// CHECK-LABEL: func.func @unpack_as_pad_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x1x?x?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK-DAG: %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
+// CHECK: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
+// offsets.
+// CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
+// sizes.
+// CHECK-SAME: [1, 1, 1, 1, %[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
+// strides multiplers.
+// CHECK-SAME: [1, 1, 1, 1, 1, 1, 1, 1]
+// CHECK-SAME: : tensor<1x1x1x1x?x?x?x?xf32> to tensor<?x?x?x?xf32>
+func.func @unpack_as_pad_dynamic(%arg0: tensor<1x1x1x1x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+ : tensor<1x1x1x1x?x?x?x?xf32> -> tensor<?x?x?x?xf32>
+ return %pack : tensor<?x?x?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"tensor.unpack">
+ transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">)
+ transform.yield
+ }
+}
+
+// -----
+
// At the moment, we cannot lower tensor.unpack with outer_dims_perm.
func.func @diagnostic_unpack(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
// expected-note @below {{target payload op}}
More information about the Mlir-commits
mailing list