[Mlir-commits] [mlir] [MLIR][Linalg] Support dynamic sizes in `lower_unpack` (PR #75494)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 14 10:08:39 PST 2023


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/75494

>From 3879e633f4e4ba36ca24e6a870942b695ac935e0 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 13 Dec 2023 15:25:31 -0600
Subject: [PATCH 1/4] [MLIR][Linalg] Support dynamic sizes in `lower_unpack`

---
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 20 +++++++++++++------
 .../Dialect/Linalg/transform-lower-pack.mlir  |  4 ++--
 2 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 10dfbe6cec781d..e66cf4c271d656 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -381,11 +381,6 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
     return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
 
   RankedTensorType packedTensorType = unPackOp.getSourceType();
-  if (!packedTensorType.hasStaticShape()) {
-    return rewriter.notifyMatchFailure(
-        unPackOp,
-        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
-  }
 
   Location loc = unPackOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
@@ -434,8 +429,21 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
       RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
       stripMinedTensorType, packingMetadata.reassociations);
+
+  // Get dynamic dims for stripMined shape
+  SmallVector<Value> dims(llvm::map_range(
+      llvm::seq<int64_t>(0, packedTensorType.getRank()), [&](int64_t i) {
+        return rewriter.create<tensor::DimOp>(loc, unPackOp.getSource(), i);
+      }));
+  applyPermutationToVector(dims, lastDimsToInsertPositionsPerm);
+  SmallVector<Value> dynDims;
+  for (int64_t i = 0; i < stripMinedTensorType.getRank(); i++) {
+    if (stripMinedTensorType.isDynamicDim(i))
+      dynDims.push_back(dims[i]);
+  }
+
   auto emptyOp =
-      rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, ValueRange{});
+      rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, dynDims);
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
 
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index b9706eed54b608..8849ae6084f3ba 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -428,6 +428,8 @@ module attributes {transform.with_named_sequence} {
 // Check that we can lower unpack with dynamic dimensions in the destination.
 // CHECK-LABEL: func.func @unpack_with_dynamic_dest(
 // CHECK-SAME: %[[ARG0:.*]]: tensor<32x2x49x16x16xf32>, %[[ARG1:.*]]: tensor<32x?x?xf32>)
+//      CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
+//      CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
 //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<32x2x16x49x16xf32>
 //      CHECK: %[[TRAN:.*]] = linalg.transpose
 // CHECK-SAME:    ins(%[[ARG0]] : tensor<32x2x49x16x16xf32>)
@@ -435,9 +437,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:   permutation = [0, 1, 3, 2, 4]
 //      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0], [1, 2], [3, 4]]
 // CHECK-SAME:   : tensor<32x2x16x49x16xf32> into tensor<32x32x784xf32>
-//      CHECK:  %[[C1:.*]] = arith.constant 1 : index
 //      CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<32x?x?xf32>
-//      CHECK: %[[C2:.*]] = arith.constant 2 : index
 //      CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<32x?x?xf32>
 //      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0] [32, %[[DIM1]], %[[DIM2]]] [1, 1, 1]
 // CHECK-SAME:   : tensor<32x32x784xf32> to tensor<32x?x?xf32>

>From 67d8304adf538cf38bff8c380931203828ae0b7c Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 14 Dec 2023 01:37:37 -0600
Subject: [PATCH 2/4] Add test

---
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  8 ++--
 .../Dialect/Linalg/transform-lower-pack.mlir  | 42 +++++++++++++++++++
 2 files changed, 47 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e66cf4c271d656..276bf1045eee3a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -430,13 +430,15 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
       stripMinedTensorType, packingMetadata.reassociations);
 
