[Mlir-commits] [mlir] [mlir][memref] Simplify expand_shape size/stride computation using output_shape (PR #187844)

Longsheng Mou llvmlistbot at llvm.org
Mon Mar 23 18:04:21 PDT 2026


================
@@ -516,98 +496,73 @@ func.func @extract_strided_metadata_of_expand_shape_all_static(
 // See extract_strided_metadata_of_expand_shape_all_static for an explanation
 // of the expansion.
 //
-// One of the important characteristic of this test is that the dynamic
-// dimensions produced by the expand_shape appear both in the first dimension
-// (for group 1) and the non-first dimension (second dimension for group 2.)
-// The idea is to make sure that:
-// 1. We properly account for dynamic shapes even when the strides are not
-//    affected by them. (When the dynamic dimension is the first one.)
-// 2. We properly compute the strides affected by dynamic shapes. (When the
-//    dynamic dimension is not the first one.)
 //
 // Here we have:
 // For the group applying to dim0:
-// size 0 = baseSizes#0 / (all static sizes in that group)
-//        = baseSizes#0 / (7 * 8 * 9)
-//        = baseSizes#0 / 504
-// size 1 = 7
+// size 0 = %sz0
+// size 1 = %sz1
 // size 2 = 8
 // size 3 = 9
-// stride 0 = baseStrides#0 * 7 * 8 * 9
-//          = baseStrides#0 * 504
+// stride 0 = baseStrides#0 * %sz1 * 8 * 9
+//          = baseStrides#0 * %sz1 *72
 // stride 1 = baseStrides#0 * 8 * 9
 //          = baseStrides#0 * 72
 // stride 2 = baseStrides#0 * 9
 // stride 3 = baseStrides#0
 //
 // For the group applying to dim1:
 // size 4 = 10
-// size 5 = 2
-// size 6 = baseSizes#1 / (all static sizes in that group)
-//        = baseSizes#1 / (10 * 2 * 3)
-//        = baseSizes#1 / 60
+// size 5 = %sz2
+// size 6 = %sz3
 // size 7 = 3
 // stride 4 = baseStrides#1 * size 5 * size 6 * size 7
-//          = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
-//          = baseStrides#1 * (baseSizes#1 / 60) * 6
-//          and since we know that baseSizes#1 is a multiple of 60:
-//          = baseStrides#1 * (baseSizes#1 / 10)
+//          = baseStrides#1 * %sz2 * %sz3 *3
 // stride 5 = baseStrides#1 * size 6 * size 7
-//          = baseStrides#1 * (baseSizes#1 / 60) * 3
-//          = baseStrides#1 * (baseSizes#1 / 20)
+//          = baseStrides#1 * %sz3 *3
 // stride 6 = baseStrides#1 * size 7
 //          = baseStrides#1 * 3
 // stride 7 = baseStrides#1
 //
 // Base and offset are unchanged.
 //
-//   CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
-//   CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
-//
-//   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
+//   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 72)>
 //   CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
 //   CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
-//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
-//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
+//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1, s2] -> (((s1 * s2) * s0) * 3)>
+//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 3)>
 //   CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
 // CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic
 //  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
+//  CHECK-SAME: %[[SIZE0:.*]]: index,  %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index,  %[[SIZE3:.*]]: index)
 //
 //   CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
 //   CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
 //   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-//   CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
 //   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
 //
 //   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
 //
-//   CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
-//   CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
-//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[SIZE1]], %[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
-//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
-//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZE2]], %[[SIZE3]], %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZE3]], %[[STRIDES]]#1]
 //   CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
 
-//   CHECK: return %[[BASE]], %[[OFFSET]], %[[DYN_SIZE0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[DYN_SIZE6]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index
+//   CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE0]], %[[SIZE1]], %[[C8]], %[[C9]], %[[C10]], %[[SIZE2]], %[[SIZE3]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index
 func.func @extract_strided_metadata_of_expand_shape_all_dynamic(
     %base: memref<?x?xf32, strided<[?,?], offset:?>>,
-    %offset0: index, %offset1: index, %offset2: index,
-    %size0: index, %size1: index, %size2: index,
-    %stride0: index, %stride1: index, %stride2: index,
-    %sz0: index, %sz1: index)
+    %sz0: index, %sz1: index, %sz2: index, %sz3: index)
     -> (memref<f32>, index,
        index, index, index, index, index, index, index, index,
        index, index, index, index, index, index, index, index) {
 
-  %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] :
----------------
CoTinker wrote:

Actually, I think the subview is typo, the name should be expand_shape.

https://github.com/llvm/llvm-project/pull/187844


More information about the Mlir-commits mailing list