[Mlir-commits] [mlir] [MLIR][linalg] Fix unpack rewriter for dynamic shapes (PR #67096)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 22 00:51:35 PDT 2023


https://github.com/qcolombet created https://github.com/llvm/llvm-project/pull/67096

Prior to this patch, `GeneralizeOuterUnitDimsUnPackOpPattern` would assert that we cannot create a `tensor.empty` operation with dynamic shapes.

The problem stems from the fact that we were not using the right builder for the `tensor.empty` operation. Indeed, each dynamic dim needs to be specified by an input variable.

Simply provide the dynamic dimensions to the `tensor.empty` builder to fix that.

>From b7bfe571515a783c4f0ed5b3a142f2cc2bf528ab Mon Sep 17 00:00:00 2001
From: Quentin Colombet <quentin.colombet at gmail.com>
Date: Fri, 22 Sep 2023 09:37:34 +0200
Subject: [PATCH] [MLIR][linalg] Fix unpack rewriter for dynamic shapes

Prior to this patch, `GeneralizeOuterUnitDimsUnPackOpPattern` would assert
that we cannot create a `tensor.empty` operation with dynamic shapes.

The problem stems from the fact that we were not using the right builder
for the `tensor.empty` operation. Indeed, each dynamic dim needs to be
specified by an input variable.

Simply provide the dynamic dimensions to the `tensor.empty` builder to fix
that.
---
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 10 +++++---
 .../Linalg/generalize-tensor-unpack.mlir      | 23 +++++++++++++++++++
 2 files changed, 30 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 49fe937741c77c9..8183b40ad7346f4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1256,6 +1256,7 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
   SmallVector<OpFoldResult> readSizes;
   SmallVector<int64_t> readShape;
+  SmallVector<Value> dynamicDims;
   for (auto i : llvm::seq<unsigned>(0, destRank)) {
     if (dimAndTileMapping.count(i)) {
       readSizes.push_back(oneIdxAttr);
@@ -1263,8 +1264,10 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
     }
 
     if (ShapedType::isDynamic(srcShape[i])) {
-      readSizes.push_back(
-          rewriter.create<tensor::DimOp>(loc, source, i).getResult());
+      Value dynamicDim =
+          rewriter.create<tensor::DimOp>(loc, source, i).getResult();
+      readSizes.push_back(dynamicDim);
+      dynamicDims.push_back(dynamicDim);
     } else {
       readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
     }
@@ -1292,7 +1295,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   SmallVector<int64_t> transpShape(readShape);
   applyPermutationToVector<int64_t>(transpShape, perm);
 
-  Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+  Value empty =
+      rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
   auto transposedOp =
       rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
 
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
index a596690c2e4fd60..023768088650062 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
@@ -94,3 +94,26 @@ func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0] [1, 32, 16, 8] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
+
+// -----
+
+func.func @unpack_with_dynamic_dims(%arg0: tensor<?x1x1x1x8x32xf32>, %arg1: tensor<?x1x32x8xf32>) -> tensor<?x1x32x8xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<?x1x1x1x8x32xf32> -> tensor<?x1x32x8xf32>
+  return %0 : tensor<?x1x32x8xf32>
+}
+// CHECK-LABEL: func.func @unpack_with_dynamic_dims
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[DIM0_SRC:.+]] = tensor.dim %[[SRC]], %[[C0]] : tensor<?x1x1x1x8x32xf32>
+// CHECK:         %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0, 0] [%[[DIM0_SRC]], 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM0_SRC]]) : tensor<?x32x8xf32>
+// CHECK:         %[[TRANSP:.+]] =  linalg.transpose
+// CHECK-SAME:      ins(%[[TILE]] : tensor<?x8x32xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<?x32x8xf32>)
+// CHECK-SAME:      permutation = [0, 2, 1]
+// CHECK:         %[[DIM0_DEST:.+]] = tensor.dim %[[DEST]], %[[C0]] : tensor<?x1x32x8xf32>
+// CHECK:         %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[TRANSP]][0, 0, 0] [%[[DIM0_DEST]], 32, 8] [1, 1, 1] : tensor<?x32x8xf32> to tensor<?x32x8xf32>
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT_SLICE]] into %[[DEST]]
+// CHECK-SAME:      [0, 0, 0, 0] [%[[DIM0_DEST]], 1, 32, 8] [1, 1, 1, 1]
+// CHECK:         return %[[INSERT]]



More information about the Mlir-commits mailing list