[Mlir-commits] [mlir] e3b93a1 - [mlir] Fix bug in pack and unpack op canonicalization for folding dynamic dims (#82539)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 28 14:39:26 PST 2024


Author: Max191
Date: 2024-02-28T17:39:22-05:00
New Revision: e3b93a16201d96745a2c64523801a2b1eb43b4c0

URL: https://github.com/llvm/llvm-project/commit/e3b93a16201d96745a2c64523801a2b1eb43b4c0
DIFF: https://github.com/llvm/llvm-project/commit/e3b93a16201d96745a2c64523801a2b1eb43b4c0.diff

LOG: [mlir] Fix bug in pack and unpack op canonicalization for folding dynamic dims (#82539)

This PR fixes a bug in the inference of pack and unpack static shapes
that should be using an inverse permutation.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e6efec14e31a60..fe2f250e6b9290 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4012,15 +4012,17 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
   llvm::SmallSetVector<int64_t, 4> innerDims;
   innerDims.insert(packOp.getInnerDimsPos().begin(),
                    packOp.getInnerDimsPos().end());
-  auto outerDimsPerm = packOp.getOuterDimsPerm();
+  SmallVector<int64_t> inverseOuterDimsPerm;
+  if (!packOp.getOuterDimsPerm().empty())
+    inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
   int srcRank = packOp.getSourceRank();
   for (auto i : llvm::seq<int64_t>(0, srcRank)) {
     if (innerDims.contains(i))
       continue;
     int64_t srcPos = i;
     int64_t destPos = i;
-    if (!outerDimsPerm.empty())
-      destPos = outerDimsPerm[srcPos];
+    if (!inverseOuterDimsPerm.empty())
+      destPos = inverseOuterDimsPerm[srcPos];
     if (ShapedType::isDynamic(srcShape[srcPos]) ==
         ShapedType::isDynamic(destShape[destPos])) {
       continue;
@@ -4240,15 +4242,17 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
                    op.getDestType().getShape().end());
   llvm::SmallSetVector<int64_t, 4> innerDims;
   innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
-  auto outerDimsPerm = op.getOuterDimsPerm();
+  SmallVector<int64_t> inverseOuterDimsPerm;
+  if (!op.getOuterDimsPerm().empty())
+    inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
   int destRank = op.getDestRank();
   for (auto i : llvm::seq<int64_t>(0, destRank)) {
     if (innerDims.contains(i))
       continue;
     int64_t srcPos = i;
     int64_t destPos = i;
-    if (!outerDimsPerm.empty())
-      srcPos = outerDimsPerm[destPos];
+    if (!inverseOuterDimsPerm.empty())
+      srcPos = inverseOuterDimsPerm[destPos];
     if (ShapedType::isDynamic(srcShape[srcPos]) ==
         ShapedType::isDynamic(destShape[destPos])) {
       continue;

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e123c77aabd57c..d17c23adfb14d8 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -822,7 +822,7 @@ func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x3
 // CHECK-LABEL: func.func @infer_src_shape_pack
 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
 // CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
-// CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
 // CHECK:         %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
 // CHECK:         return %[[PACK]]
 
@@ -841,13 +841,24 @@ func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?
 // CHECK-LABEL: func.func @infer_dest_shape_pack
 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
 // CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
-// CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
 // CHECK:         %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
-// CHECK:         %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<10x20x30x?x16xf32> to tensor<?x?x?x?x16xf32>
+// CHECK:         %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<?x20x10x30x16xf32> to tensor<?x?x?x?x16xf32>
 // CHECK:         return %[[CAST_PACK]]
 
 // -----
 
+func.func @no_infer_pack_shape(%arg0: tensor<?x32x100xf32>, %arg1: index) -> tensor<32x7x?x16x1xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty(%arg1) : tensor<32x7x?x16x1xf32>
+  %pack = tensor.pack %arg0 padding_value(%cst : f32) outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<?x32x100xf32> -> tensor<32x7x?x16x1xf32>
+  return %pack : tensor<32x7x?x16x1xf32>
+}
+// CHECK-LABEL: func.func @no_infer_pack_shape
+// CHECK-NOT:     tensor.cast
+
+// -----
+
 func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
@@ -920,9 +931,9 @@ func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tens
 // CHECK-LABEL: func.func @infer_dest_shape_unpack
 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
 // CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
-// CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
-// CHECK:         %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32> to tensor<?x?x?x?xf32>
+// CHECK:         %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<40x20x?x30xf32> to tensor<?x?x?x?xf32>
 // CHECK:         return %[[CAST_UNPACK]]
 
 // -----
@@ -938,12 +949,24 @@ func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30
 // CHECK-LABEL: func.func @infer_src_shape_unpack
 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
 // CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
-// CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
 // CHECK:         return %[[UNPACK]]
 
 // -----
 
+func.func @no_infer_unpack_shape(%arg1: tensor<32x7x?x16x1xf32>, %arg2: index) -> tensor<?x32x100xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty(%arg2) : tensor<?x32x100xf32>
+  %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<32x7x?x16x1xf32> -> tensor<?x32x100xf32>
+  return %unpack : tensor<?x32x100xf32>
+}
+// CHECK-LABEL: func.func @no_infer_unpack_shape
+// CHECK-NOT:     tensor.cast
+
+// -----
+
+
 // CHECK-LABEL: func @fold_overlapping_insert
 //  CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
 func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {


        


More information about the Mlir-commits mailing list