[Mlir-commits] [mlir] [mlir] Fix bug in pack op canonicalization for folding dynamic dims (PR #82539)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 21 13:44:34 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
Author: None (Max191)
<details>
<summary>Changes</summary>
This PR fixes a bug in the inference of pack static shapes that should be using an inverse permutation.
---
Full diff: https://github.com/llvm/llvm-project/pull/82539.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+5-3)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+3-3)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e6efec14e31a60..b687bc8768056b 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4012,15 +4012,17 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
llvm::SmallSetVector<int64_t, 4> innerDims;
innerDims.insert(packOp.getInnerDimsPos().begin(),
packOp.getInnerDimsPos().end());
- auto outerDimsPerm = packOp.getOuterDimsPerm();
+ SmallVector<int64_t> inverseOuterDimsPerm;
+ if (!packOp.getOuterDimsPerm().empty())
+ inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
int srcRank = packOp.getSourceRank();
for (auto i : llvm::seq<int64_t>(0, srcRank)) {
if (innerDims.contains(i))
continue;
int64_t srcPos = i;
int64_t destPos = i;
- if (!outerDimsPerm.empty())
- destPos = outerDimsPerm[srcPos];
+ if (!inverseOuterDimsPerm.empty())
+ destPos = inverseOuterDimsPerm[srcPos];
if (ShapedType::isDynamic(srcShape[srcPos]) ==
ShapedType::isDynamic(destShape[destPos])) {
continue;
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e123c77aabd57c..5a754dd0d61cf5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -822,7 +822,7 @@ func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x3
// CHECK-LABEL: func.func @infer_src_shape_pack
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
-// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
// CHECK: %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
// CHECK: return %[[PACK]]
@@ -841,9 +841,9 @@ func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?
// CHECK-LABEL: func.func @infer_dest_shape_pack
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
-// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
-// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<10x20x30x?x16xf32> to tensor<?x?x?x?x16xf32>
+// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<?x20x10x30x16xf32> to tensor<?x?x?x?x16xf32>
// CHECK: return %[[CAST_PACK]]
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/82539
More information about the Mlir-commits
mailing list