[Mlir-commits] [mlir] [MLIR][MemRef] Normalize memref.alloc ops with non trivial layout map (PR #129875)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 5 04:08:32 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-affine

Author: Arnab Dutta  (arnab-polymage)

<details>
<summary>Changes</summary>



---

Patch is 29.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129875.diff


7 Files Affected:

- (modified) mlir/lib/Analysis/FlatLinearValueConstraints.cpp (+5-5) 
- (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+41-15) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp (+3-1) 
- (modified) mlir/test/Dialect/Affine/memref-bound-check.mlir (+4-3) 
- (modified) mlir/test/Dialect/MemRef/normalize-memrefs-ops-dynamic.mlir (+71-38) 
- (modified) mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir (+1-1) 
- (modified) mlir/test/Dialect/MemRef/normalize-memrefs.mlir (+23-11) 


``````````diff
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 6ad39a3a91293..fefce0e0c087b 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -118,9 +118,11 @@ struct SemiAffineExprFlattener : public AffineExprFlattener {
       // with a positive value." (see AffineExprKind in AffineExpr.h). If this
       // assumption does not hold constraints (added above) are a contradiction.
 
+      return success();
+    } else if (localExpr.getKind() == AffineExprKind::Mul) {
+      (void)localVarCst.appendVar(VarKind::Local);
       return success();
     }
-
     // TODO: Support other semi-affine expressions.
     return failure();
   }
@@ -163,7 +165,6 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
 
     return success();
   };
-
   if (addConservativeSemiAffineBounds) {
     SemiAffineExprFlattener flattener(numDims, numSymbols);
     return flattenExprs(flattener);
@@ -229,7 +230,8 @@ LogicalResult FlatLinearConstraints::composeMatchingMap(AffineMap other) {
   assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
 
   std::vector<SmallVector<int64_t, 8>> flatExprs;
-  if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
+  if (failed(flattenAlignedMapAndMergeLocals(
+          other, &flatExprs, /*addConservativeSemiAffineBounds=*/true)))
     return failure();
   assert(flatExprs.size() == other.getNumResults());
 
@@ -796,8 +798,6 @@ LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals(
                << "composition unimplemented for semi-affine maps\n");
     return failure();
   }
-
-  // Add localCst information.
   if (localCst.getNumLocalVars() > 0) {
     unsigned numLocalVars = getNumLocalVars();
     // Insert local dims of localCst at the beginning.
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 7ef016f88be37..6671d981f2e4b 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1786,7 +1786,6 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
   }
 }
 
-// TODO: Currently works for static memrefs with a single layout map.
 template <typename AllocLikeOp>
 LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
   MemRefType memrefType = allocOp->getType();
@@ -1799,7 +1798,6 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
     // Either memrefType already had an identity map or the map couldn't be
     // transformed to an identity map.
     return failure();
-
   Value oldMemRef = allocOp->getResult();
 
   SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands());
