[Mlir-commits] [mlir] [MLIR][MemRef] Normalize memref.alloc ops with non trivial layout map (PR #129875)

Arnab Dutta llvmlistbot at llvm.org
Wed Mar 5 04:07:53 PST 2025


https://github.com/arnab-polymage created https://github.com/llvm/llvm-project/pull/129875

None

>From 9f3b4a9e5cacad4fe355ef0d52a858a21474e067 Mon Sep 17 00:00:00 2001
From: Arnab Dutta <arnab at polymagelabs.com>
Date: Wed, 5 Mar 2025 17:33:56 +0530
Subject: [PATCH] [MLIR][MemRef] Normalize memref.alloc ops with non trivial
 layout map

---
 .../Analysis/FlatLinearValueConstraints.cpp   |  10 +-
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       |  56 ++++++---
 .../MemRef/Transforms/NormalizeMemRefs.cpp    |   4 +-
 .../Dialect/Affine/memref-bound-check.mlir    |   7 +-
 .../MemRef/normalize-memrefs-ops-dynamic.mlir | 109 ++++++++++++------
 .../Dialect/MemRef/normalize-memrefs-ops.mlir |   2 +-
 .../Dialect/MemRef/normalize-memrefs.mlir     |  34 ++++--
 7 files changed, 148 insertions(+), 74 deletions(-)

diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 6ad39a3a91293..fefce0e0c087b 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -118,9 +118,11 @@ struct SemiAffineExprFlattener : public AffineExprFlattener {
       // with a positive value." (see AffineExprKind in AffineExpr.h). If this
       // assumption does not hold constraints (added above) are a contradiction.
 
+      return success();
+    } else if (localExpr.getKind() == AffineExprKind::Mul) {
+      (void)localVarCst.appendVar(VarKind::Local);
       return success();
     }
-
     // TODO: Support other semi-affine expressions.
     return failure();
   }
@@ -163,7 +165,6 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
 
     return success();
   };
-
   if (addConservativeSemiAffineBounds) {
     SemiAffineExprFlattener flattener(numDims, numSymbols);
     return flattenExprs(flattener);
@@ -229,7 +230,8 @@ LogicalResult FlatLinearConstraints::composeMatchingMap(AffineMap other) {
   assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
 
   std::vector<SmallVector<int64_t, 8>> flatExprs;
-  if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
+  if (failed(flattenAlignedMapAndMergeLocals(
+          other, &flatExprs, /*addConservativeSemiAffineBounds=*/true)))
     return failure();
   assert(flatExprs.size() == other.getNumResults());
 
@@ -796,8 +798,6 @@ LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals(
                << "composition unimplemented for semi-affine maps\n");
     return failure();
   }
-
-  // Add localCst information.
   if (localCst.getNumLocalVars() > 0) {
     unsigned numLocalVars = getNumLocalVars();
     // Insert local dims of localCst at the beginning.
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 7ef016f88be37..6671d981f2e4b 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1786,7 +1786,6 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
   }
 }
 
-// TODO: Currently works for static memrefs with a single layout map.
 template <typename AllocLikeOp>
 LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
   MemRefType memrefType = allocOp->getType();
@@ -1799,7 +1798,6 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
     // Either memrefType already had an identity map or the map couldn't be
     // transformed to an identity map.
     return failure();
-
   Value oldMemRef = allocOp->getResult();
 
   SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands());
