[Mlir-commits] [mlir] [Draft][MLIR] Add reshape propagation through tensor.pad (PR #136681)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 28 08:42:45 PDT 2025


================
@@ -893,3 +893,26 @@ func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       ins(%[[EXPANDED]] :
 //      CHECK:   return %[[GENERIC]]
+
+// -----
+
+func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0   = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+  %padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] {
+    ^bb0(%i: index, %j: index, %k: index):
+      tensor.yield %cst : f32
+  } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+  %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+  return %expanded : tensor<32x16x258x258xf32>
+}
+//      CHECK: func @fold_tensor_pad_with_expand(
+// CHECK-SAME:     %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+//  CHECK-DAG:   %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+//      CHECK:   %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>)
+//      CHECK:   %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+//      CHECK:   ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+//      CHECK:     tensor.yield %[[CST]] : f32
+//      CHECK:   } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
+//      CHECK:   return %[[PADDED]] : tensor<32x16x258x258xf32>
----------------
Max191 wrote:

nit: Add new line at the end of the file

https://github.com/llvm/llvm-project/pull/136681


More information about the Mlir-commits mailing list