[Mlir-commits] [mlir] [mlir][linalg] Bail out tensor.cast pack/unpack fold on unprovable tile sizes (PR #188000)

Hocky Yudhiono llvmlistbot at llvm.org
Wed Mar 25 18:55:08 PDT 2026


================
@@ -0,0 +1,149 @@
+// RUN: mlir-opt %s --inline -canonicalize="test-convergence" -split-input-file | FileCheck %s --check-prefixes=CHECK
+
+// CHECK: func.func @dynamic_tile_arg_no_fold
+// CHECK-SAME:  %[[SRC:.+]]: tensor<1x3x8x1xi32>, %[[TILE:.+]]: index
+// CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-DAG:   %[[CAST:.+]] = tensor.cast %[[SRC]] : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+// CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[CAST]]
+// CHECK-SAME:    inner_dims_pos = [0, 1]
+// CHECK-SAME:    inner_tiles = [%[[TILE]], 1]
+// CHECK-SAME:    into %[[EMPTY]] : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
+module {
+  func.func @dynamic_tile_arg_no_fold(%arg0: tensor<1x3x8x1xi32>, %arg1: index) -> tensor<7x3xi32> {
+    %0 = tensor.empty() : tensor<7x3xi32>
+    %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+    %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%arg1, 1] into %0 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+    return %unpack : tensor<7x3xi32>
+  }
+}
+
+
+// -----
+
+// CHECK-LABEL: func.func @dynamic_tile_from_inlined_mismatch_no_fold
+// CHECK-DAG:   %[[C256:.+]] = arith.constant 256 : index
+// CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-DAG:   %[[CAST:.+]] = tensor.cast %{{.+}} : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+// CHECK:       %[[UNPACK:.+]] = linalg.unpack %[[CAST]]
+// CHECK-SAME:    inner_dims_pos = [0, 1]
+// CHECK-SAME:    inner_tiles = [%[[C256]], 1]
+// CHECK-SAME:    into %[[EMPTY]] : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
+module {
+  func.func @get_tile() -> index {
+    %c256 = arith.constant 256 : index
+    return %c256 : index
+  }
+  func.func @dynamic_tile_from_inlined_mismatch_no_fold(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> {
+    %0 = call @get_tile() : () -> index
+    %1 = tensor.empty() : tensor<7x3xi32>
+    %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32>
+    %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%0, 1] into %1 : tensor<?x3x?x1xi32> -> tensor<7x3xi32>
+    return %unpack : tensor<7x3xi32>
+  }
+}
+
+
+// -----
+
+// CHECK-LABEL: func.func @constant_tile_from_inlined_match_folds
+// CHECK:       %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32>
+// CHECK-NOT:   tensor.cast
+// CHECK:       %[[UNPACK:.+]] = linalg.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 1]
+// CHECK-SAME:    into %[[EMPTY]] : tensor<1x3x8x1xi32> -> tensor<7x3xi32>
+// CHECK:       return %[[UNPACK]] : tensor<7x3xi32>
+module {
+  func.func @get_tile() -> index {
+    %c8 = arith.constant 8 : index
+    return %c8 : index
+  }
----------------
hockyy wrote:

Refactored and moved the testcases to `canonicalize.mlir`

https://github.com/llvm/llvm-project/pull/188000


More information about the Mlir-commits mailing list