[Mlir-commits] [mlir] [mlir] Fix `lower_unpack` when dynamic dimensions are involved (PR #68423)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 6 08:25:54 PDT 2023


https://github.com/qcolombet created https://github.com/llvm/llvm-project/pull/68423

When lowering `tensor.unpack`, we need to use the sizes of the destination tensor in the final `tensor.extract_slice` operation. Prior to this patch, when the destination tensor had dynamic dimensions, we would compute them from the result of the `tensor.unpack` operation instead of its destination argument.

This would produce invalid IR because the `tensor.dim` operations would need to appear before the `tensor.extract_slice` operation, but the input of the `tensor.dim` operations would consume the final result of the lowering of `tensor.unpack`, which happens after the `tensor.extract_slice` operation. In other words, the definition wouldn't dominate its uses.

I.e., we were generating:
```
%dynDim = tensor.dim %defLater, ... <-- %defLater defined below
%res = tensor.extract_slice ..., %dynDim, ...
%defLater = linalg.copy (ins %res)
```

Note: I checked the implementation of `lower_pack` and the code is correct as far as I can tell.

>From 11e7138d60cf9fbbdbd734ba44d92e7ff407257c Mon Sep 17 00:00:00 2001
From: Quentin Colombet <quentin.colombet at gmail.com>
Date: Fri, 6 Oct 2023 17:12:33 +0200
Subject: [PATCH] [mlir] Fix `lower_unpack` when dynamic dimensions are
 involved

When lowering `tensor.unpack`, we need to use the sizes of the destination
tensor in the final `tensor.extract_slice` operation.
Prior to this patch, when the destination tensor had dynamic dimensions, we
would compute them from the result of the `tensor.unpack` operation instead
of its destination argument.

This would produce invalid IR because the `tensor.dim` operations would
need to appear before the `tensor.extract_slice` operation, but the input
of the `tensor.dim` operations would consume the final result of the
lowering of `tensor.unpack`, which happens after the
`tensor.extract_slice` operation. In other words, the definition wouldn't
dominate its uses.

I.e., we were generating:
```
%dynDim = tensor.dim %defLater, ... <-- %defLater defined below
%res = tensor.extract_slice ..., %dynDim, ...
%defLater = linalg.copy (ins %res)
```

Note: I checked the implementation of `lower_pack` and the code is correct
as far as I can tell.
---
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  2 +-
 .../Dialect/Linalg/transform-lower-pack.mlir  | 39 ++++++++++++++++++-
 2 files changed, 39 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8183b40ad7346f4..bca343cf8777149 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -467,7 +467,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
   auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
       loc, destTensorType, reshapeOp->getResult(0),
       SmallVector<OpFoldResult>(destRank, zero),
-      tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
+      tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
       SmallVector<OpFoldResult>(destRank, one));
 
   // 7. Inject a copy to preserve DPS.
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index c71feddcc1c8486..ad6c6a6f6199cc6 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -133,7 +133,7 @@ func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16
   // CHECK-SAME:   : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
   //      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
   // CHECK-SAME:   : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
-  //      CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>) 
+  //      CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
   // CHECK-SAME:        outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
   %pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
     : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
@@ -397,3 +397,40 @@ transform.sequence failures(propagate) {
   transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
     -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
 }
+
+// -----
+
+// Check that we can lower unpack with dynamic dimensions in the destination.
+// CHECK-LABEL: func.func @unpack_with_dynamic_dest(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<32x2x49x16x16xf32>, %[[ARG1:.*]]: tensor<32x?x?xf32>)
+//      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<32x2x16x49x16xf32>
+//      CHECK: %[[TRAN:.*]] = linalg.transpose
+// CHECK-SAME:    ins(%[[ARG0]] : tensor<32x2x49x16x16xf32>)
+// CHECK-SAME:   outs(%[[EMPTY]] : tensor<32x2x16x49x16xf32>)
+// CHECK-SAME:   permutation = [0, 1, 3, 2, 4]
+//      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0], [1, 2], [3, 4]]
+// CHECK-SAME:   : tensor<32x2x16x49x16xf32> into tensor<32x32x784xf32>
+//      CHECK:  %[[C1:.*]] = arith.constant 1 : index
+//      CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<32x?x?xf32>
+//      CHECK: %[[C2:.*]] = arith.constant 2 : index
+//      CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<32x?x?xf32>
+//      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0] [32, %[[DIM1]], %[[DIM2]]] [1, 1, 1]
+// CHECK-SAME:   : tensor<32x32x784xf32> to tensor<32x?x?xf32>
+//      CHECK: linalg.copy ins(%[[SLICE]] : tensor<32x?x?xf32>)
+// CHECK-SAME:        outs(%[[ARG1]] : tensor<32x?x?xf32>)
+func.func @unpack_with_dynamic_dest(%arg0: tensor<32x2x49x16x16xf32>, %arg1: tensor<32x?x?xf32>) -> tensor<32x?x?xf32> {
+  %pack = tensor.unpack %arg0 inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %arg1
+    : tensor<32x2x49x16x16xf32> -> tensor<32x?x?xf32>
+  return %pack : tensor<32x?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+    : (!transform.any_op) -> !transform.op<"tensor.unpack">
+  transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+    -> (!transform.op<"tensor.empty">,
+        !transform.op<"linalg.transpose">,
+        !transform.op<"tensor.collapse_shape">,
+        !transform.op<"tensor.extract_slice">)
+}



More information about the Mlir-commits mailing list