[Mlir-commits] [mlir] [mlir][memref] Simplify expand_shape size/stride computation using output_shape (PR #187844)
Longsheng Mou
llvmlistbot at llvm.org
Mon Mar 23 18:04:21 PDT 2026
================
@@ -516,98 +496,73 @@ func.func @extract_strided_metadata_of_expand_shape_all_static(
// See extract_strided_metadata_of_expand_shape_all_static for an explanation
// of the expansion.
//
-// One of the important characteristic of this test is that the dynamic
-// dimensions produced by the expand_shape appear both in the first dimension
-// (for group 1) and the non-first dimension (second dimension for group 2.)
-// The idea is to make sure that:
-// 1. We properly account for dynamic shapes even when the strides are not
-// affected by them. (When the dynamic dimension is the first one.)
-// 2. We properly compute the strides affected by dynamic shapes. (When the
-// dynamic dimension is not the first one.)
//
// Here we have:
// For the group applying to dim0:
-// size 0 = baseSizes#0 / (all static sizes in that group)
-// = baseSizes#0 / (7 * 8 * 9)
-// = baseSizes#0 / 504
-// size 1 = 7
+// size 0 = %sz0
+// size 1 = %sz1
// size 2 = 8
// size 3 = 9
-// stride 0 = baseStrides#0 * 7 * 8 * 9
-// = baseStrides#0 * 504
+// stride 0 = baseStrides#0 * %sz1 * 8 * 9
+// = baseStrides#0 * %sz1 *72
// stride 1 = baseStrides#0 * 8 * 9
// = baseStrides#0 * 72
// stride 2 = baseStrides#0 * 9
// stride 3 = baseStrides#0
//
// For the group applying to dim1:
// size 4 = 10
-// size 5 = 2
-// size 6 = baseSizes#1 / (all static sizes in that group)
-// = baseSizes#1 / (10 * 2 * 3)
-// = baseSizes#1 / 60
+// size 5 = %sz2
+// size 6 = %sz3
// size 7 = 3
// stride 4 = baseStrides#1 * size 5 * size 6 * size 7
-// = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
-// = baseStrides#1 * (baseSizes#1 / 60) * 6
-// and since we know that baseSizes#1 is a multiple of 60:
-// = baseStrides#1 * (baseSizes#1 / 10)
+// = baseStrides#1 * %sz2 * %sz3 *3
// stride 5 = baseStrides#1 * size 6 * size 7
-// = baseStrides#1 * (baseSizes#1 / 60) * 3
-// = baseStrides#1 * (baseSizes#1 / 20)
+// = baseStrides#1 * %sz3 *3
// stride 6 = baseStrides#1 * size 7
// = baseStrides#1 * 3
// stride 7 = baseStrides#1
//
// Base and offset are unchanged.
//
-// CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
-// CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
-//
-// CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
+// CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 72)>
// CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
// CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
-// CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
-// CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
+// CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1, s2] -> (((s1 * s2) * s0) * 3)>
+// CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 3)>
// CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
+// CHECK-SAME: %[[SIZE0:.*]]: index, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index, %[[SIZE3:.*]]: index)
//
// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
//
-// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
-// CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
-// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
+// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[SIZE1]], %[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
-// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
-// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
+// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZE2]], %[[SIZE3]], %[[STRIDES]]#1]
+// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZE3]], %[[STRIDES]]#1]
// CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
-// CHECK: return %[[BASE]], %[[OFFSET]], %[[DYN_SIZE0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[DYN_SIZE6]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index
+// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE0]], %[[SIZE1]], %[[C8]], %[[C9]], %[[C10]], %[[SIZE2]], %[[SIZE3]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index
func.func @extract_strided_metadata_of_expand_shape_all_dynamic(
%base: memref<?x?xf32, strided<[?,?], offset:?>>,
- %offset0: index, %offset1: index, %offset2: index,
- %size0: index, %size1: index, %size2: index,
- %stride0: index, %stride1: index, %stride2: index,
- %sz0: index, %sz1: index)
+ %sz0: index, %sz1: index, %sz2: index, %sz3: index)
-> (memref<f32>, index,
index, index, index, index, index, index, index, index,
index, index, index, index, index, index, index, index) {
- %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] :
----------------
CoTinker wrote:
Actually, I think the subview is typo, the name should be expand_shape.
https://github.com/llvm/llvm-project/pull/187844
More information about the Mlir-commits
mailing list