@@ -1819,8 +1817,40 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
         b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType, newDynamicSizes,
                               allocOp->getAlignmentAttr());
   } else {
+    mlir::ValueRange dynamicSizes = allocOp->getDynamicSizes();
+    mlir::ValueRange symbolOperands = allocOp->getSymbolOperands();
+    ArrayRef<int64_t> newShape = newMemRefType.getShape();
+    ArrayRef<int64_t> oldShape = memrefType.getShape();
+    SmallVector<Value> mapOperands(oldShape.size() + symbolOperands.size());
+    SmallVector<Value> dimensionOperands;
+    unsigned dimId = 0, symId = 0;
+    // Collect all the map operands of `allocOp` (both dynamic sizes and symbol
+    // operands), which will help us to compute the dynamic sizes of the new
+    // alloc op we are going to create.
+    for (unsigned i = 0, e = oldShape.size(); i < e; i++) {
+      if (oldShape[i] == ShapedType::kDynamic)
+        mapOperands[i] = dynamicSizes[dimId++];
+      else
+        mapOperands[i] =
+            b.create<arith::ConstantIndexOp>(allocOp->getLoc(), oldShape[i]);
+    }
+    for (unsigned i = oldShape.size(), e = mapOperands.size(); i < e; i++)
+      mapOperands[i] = symbolOperands[symId++];
+    // Compute the dynamic sizes operands for the new alloc op. If `newShape` is
+    // dynamic along a dimension, compute its shape using the layout map and
+    // dynamic sizes and symbol operands of the old `allocOp`.
+    for (unsigned i = 0, e = newShape.size(); i < e; i++) {
+      if (newShape[i] != ShapedType::kDynamic)
+        continue;
+      AffineExpr resExpr = layoutMap.getResult(i);
+      auto resMap = AffineMap::get(layoutMap.getNumDims(),
+                                   layoutMap.getNumSymbols(), resExpr);
+      dimensionOperands.push_back(
+          b.create<AffineApplyOp>(allocOp->getLoc(), resMap, mapOperands));
+    }
     newAlloc = b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType,
-                                     allocOp->getAlignmentAttr());
+                                             dimensionOperands,
+                                             allocOp->getAlignmentAttr());
   }
   // Replace all uses of the old memref.
   if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
@@ -1868,11 +1898,8 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
 
   // Normalize only static memrefs and dynamic memrefs with a tiled-layout map
   // for now.
-  // TODO: Normalize the other types of dynamic memrefs.
   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
   (void)getTileSizePos(layoutMap, tileSizePos);
-  if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty())
-    return memrefType;
 
   // We have a single map that is not an identity map. Create a new memref
   // with the right shape and an identity layout map.
@@ -1894,7 +1921,6 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
   unsigned newRank = layoutMap.getNumResults();
   if (failed(fac.composeMatchingMap(layoutMap)))
     return memrefType;
-  // TODO: Handle semi-affine maps.
   // Project out the old data dimensions.
   fac.projectOut(newRank, fac.getNumVars() - newRank - fac.getNumLocalVars());
   SmallVector<int64_t, 4> newShape(newRank);
@@ -1910,14 +1936,14 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
     // For a static memref and an affine map with no symbols, this is
     // always bounded. However, when we have symbols, we may not be able to
     // obtain a constant upper bound. Also, mapping to a negative space is
-    // invalid for normalization.
-    if (!ubConst.has_value() || *ubConst < 0) {
-      LLVM_DEBUG(llvm::dbgs()
-                 << "can't normalize map due to unknown/invalid upper bound");
-      return memrefType;
-    }
-    // If dimension of new memrefType is dynamic, the value is -1.
-    newShape[d] = *ubConst + 1;
+    // invalid for normalization. If dimension of new memrefType is dynamic,
+    // the value is `ShapedType::kDynamic`.
+    if (!ubConst.has_value())
+        newShape[d] = ShapedType::kDynamic;
+      else if (*ubConst >= 0)
+        newShape[d] = *ubConst + 1;
+      else
+        return memrefType;
   }
 
   // Create the new memref type after trivializing the old layout map.
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index 08b853fe65b85..d2bd95b5996c8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -445,8 +445,10 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
           if (oldMemRefType == newMemRefType)
             continue;
           // TODO: Assume single layout map. Multiple maps not supported.
+          // TODO: Semi-affine layout not supported.
           AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