-  // Get dynamic dims for stripMined shape
-  SmallVector<Value> dims(llvm::map_range(
+  // Get dynamic dims from input tensor in order of stripMinedTensor
+  // `tensor.empty` op
+  SmallVector<Value, 4> dims(llvm::map_range(
       llvm::seq<int64_t>(0, packedTensorType.getRank()), [&](int64_t i) {
         return rewriter.create<tensor::DimOp>(loc, unPackOp.getSource(), i);
       }));
   applyPermutationToVector(dims, lastDimsToInsertPositionsPerm);
-  SmallVector<Value> dynDims;
+  SmallVector<Value, 4> dynDims;
+  dynDims.reserve(stripMinedTensorType.getNumDynamicDims());
   for (int64_t i = 0; i < stripMinedTensorType.getRank(); i++) {
     if (stripMinedTensorType.isDynamicDim(i))
       dynDims.push_back(dims[i]);
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 8849ae6084f3ba..f19db18f4876fc 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -464,6 +464,48 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Check that we can lower unpack with dynamic dimensions in the input and destination.
+// CHECK-LABEL: func.func @unpack_with_dynamic_input_dest(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x8x16xf32>, %[[ARG1:.*]]: tensor<?x?xf32>)
+//      CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+//      CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//      CHECK-DAG:  %[[C2:.*]] = arith.constant 2 : index
+//      CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+//      CHECK-DAG: %[[DIM00:.*]] = tensor.dim %[[ARG0]], %[[C0]]
+//      CHECK-DAG: %[[DIM01:.*]] = tensor.dim %[[ARG0]], %[[C1]]
+//      CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM00]], %[[DIM01]]) : tensor<?x8x?x16xf32>
+//      CHECK: %[[TRAN:.*]] = linalg.transpose
+// CHECK-SAME:    ins(%[[ARG0]] : tensor<?x?x8x16xf32>)
+// CHECK-SAME:   outs(%[[EMPTY]] : tensor<?x8x?x16xf32>)
+// CHECK-SAME:   permutation = [0, 2, 1, 3]
+//      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
+// CHECK-SAME:   : tensor<?x8x?x16xf32> into tensor<?x?xf32>
+//      CHECK: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//      CHECK: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [%[[DIM10]], %[[DIM11]]] [1, 1]
+// CHECK-SAME:   : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK: linalg.copy ins(%[[SLICE]] : tensor<?x?xf32>)
+// CHECK-SAME:        outs(%[[ARG1]] : tensor<?x?xf32>)
+func.func @unpack_with_dynamic_input_dest(%arg0: tensor<?x?x8x16xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+    %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 16] into %arg1 : tensor<?x?x8x16xf32> -> tensor<?x?xf32>
+    return %unpack : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.unpack">
+    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+      -> (!transform.op<"tensor.empty">,
+          !transform.op<"linalg.transpose">,
+          !transform.op<"tensor.collapse_shape">,
+          !transform.op<"tensor.extract_slice">)
+          transform.yield
+  }
+}
+
+// -----
+
 // At the moment, we cannot lower tensor.unpack with outer_dims_perm.
 func.func @diagnostic_unpack(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
   // expected-note @below {{target payload op}}

>From 208091f47975d0941d31683c3a3a10bf90254741 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 14 Dec 2023 11:50:48 -0600
Subject: [PATCH 3/4] Add fully dynamic test

---
 .../Dialect/Linalg/transform-lower-pack.mlir  | 43 +++++++++++++++++++
 1 file changed, 43 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index f19db18f4876fc..8de0f538d6c73d 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -506,6 +506,49 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Check that we can lower unpack with dynamic dimensions in the input, destination, inner_tiles.
+// CHECK-LABEL: func.func @unpack_fully_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+//      CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+//      CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//      CHECK-DAG:  %[[C2:.*]] = arith.constant 2 : index
+//      CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+//      CHECK-DAG: %[[DIM00:.*]] = tensor.dim %[[ARG0]], %[[C0]]
+//      CHECK-DAG: %[[DIM01:.*]] = tensor.dim %[[ARG0]], %[[C1]]
+//      CHECK-DAG: %[[DIM02:.*]] = tensor.dim %[[ARG0]], %[[C2]]
+//      CHECK-DAG: %[[DIM03:.*]] = tensor.dim %[[ARG0]], %[[C3]]
+//      CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM00]], %[[DIM02]], %[[DIM01]], %[[DIM03]]) : tensor<?x?x?x?xf32>
+//      CHECK: %[[TRAN:.*]] = linalg.transpose
+// CHECK-SAME:    ins(%[[ARG0]] : tensor<?x?x?x?xf32>)
+// CHECK-SAME:   outs(%[[EMPTY]] : tensor<?x?x?x?xf32>)
+// CHECK-SAME:   permutation = [0, 2, 1, 3]
+//      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
+// CHECK-SAME:   : tensor<?x?x?x?xf32> into tensor<?x?xf32>
+//      CHECK: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//      CHECK: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [%[[DIM10]], %[[DIM11]]] [1, 1]
+// CHECK-SAME:   : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK: linalg.copy ins(%[[SLICE]] : tensor<?x?xf32>)
+// CHECK-SAME:        outs(%[[ARG1]] : tensor<?x?xf32>)
+func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?xf32>, %tile_n : index, %tile_m : index) -> tensor<?x?xf32> {
+  %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.unpack"> 
+    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+          -> (!transform.op<"tensor.empty">,
+          !transform.op<"linalg.transpose">,
+          !transform.op<"tensor.collapse_shape">,
+          !transform.op<"tensor.extract_slice">)
+      transform.yield
+  }
+}
+
+// -----
+
 // At the moment, we cannot lower tensor.unpack with outer_dims_perm.
 func.func @diagnostic_unpack(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
   // expected-note @below {{target payload op}}

>From b2ff96c2b5f86bfb498d4a725215b4a49d717d09 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 14 Dec 2023 12:08:18 -0600
Subject: [PATCH 4/4] Add unpack as unpad test

---
 .../Dialect/Linalg/transform-lower-pack.mlir  | 40 +++++++++++++++++++
 1 file changed, 40 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 8de0f538d6c73d..0c3196b3cafc9f 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -549,6 +549,46 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Check that we can lower unpack "as unpad" with dynamic dims 
+// CHECK-LABEL: func.func @unpack_as_pad_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x1x?x?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?x?xf32>
+//      CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+//      CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//      CHECK-DAG:  %[[C2:.*]] = arith.constant 2 : index
+//      CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+//      CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
+//      CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
+//      CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+//      CHECK-DAG: %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
+//      CHECK: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
+// offsets.
+// CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
+// sizes.
+// CHECK-SAME:   [1, 1, 1, 1, %[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
+// strides multiplers.
+// CHECK-SAME:   [1, 1, 1, 1, 1, 1, 1, 1]
+// CHECK-SAME:   : tensor<1x1x1x1x?x?x?x?xf32> to tensor<?x?x?x?xf32>
+func.func @unpack_as_pad_dynamic(%arg0: tensor<1x1x1x1x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+    : tensor<1x1x1x1x?x?x?x?xf32> -> tensor<?x?x?x?xf32>
+  return %pack : tensor<?x?x?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.unpack">
+    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+      -> (!transform.op<"tensor.empty">,
+          !transform.op<"linalg.transpose">,
+          !transform.op<"tensor.collapse_shape">,
+          !transform.op<"tensor.extract_slice">)
+          transform.yield
+  }
+}
+
+// -----
+
 // At the moment, we cannot lower tensor.unpack with outer_dims_perm.
 func.func @diagnostic_unpack(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
   // expected-note @below {{target payload op}}



More information about the Mlir-commits mailing list