@@ -1819,8 +1817,40 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
         b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType, newDynamicSizes,
                               allocOp->getAlignmentAttr());
   } else {
+    mlir::ValueRange dynamicSizes = allocOp->getDynamicSizes();
+    mlir::ValueRange symbolOperands = allocOp->getSymbolOperands();
+    ArrayRef<int64_t> newShape = newMemRefType.getShape();
+    ArrayRef<int64_t> oldShape = memrefType.getShape();
+    SmallVector<Value> mapOperands(oldShape.size() + symbolOperands.size());
+    SmallVector<Value> dimensionOperands;
+    unsigned dimId = 0, symId = 0;
+    // Collect all the map operands of `allocOp` (both dynamic sizes and symbol
+    // operands), which will help us to compute the dynamic sizes of the new
+    // alloc op we are going to create.
+    for (unsigned i = 0, e = oldShape.size(); i < e; i++) {
+      if (oldShape[i] == ShapedType::kDynamic)
+        mapOperands[i] = dynamicSizes[dimId++];
+      else
+        mapOperands[i] =
+            b.create<arith::ConstantIndexOp>(allocOp->getLoc(), oldShape[i]);
+    }
+    for (unsigned i = oldShape.size(), e = mapOperands.size(); i < e; i++)
+      mapOperands[i] = symbolOperands[symId++];
+    // Compute the dynamic sizes operands for the new alloc op. If `newShape` is
+    // dynamic along a dimension, compute its shape using the layout map and
+    // dynamic sizes and symbol operands of the old `allocOp`.
+    for (unsigned i = 0, e = newShape.size(); i < e; i++) {
+      if (newShape[i] != ShapedType::kDynamic)
+        continue;
+      AffineExpr resExpr = layoutMap.getResult(i);
+      auto resMap = AffineMap::get(layoutMap.getNumDims(),
+                                   layoutMap.getNumSymbols(), resExpr);
+      dimensionOperands.push_back(
+          b.create<AffineApplyOp>(allocOp->getLoc(), resMap, mapOperands));
+    }
     newAlloc = b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType,
-                                     allocOp->getAlignmentAttr());
+                                             dimensionOperands,
+                                             allocOp->getAlignmentAttr());
   }
   // Replace all uses of the old memref.
   if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
@@ -1868,11 +1898,8 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
 
   // Normalize only static memrefs and dynamic memrefs with a tiled-layout map
   // for now.
-  // TODO: Normalize the other types of dynamic memrefs.
   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
   (void)getTileSizePos(layoutMap, tileSizePos);
-  if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty())
-    return memrefType;
 
   // We have a single map that is not an identity map. Create a new memref
   // with the right shape and an identity layout map.
@@ -1894,7 +1921,6 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
   unsigned newRank = layoutMap.getNumResults();
   if (failed(fac.composeMatchingMap(layoutMap)))
     return memrefType;
-  // TODO: Handle semi-affine maps.
   // Project out the old data dimensions.
   fac.projectOut(newRank, fac.getNumVars() - newRank - fac.getNumLocalVars());
   SmallVector<int64_t, 4> newShape(newRank);
@@ -1910,14 +1936,14 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
     // For a static memref and an affine map with no symbols, this is
     // always bounded. However, when we have symbols, we may not be able to
     // obtain a constant upper bound. Also, mapping to a negative space is
-    // invalid for normalization.
-    if (!ubConst.has_value() || *ubConst < 0) {
-      LLVM_DEBUG(llvm::dbgs()
-                 << "can't normalize map due to unknown/invalid upper bound");
-      return memrefType;
-    }
-    // If dimension of new memrefType is dynamic, the value is -1.
-    newShape[d] = *ubConst + 1;
+    // invalid for normalization. If dimension of new memrefType is dynamic,
+    // the value is `ShapedType::kDynamic`.
+    if (!ubConst.has_value())
+        newShape[d] = ShapedType::kDynamic;
+      else if (*ubConst >= 0)
+        newShape[d] = *ubConst + 1;
+      else
+        return memrefType;
   }
 
   // Create the new memref type after trivializing the old layout map.
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index 08b853fe65b85..d2bd95b5996c8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -445,8 +445,10 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
           if (oldMemRefType == newMemRefType)
             continue;
           // TODO: Assume single layout map. Multiple maps not supported.
+          // TODO: Semi-affine layout not supported.
           AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
