[Mlir-commits] [mlir] [MLIR][linalg] Fix unpack rewriter for dynamic shapes (PR #67096)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 22 00:52:46 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
Prior to this patch, `GeneralizeOuterUnitDimsUnPackOpPattern` would assert that we cannot create a `tensor.empty` operation with dynamic shapes.
The problem stems from the fact that we were not using the right builder for the `tensor.empty` operation. Indeed, each dynamic dim needs to be specified by an input variable.
Simply provide the dynamic dimensions to the `tensor.empty` builder to fix that.
---
Full diff: https://github.com/llvm/llvm-project/pull/67096.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+7-3)
- (modified) mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir (+23)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 49fe937741c77c9..8183b40ad7346f4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1256,6 +1256,7 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> readSizes;
SmallVector<int64_t> readShape;
+ SmallVector<Value> dynamicDims;
for (auto i : llvm::seq<unsigned>(0, destRank)) {
if (dimAndTileMapping.count(i)) {
readSizes.push_back(oneIdxAttr);
@@ -1263,8 +1264,10 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
}
if (ShapedType::isDynamic(srcShape[i])) {
- readSizes.push_back(
- rewriter.create<tensor::DimOp>(loc, source, i).getResult());
+ Value dynamicDim =
+ rewriter.create<tensor::DimOp>(loc, source, i).getResult();
+ readSizes.push_back(dynamicDim);
+ dynamicDims.push_back(dynamicDim);
} else {
readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
}
@@ -1292,7 +1295,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
SmallVector<int64_t> transpShape(readShape);
applyPermutationToVector<int64_t>(transpShape, perm);
- Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
+ Value empty =
+ rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
auto transposedOp =
rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
index a596690c2e4fd60..023768088650062 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
@@ -94,3 +94,26 @@ func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0] [1, 32, 16, 8] [1, 1, 1, 1]
// CHECK: return %[[INSERT]]
+
+// -----
+
+func.func @unpack_with_dynamic_dims(%arg0: tensor<?x1x1x1x8x32xf32>, %arg1: tensor<?x1x32x8xf32>) -> tensor<?x1x32x8xf32> {
+ %0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<?x1x1x1x8x32xf32> -> tensor<?x1x32x8xf32>
+ return %0 : tensor<?x1x32x8xf32>
+}
+// CHECK-LABEL: func.func @unpack_with_dynamic_dims
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM0_SRC:.+]] = tensor.dim %[[SRC]], %[[C0]] : tensor<?x1x1x1x8x32xf32>
+// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0, 0] [%[[DIM0_SRC]], 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0_SRC]]) : tensor<?x32x8xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[TILE]] : tensor<?x8x32xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x32x8xf32>)
+// CHECK-SAME: permutation = [0, 2, 1]
+// CHECK: %[[DIM0_DEST:.+]] = tensor.dim %[[DEST]], %[[C0]] : tensor<?x1x32x8xf32>
+// CHECK: %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[TRANSP]][0, 0, 0] [%[[DIM0_DEST]], 32, 8] [1, 1, 1] : tensor<?x32x8xf32> to tensor<?x32x8xf32>
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT_SLICE]] into %[[DEST]]
+// CHECK-SAME: [0, 0, 0, 0] [%[[DIM0_DEST]], 1, 32, 8] [1, 1, 1, 1]
+// CHECK: return %[[INSERT]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/67096
More information about the Mlir-commits
mailing list