-          if (failed(replaceAllMemRefUsesWith(oldMemRef,
+          if (!layoutMap.getResult(0).isPureAffine() ||
+              failed(replaceAllMemRefUsesWith(oldMemRef,
                                               /*newMemRef=*/newMemRef,
                                               /*extraIndices=*/{},
                                               /*indexRemap=*/layoutMap,
diff --git a/mlir/test/Dialect/Affine/memref-bound-check.mlir b/mlir/test/Dialect/Affine/memref-bound-check.mlir
index 80909abee51d6..321b2ba4a914f 100644
--- a/mlir/test/Dialect/Affine/memref-bound-check.mlir
+++ b/mlir/test/Dialect/Affine/memref-bound-check.mlir
@@ -124,13 +124,14 @@ func.func @mod_floordiv_nested() {
   return
 }
 
-// CHECK-LABEL: func @test_semi_affine_bailout
-func.func @test_semi_affine_bailout(%N : index) {
+// CHECK-LABEL: func @test_semi_affine_access
+func.func @test_semi_affine_access(%N : index) {
   %B = memref.alloc() : memref<10 x i32>
   affine.for %i = 0 to 10 {
     %idx = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%i)[%N]
     %y = affine.load %B[%idx] : memref<10 x i32>
-    // expected-error at -1 {{getMemRefRegion: compose affine map failed}}
+    // expected-error at -1 {{'affine.load' op memref out of upper bound access along dimension #1}}
+    // expected-error at -2 {{'affine.load' op memref out of lower bound access along dimension #1}}
   }
   return
 }
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs-ops-dynamic.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs-ops-dynamic.mlir
index a3f256b30c6a0..0cf8668561395 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs-ops-dynamic.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs-ops-dynamic.mlir
@@ -98,14 +98,15 @@ func.func @test_norm_dynamic1234(%arg0 : memref<?x?x?x?xf32, #map_tiled1>) -> ()
 // -----
 
 // Same test with maps that are not tiled layout maps in the arguments and the operations in the function.
-// This is not normalized since this is not tiled-layout map. No mod and floordiv.
 
 #map_not_tiled0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 - d2)>
 
-// CHECK-DAG: #[[$MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 - d2)>
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2 - d1)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3 - d2)>
 
 // CHECK-LABEL:  func @test_norm_dynamic_not_tiled0
-// CHECK-SAME:   ([[ARG_0_:%.+]]: memref<1x?x?x14xf32, #[[$MAP6]]>) {
+// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x?xf32>) {
 func.func @test_norm_dynamic_not_tiled0(%arg0 : memref<1x?x?x14xf32, #map_not_tiled0>) -> () {
     %c1 = arith.constant 1 : index
     %c2 = arith.constant 2 : index
@@ -118,11 +119,16 @@ func.func @test_norm_dynamic_not_tiled0(%arg0 : memref<1x?x?x14xf32, #map_not_ti
     // CHECK-DAG:       [[CST_1_:%.+]] = arith.constant 1 : index
     // CHECK-DAG:       [[CST_2_:%.+]] = arith.constant 2 : index
     // CHECK-NOT: separator of consecutive DAGs
-    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14xf32, #[[$MAP6]]>
-    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14xf32, #[[$MAP6]]>
-    // CHECK:           [[RES_:%.+]] = memref.alloc([[DIM_0_]], [[DIM_1_]]) : memref<1x?x?x14xf32, #[[$MAP6]]>
-    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14xf32, #[[$MAP6]]>, memref<1x?x?x14xf32, #[[$MAP6]]>) -> ()
-    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x14xf32, #[[$MAP6]]>
+    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x?xf32>
+    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x?xf32>
+    // CHECK-DAG:       [[C_1_:%.+]] = arith.constant 1 : index
+    // CHECK-DAG:       [[C_14_:%.+]] = arith.constant 14 : index
+    // CHECK:           [[T0_:%.+]] = affine.apply #[[$MAP]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T1_:%.+]] = affine.apply #[[$MAP1]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T2_:%.+]] = affine.apply #[[$MAP2]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[RES_:%.+]] = memref.alloc([[T0_]], [[T1_]], [[T2_]]) : memref<1x?x?x?xf32>
+    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x?xf32>, memref<1x?x?x?xf32>) -> ()
+    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x?xf32>
     // CHECK:           return
 }
 
@@ -133,10 +139,13 @@ func.func @test_norm_dynamic_not_tiled0(%arg0 : memref<1x?x?x14xf32, #map_not_ti
 
 #map_not_tiled1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 - d2, d2 mod 32, d3 mod 64)>
 
-// CHECK-DAG: #[[$MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 - d2, d2 mod 32, d3 mod 64)>
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d2 - d1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3 - d2)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d2 mod 32)>
 
 // CHECK-LABEL:  func @test_norm_dynamic_not_tiled1
-// CHECK-SAME:   ([[ARG_0_:%.+]]: memref<1x?x?x14xf32, #[[$MAP6]]>) {
+// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x?x?x64xf32>) {
 func.func @test_norm_dynamic_not_tiled1(%arg0 : memref<1x?x?x14xf32, #map_not_tiled1>) -> () {
     %c1 = arith.constant 1 : index
     %c2 = arith.constant 2 : index
@@ -149,11 +158,17 @@ func.func @test_norm_dynamic_not_tiled1(%arg0 : memref<1x?x?x14xf32, #map_not_ti
     // CHECK-DAG:       [[CST_1_:%.+]] = arith.constant 1 : index
     // CHECK-DAG:       [[CST_2_:%.+]] = arith.constant 2 : index
     // CHECK-NOT: separator of consecutive DAGs
-    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14xf32, #[[$MAP6]]>
-    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14xf32, #[[$MAP6]]>
-    // CHECK:           [[RES_:%.+]] = memref.alloc([[DIM_0_]], [[DIM_1_]]) : memref<1x?x?x14xf32, #[[$MAP6]]>
-    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14xf32, #[[$MAP6]]>, memref<1x?x?x14xf32, #[[$MAP6]]>) -> ()
-    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x14xf32, #[[$MAP6]]>
+    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x?x?x64xf32>
+    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x?x?x64xf32>
+    // CHECK-DAG:       [[C_1_:%.+]] = arith.constant 1 : index
+    // CHECK-DAG:       [[C_14_:%.+]] = arith.constant 14 : index
+    // CHECK:           [[T0_:%.+]] = affine.apply #[[$MAP]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T1_:%.+]] = affine.apply #[[$MAP1]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T2_:%.+]] = affine.apply #[[$MAP2]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T3_:%.+]] = affine.apply #[[$MAP3]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[RES_:%.+]] = memref.alloc([[T0_]], [[T1_]], [[T2_]], [[T3_]]) : memref<1x?x?x?x?x64xf32>
+    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x?x?x64xf32>, memref<1x?x?x?x?x64xf32>) -> ()
+    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x?x?x64xf32>
     // CHECK:           return
 }
 
@@ -164,10 +179,12 @@ func.func @test_norm_dynamic_not_tiled1(%arg0 : memref<1x?x?x14xf32, #map_not_ti
 
 #map_not_tiled2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 floordiv 64, d2 mod 32, d3 mod 32)>
 
-// CHECK-DAG: #[[$MAP7:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 floordiv 64, d2 mod 32, d3 mod 32)>
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d2 - d1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2 mod 32)>
 
 // CHECK-LABEL:  func @test_norm_dynamic_not_tiled2
-// CHECK-SAME:   ([[ARG_0_:%.+]]: memref<1x?x?x14xf32, #[[$MAP7]]>) {
+// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x1x?x32xf32>) {
 func.func @test_norm_dynamic_not_tiled2(%arg0 : memref<1x?x?x14xf32, #map_not_tiled2>) -> () {
     %c1 = arith.constant 1 : index
     %c2 = arith.constant 2 : index
@@ -180,11 +197,16 @@ func.func @test_norm_dynamic_not_tiled2(%arg0 : memref<1x?x?x14xf32, #map_not_ti
     // CHECK-DAG:       [[CST_1_:%.+]] = arith.constant 1 : index
     // CHECK-DAG:       [[CST_2_:%.+]] = arith.constant 2 : index
     // CHECK-NOT: separator of consecutive DAGs
-    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14xf32, #[[$MAP7]]>
-    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14xf32, #[[$MAP7]]>
-    // CHECK:           [[RES_:%.+]] = memref.alloc([[DIM_0_]], [[DIM_1_]]) : memref<1x?x?x14xf32, #[[$MAP7]]>
-    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14xf32, #[[$MAP7]]>, memref<1x?x?x14xf32, #[[$MAP7]]>) -> ()
-    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x14xf32, #[[$MAP7]]>
+    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x1x?x32xf32>
+    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x1x?x32xf32>
+    // CHECK-DAG:       [[C_1_:%.+]] = arith.constant 1 : index
+    // CHECK-DAG:       [[C_14_:%.+]] = arith.constant 14 : index
+    // CHECK:           [[T0_:%.+]] = affine.apply #[[$MAP]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T1_:%.+]] = affine.apply #[[$MAP1]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T2_:%.+]] = affine.apply #[[$MAP2]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[RES_:%.+]] = memref.alloc([[T0_]], [[T1_]], [[T2_]]) : memref<1x?x?x1x?x32xf32>
+    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x1x?x32xf32>, memref<1x?x?x1x?x32xf32>) -> ()
+    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x1x?x32xf32>
     // CHECK:           return
 }
 
@@ -195,10 +217,11 @@ func.func @test_norm_dynamic_not_tiled2(%arg0 : memref<1x?x?x14xf32, #map_not_ti
 
 #map_not_tiled3 = affine_map<(d0, d1, d2, d3) -> (d0, d1 floordiv 32, d2, d3, d1 mod 32, d1 mod 32)>
 
-// CHECK-DAG: #[[$MAP8:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 floordiv 32, d2, d3, d1 mod 32, d1 mod 32)>
-
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1 floordiv 32)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d2)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1 mod 32)>
 // CHECK-LABEL:  func @test_norm_dynamic_not_tiled3
-// CHECK-SAME:   ([[ARG_0_:%.+]]: memref<1x?x?x14xf32, #[[$MAP8]]>) {
+// CHECK-SAME:   ([[ARG_0_:%.+]]: memref<1x?x?x14x?x?xf32>) {
 func.func @test_norm_dynamic_not_tiled3(%arg0 : memref<1x?x?x14xf32, #map_not_tiled3>) -> () {
     %c1 = arith.constant 1 : index
     %c2 = arith.constant 2 : index
@@ -211,11 +234,17 @@ func.func @test_norm_dynamic_not_tiled3(%arg0 : memref<1x?x?x14xf32, #map_not_ti
     // CHECK-DAG:       [[CST_1_:%.+]] = arith.constant 1 : index
     // CHECK-DAG:       [[CST_2_:%.+]] = arith.constant 2 : index
     // CHECK-NOT: separator of consecutive DAGs
-    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14xf32, #[[$MAP8]]>
-    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14xf32, #[[$MAP8]]>
-    // CHECK:           [[RES_:%.+]] = memref.alloc([[DIM_0_]], [[DIM_1_]]) : memref<1x?x?x14xf32, #[[$MAP8]]>
-    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14xf32, #[[$MAP8]]>, memref<1x?x?x14xf32, #[[$MAP8]]>) -> ()
-    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x14xf32, #[[$MAP8]]>
+    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14x?x?xf32>
+    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14x?x?xf32>
+    // CHECK-DAG:       [[C_1_:%.+]] = arith.constant 1 : index
+    // CHECK-DAG:       [[C_14_:%.+]] = arith.constant 14 : index
+    // CHECK:           [[T0_:%.+]] = affine.apply #[[$MAP]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T1_:%.+]] = affine.apply #[[$MAP1]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T2_:%.+]] = affine.apply #[[$MAP2]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T3_:%.+]] = affine.apply #[[$MAP2]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])  
+    // CHECK:           [[RES_:%.+]] = memref.alloc([[T0_]], [[T1_]], [[T2_]], [[T3_]]) : memref<1x?x?x14x?x?xf32>
+    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14x?x?xf32>, memref<1x?x?x14x?x?xf32>) -> ()
+    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x14x?x?xf32>
     // CHECK:           return
 }
 
@@ -226,10 +255,10 @@ func.func @test_norm_dynamic_not_tiled3(%arg0 : memref<1x?x?x14xf32, #map_not_ti
 
 #map_not_tiled4 = affine_map<(d0, d1, d2, d3) -> (d0 floordiv 32, d1 floordiv 32, d0, d3, d0 mod 32, d1 mod 32)>
 
-// CHECK-DAG: #[[$MAP9:.+]] = affine_map<(d0, d1, d2, d3) -> (d0 floordiv 32, d1 floordiv 32, d0, d3, d0 mod 32, d1 mod 32)>
-
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1 floordiv 32)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1 mod 32)>
 // CHECK-LABEL:  func @test_norm_dynamic_not_tiled4
-// CHECK-SAME:   ([[ARG_0_:%.+]]: memref<1x?x?x14xf32, #[[$MAP9]]>) {
+// CHECK-SAME:   ([[ARG_0_:%.+]]: memref<1x?x1x14x32x?xf32>) {
 func.func @test_norm_dynamic_not_tiled4(%arg0 : memref<1x?x?x14xf32, #map_not_tiled4>) -> () {
     %c1 = arith.constant 1 : index
     %c2 = arith.constant 2 : index
@@ -242,10 +271,14 @@ func.func @test_norm_dynamic_not_tiled4(%arg0 : memref<1x?x?x14xf32, #map_not_ti
     // CHECK-DAG:       [[CST_1_:%.+]] = arith.constant 1 : index
     // CHECK-DAG:       [[CST_2_:%.+]] = arith.constant 2 : index
     // CHECK-NOT: separator of consecutive DAGs
-    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14xf32, #[[$MAP9]]>
-    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14xf32, #[[$MAP9]]>
-    // CHECK:           [[RES_:%.+]] = memref.alloc([[DIM_0_]], [[DIM_1_]]) : memref<1x?x?x14xf32, #[[$MAP9]]>
-    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14xf32, #[[$MAP9]]>, memref<1x?x?x14xf32, #[[$MAP9]]>) -> ()
-    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x?x14xf32, #[[$MAP9]]>
+    // CHECK-DAG:       [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x1x14x32x?xf32>
+    // CHECK-DAG:       [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x1x14x32x?xf32>
+    // CHECK-DAG:       [[C_1_:%.+]] = arith.constant 1 : index
+    // CHECK-DAG:       [[C_14_:%.+]] = arith.constant 14 : index
+    // CHECK:           [[T0_:%.+]] = affine.apply #[[$MAP]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[T1_:%.+]] = affine.apply #[[$MAP1]]([[C_1_]], [[DIM_0_]], [[DIM_1_]], [[C_14_]])
+    // CHECK:           [[RES_:%.+]] = memref.alloc([[T0_]], [[T1_]]) :  memref<1x?x1x14x32x?xf32>
+    // CHECK:           "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x1x14x32x?xf32>, memref<1x?x1x14x32x?xf32>) -> ()
+    // CHECK:           memref.dealloc [[RES_]] : memref<1x?x1x14x32x?xf32>
     // CHECK:           return
 }
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
index 3bede131325a7..9f843cafc0a4d 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
@@ -114,7 +114,7 @@ func.func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () {
 // CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) {
 func.func @test_norm_ret(%arg0: memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) {
     %0 = memref.alloc() : memref<1x16x14x14xf32, #map_tile>
-    // CHECK-NEXT: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x32xf32>
+    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x32xf32>
     %1, %2 = "test.op_norm_ret"(%arg0) : (memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>)
     // CHECK-NEXT: %[[v1:.*]], %[[v2:.*]] = "test.op_norm_ret"
     // CHECK-SAME: (memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>)
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index e93a1a4ebae53..6af561e0349a6 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -6,6 +6,8 @@
 // CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>
 // CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)>
 // CHECK-DAG: #[[$REDUCE_MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>
+// CHECK-DAG: #[[$STRIDED_ACCESS:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
+// CHECK-DAG: #[[$STRIDED_ACCESS_1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 3 + s0 + d1)>
 
 // CHECK-LABEL: func @permute()
 func.func @permute() {
@@ -33,7 +35,7 @@ func.func @permute() {
 
 // CHECK-LABEL: func @alloca
 func.func @alloca(%idx : index) {
-  // CHECK-NEXT: memref.alloca() : memref<65xf32>
+  // CHECK: memref.alloca() : memref<65xf32>
   %A = memref.alloca() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
   // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
   affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
@@ -47,7 +49,7 @@ func.func @alloca(%idx : index) {
 
 // CHECK-LABEL: func @shift
 func.func @shift(%idx : index) {
-  // CHECK-NEXT: memref.alloc() : memref<65xf32>
+  // CHECK: memref.alloc() : memref<65xf32>
   %A = memref.alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
   // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
   affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
@@ -144,13 +146,20 @@ func.func @symbolic_operands(%s : index) {
   return
 }
 
-// Semi-affine maps, normalization not implemented yet.
+// -----
+
+// Semi-affine maps.
 // CHECK-LABEL: func @semi_affine_layout_map
+// CHECK-SAME: %[[S0:.*]]: index, %[[S1:.*]]: index
 func.func @semi_affine_layout_map(%s0: index, %s1: index) {
+  // CHECK-DAG:  %[[C256:.*]] = arith.constant 256 : index
+  // CHECK-DAG:  %[[C1024:.*]] = arith.constant 1024 : index
+  // CHECK:      %[[DYNAMIC_SIZE:.*]] = affine.apply #[[$STRIDED_ACCESS]](%[[C256]], %[[C1024]])[%[[S0]], %[[S1]]]
+  // CHECK:      %[[ALLOC:.*]] = memref.alloc(%[[DYNAMIC_SIZE]]) : memref<?xf32>
   %A = memref.alloc()[%s0, %s1] : memref<256x1024xf32, affine_map<(d0, d1)[s0, s1] -> (d0*s0 + d1*s1)>>
   affine.for %i = 0 to 256 {
     affine.for %j = 0 to 1024 {
-      // CHECK: memref<256x1024xf32, #map{{[0-9a-zA-Z_]+}}>
+      // CHECK: affine.load %[[ALLOC]][%{{.*}} * symbol(%[[S0]]) + %{{.*}} * symbol(%[[S1]])] : memref<?xf32>
       affine.load %A[%i, %j] : memref<256x1024xf32, affine_map<(d0, d1)[s0, s1] -> (d0*s0 + d1*s1)>>
     }
   }
@@ -160,7 +169,7 @@ func.func @semi_affine_layout_map(%s0: index, %s1: index) {
 // CHECK-LABEL: func @alignment
 func.func @alignment() {
   %A = memref.alloc() {alignment = 32 : i64}: memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>>
-  // CHECK-NEXT: memref.alloc() {alignment = 32 : i64} : memref<256x64x128xf32>
+  // CHECK: memref.alloc() {alignment = 32 : i64} : memref<256x64x128xf32>
   return
 }
 
@@ -355,12 +364,15 @@ func.func @affine_parallel_norm() ->  memref<8xf32, #tile> {
 // CHECK-LABEL: func.func @map_symbol
 func.func @map_symbol() -> memref<2x3xf32, #map> {
   %c1 = arith.constant 1 : index
-  // The constant isn't propagated here and the utility can't compute a constant
-  // upper bound for the memref dimension in the absence of that.
-  // CHECK: memref.alloc()[%{{.*}}]
   %0 = memref.alloc()[%c1] : memref<2x3xf32, #map>
   return %0 : memref<2x3xf32, #map>
 }
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+// CHECK:       %[[DYNAMIC_SIZE:.*]] = affine.apply #[[$STRIDED_ACCESS_1]](%[[C2]], %[[C3]])[%[[C1]]]
+// CHECK-NEXT:  %[[ALLOC:.*]] = memref.alloc(%[[DYNAMIC_SIZE]]) : memref<?xf32>
+// CHECK-NEXT:  return %[[ALLOC]] : memref<?xf32>
 
 #neg = affine_map<(d0, d1) -> (d0, d1 - 100)>
 // CHECK-LABEL: func.func @neg_map
@@ -399,13 +411,13 @@ func.func @memref_load_with_reduction_map(%arg0 :  memref<4x4xf32,#map2>) -> ()
     affine.for %j = 0 to 8 {
       affine.for %k = 0 to 8 {
         // CHECK: %[[INDEX0:.*]] = affine.apply #[[$REDUCE_MAP1]](%{{.*}}, %{{.*}})
-        // CHECK: memref.load %alloc[%[[INDEX0]]] : memref<32xf32>
+        // CHECK: memref.load %{{.*}}[%[[INDEX0]]] : memref<32xf32>
         %a = memref.load %0[%i, %k] : memref<4x8xf32,#map0>
         // CHECK: %[[INDEX1:.*]] = affine.apply #[[$REDUCE_MAP2]](%{{.*}}, %{{.*}})
-        // CHECK: memref.load %alloc_0[%[[INDEX1]]] : memref<32xf32>
+        // CHECK: memref.load %{{.*}}[%[[INDEX1]]] : memref<32xf32>
         %b = memref.load %1[%k, %j] :memref<8x4xf32,#map1>
         // CHECK: %[[INDEX2:.*]] = affine.apply #[[$REDUCE_MAP3]](%{{.*}}, %{{.*}})
-        // CHECK: memref.load %alloc_1[%[[INDEX2]]] : memref<16xf32>
+        // CHECK: memref.load %{{.*}}[%[[INDEX2]]] : memref<16xf32>
         %c = memref.load %2[%i, %j] : memref<4x4xf32,#map2>
         %3 = arith.mulf %a, %b : f32
         %4 = arith.addf %3, %c : f32



More information about the Mlir-commits mailing list