-          if (failed(replaceAllMemRefUsesWith(oldMemRef,
+          if (!layoutMap.getResult(0).isPureAffine() ||
+              failed(replaceAllMemRefUsesWith(oldMemRef,
                                               /*newMemRef=*/newMemRef,
                                               /*extraIndices=*/{},
                                               /*indexRemap=*/layoutMap,
diff --git a/mlir/test/Dialect/Affine/memref-bound-check.mlir b/mlir/test/Dialect/Affine/memref-bound-check.mlir
index 80909abee51d6..321b2ba4a914f 100644
--- a/mlir/test/Dialect/Affine/memref-bound-check.mlir
+++ b/mlir/test/Dialect/Affine/memref-bound-check.mlir
@@ -124,13 +124,14 @@ func.func @mod_floordiv_nested() {
   return
 }
 
-// CHECK-LABEL: func @test_semi_affine_bailout
-func.func @test_semi_affine_bailout(%N : index) {
+// CHECK-LABEL: func @test_semi_affine_access
+func.func @test_semi_affine_access(%N : index) {
   %B = memref.alloc() : memref<10 x i32>
   affine.for %i = 0 to 10 {
     %idx = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%i)[%N]
     %y = affine.load %B[%idx] : memref<10 x i32>
-    // expected-error at -1 {{getMemRefRegion: compose affine map failed}}
+    // expected-error at -1 {{'affine.load' op memref out of upper bound access along dimension #1}}
+    // expected-error at -2 {{'affine.load' op memref out of lower bound access along dimension #1}}
   }
   return
 }
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs-ops-dynamic.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs-ops-dynamic.mlir
index a3f256b30c6a0..0cf8668561395 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs-ops-dynamic.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs-ops-dynamic.mlir
@@ -98,14 +98,15 @@ func.func @test_norm_dynamic1234(%arg0 : memref<?x?x?x?xf32, #map_tiled1>) -> ()
 // -----
 
 // Same test with maps that are not tiled layout maps in the arguments and the operations in the function.
-// This is not normalized since this is not tiled-layout map. No mod and floordiv.
 
 #map_not_tiled0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 - d2)>
 
-// CHECK-DAG: #[[$MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 - d2)>
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2 - d1)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3 - d2)>
 
 // CHECK-LABEL:  func @test_norm_dynamic_not_tiled0
-// CHECK-SAME:   ([[ARG_0_:%.+]]: memref<1x?x?x14xf32, #[[$MAP6]]>) {
+// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x?xf32>) {
 func.func @test_norm_dynamic_not_tiled0(%arg0 : memref<1x?x?x14xf32, #map_not_tiled0>) -> () {
     %c1 = arith.constant 1 : index
     %c2 = arith.constant 2 : index
@@ -118,11 +119,16 @@ func.func @test_norm_dynamic_not_tiled0(%arg0 : memref<1x?x?x14xf32, #map_not_ti
     // CHECK-DAG:       [[CST_1_:%.+]] = arith.constant 1 : index
     // CHECK-DAG:       [[CST_2_:%.+]] = arith.constant 2 : index
     // CHECK-NOT: separator of consecutive DAGs
-    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14xf32, #[[$MAP6]]>
-    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14xf32, #[[$MAP6]]>
-    // CHECK:           [[RES_:%.+]] = memref.alloc([[DIM_0_]], [[DIM_1_]]) : memref<1x?x?x14xf32, #[[$MAP6]]>
-    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14xf32, #[[$MAP6]]>, memref<1x?x?x14xf32, #[[$MAP6]]>) -> ()
-    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x14xf32, #[[$MAP6]]>
+    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x?xf32>
+    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x?xf32>
+    // CHECK-DAG:       [[C_1_:%.+]] = arith.constant 1 : index
+    // CHECK-DAG:       [[C_14_:%.+]] = arith.constant 14 : index
+    // CHECK:           [[T0_:%.+]] = affine.apply #[[$MAP]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T1_:%.+]] = affine.apply #[[$MAP1]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T2_:%.+]] = affine.apply #[[$MAP2]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[RES_:%.+]] = memref.alloc([[T0_]], [[T1_]], [[T2_]]) : memref<1x?x?x?xf32>
+    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x?xf32>, memref<1x?x?x?xf32>) -> ()
+    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x?xf32>
     // CHECK:           return
 }
 
@@ -133,10 +139,13 @@ func.func @test_norm_dynamic_not_tiled0(%arg0 : memref<1x?x?x14xf32, #map_not_ti
 
 #map_not_tiled1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 - d2, d2 mod 32, d3 mod 64)>
 
-// CHECK-DAG: #[[$MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 - d2, d2 mod 32, d3 mod 64)>
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d2 - d1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3 - d2)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d2 mod 32)>
 
 // CHECK-LABEL:  func @test_norm_dynamic_not_tiled1
-// CHECK-SAME:   ([[ARG_0_:%.+]]: memref<1x?x?x14xf32, #[[$MAP6]]>) {
+// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x?x?x64xf32>) {
 func.func @test_norm_dynamic_not_tiled1(%arg0 : memref<1x?x?x14xf32, #map_not_tiled1>) -> () {
     %c1 = arith.constant 1 : index
     %c2 = arith.constant 2 : index
@@ -149,11 +158,17 @@ func.func @test_norm_dynamic_not_tiled1(%arg0 : memref<1x?x?x14xf32, #map_not_ti
     // CHECK-DAG:       [[CST_1_:%.+]] = arith.constant 1 : index
     // CHECK-DAG:       [[CST_2_:%.+]] = arith.constant 2 : index
     // CHECK-NOT: separator of consecutive DAGs
-    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14xf32, #[[$MAP6]]>
-    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14xf32, #[[$MAP6]]>
-    // CHECK:           [[RES_:%.+]] = memref.alloc([[DIM_0_]], [[DIM_1_]]) : memref<1x?x?x14xf32, #[[$MAP6]]>
-    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14xf32, #[[$MAP6]]>, memref<1x?x?x14xf32, #[[$MAP6]]>) -> ()
-    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x14xf32, #[[$MAP6]]>
+    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x?x?x64xf32>
+    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x?x?x64xf32>
+    // CHECK-DAG:       [[C_1_:%.+]] = arith.constant 1 : index
+    // CHECK-DAG:       [[C_14_:%.+]] = arith.constant 14 : index
+    // CHECK:           [[T0_:%.+]] = affine.apply #[[$MAP]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T1_:%.+]] = affine.apply #[[$MAP1]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T2_:%.+]] = affine.apply #[[$MAP2]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T3_:%.+]] = affine.apply #[[$MAP3]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[RES_:%.+]] = memref.alloc([[T0_]], [[T1_]], [[T2_]], [[T3_]]) : memref<1x?x?x?x?x64xf32>
+    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x?x?x64xf32>, memref<1x?x?x?x?x64xf32>) -> ()
+    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x?x?x64xf32>
     // CHECK:           return
 }
 
@@ -164,10 +179,12 @@ func.func @test_norm_dynamic_not_tiled1(%arg0 : memref<1x?x?x14xf32, #map_not_ti
 
 #map_not_tiled2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 floordiv 64, d2 mod 32, d3 mod 32)>
 
-// CHECK-DAG: #[[$MAP7:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 floordiv 64, d2 mod 32, d3 mod 32)>
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d2 - d1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2 mod 32)>
 
 // CHECK-LABEL:  func @test_norm_dynamic_not_tiled2
-// CHECK-SAME:   ([[ARG_0_:%.+]]: memref<1x?x?x14xf32, #[[$MAP7]]>) {
+// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x1x?x32xf32>) {
 func.func @test_norm_dynamic_not_tiled2(%arg0 : memref<1x?x?x14xf32, #map_not_tiled2>) -> () {
     %c1 = arith.constant 1 : index
     %c2 = arith.constant 2 : index
@@ -180,11 +197,16 @@ func.func @test_norm_dynamic_not_tiled2(%arg0 : memref<1x?x?x14xf32, #map_not_ti
     // CHECK-DAG:       [[CST_1_:%.+]] = arith.constant 1 : index
     // CHECK-DAG:       [[CST_2_:%.+]] = arith.constant 2 : index
     // CHECK-NOT: separator of consecutive DAGs
-    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14xf32, #[[$MAP7]]>
-    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14xf32, #[[$MAP7]]>
-    // CHECK:           [[RES_:%.+]] = memref.alloc([[DIM_0_]], [[DIM_1_]]) : memref<1x?x?x14xf32, #[[$MAP7]]>
-    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14xf32, #[[$MAP7]]>, memref<1x?x?x14xf32, #[[$MAP7]]>) -> ()
-    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x14xf32, #[[$MAP7]]>
+    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x1x?x32xf32>
+    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x1x?x32xf32>
+    // CHECK-DAG:       [[C_1_:%.+]] = arith.constant 1 : index
+    // CHECK-DAG:       [[C_14_:%.+]] = arith.constant 14 : index
+    // CHECK:           [[T0_:%.+]] = affine.apply #[[$MAP]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T1_:%.+]] = affine.apply #[[$MAP1]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T2_:%.+]] = affine.apply #[[$MAP2]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[RES_:%.+]] = memref.alloc([[T0_]], [[T1_]], [[T2_]]) : memref<1x?x?x1x?x32xf32>
+    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x1x?x32xf32>, memref<1x?x?x1x?x32xf32>) -> ()
+    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x1x?x32xf32>
     // CHECK:           return
 }
 
@@ -195,10 +217,11 @@ func.func @test_norm_dynamic_not_tiled2(%arg0 : memref<1x?x?x14xf32, #map_not_ti
 
 #map_not_tiled3 = affine_map<(d0, d1, d2, d3) -> (d0, d1 floordiv 32, d2, d3, d1 mod 32, d1 mod 32)>
 
-// CHECK-DAG: #[[$MAP8:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 floordiv 32, d2, d3, d1 mod 32, d1 mod 32)>
-
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1 floordiv 32)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d2)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1 mod 32)>
 // CHECK-LABEL:  func @test_norm_dynamic_not_tiled3
-// CHECK-SAME:   ([[ARG_0_:%.+]]: memref<1x?x?x14xf32, #[[$MAP8]]>) {
+// CHECK-SAME:   ([[ARG_0_:%.+]]: memref<1x?x?x14x?x?xf32>) {
 func.func @test_norm_dynamic_not_tiled3(%arg0 : memref<1x?x?x14xf32, #map_not_tiled3>) -> () {
     %c1 = arith.constant 1 : index
     %c2 = arith.constant 2 : index
@@ -211,11 +234,17 @@ func.func @test_norm_dynamic_not_tiled3(%arg0 : memref<1x?x?x14xf32, #map_not_ti
     // CHECK-DAG:       [[CST_1_:%.+]] = arith.constant 1 : index
     // CHECK-DAG:       [[CST_2_:%.+]] = arith.constant 2 : index
     // CHECK-NOT: separator of consecutive DAGs
-    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14xf32, #[[$MAP8]]>
-    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14xf32, #[[$MAP8]]>
-    // CHECK:           [[RES_:%.+]] = memref.alloc([[DIM_0_]], [[DIM_1_]]) : memref<1x?x?x14xf32, #[[$MAP8]]>
-    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14xf32, #[[$MAP8]]>, memref<1x?x?x14xf32, #[[$MAP8]]>) -> ()
-    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x14xf32, #[[$MAP8]]>
+    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14x?x?xf32>
+    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14x?x?xf32>
+    // CHECK-DAG:       [[C_1_:%.+]] = arith.constant 1 : index
+    // CHECK-DAG:       [[C_14_:%.+]] = arith.constant 14 : index
+    // CHECK:           [[T0_:%.+]] = affine.apply #[[$MAP]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T1_:%.+]] = affine.apply #[[$MAP1]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T2_:%.+]] = affine.apply #[[$MAP2]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T3_:%.+]]...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/129875


More information about the Mlir-commits mailing list