[Mlir-commits] [mlir] [MLIR] Fixing the memref linearization size computation (PR #138922)

Zhuoran Yin llvmlistbot at llvm.org
Thu May 8 07:38:17 PDT 2025


https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/138922

>From a372a077d05d81775242e3166e58932b7f2eeed6 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 5 May 2025 21:29:33 +0000
Subject: [PATCH 1/4] Fixing the memref linearization size computation

---
 mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp | 63 +++++++++++++++++--
 .../Dialect/MemRef/emulate-narrow-type.mlir   | 16 ++---
 .../Vector/vector-emulate-narrow-type.mlir    | 16 ++---
 3 files changed, 73 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index ac397b597fd14..ae92f1bffee75 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -66,7 +66,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
   SmallVector<AffineExpr> symbols(2 * sourceRank);
   bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
   AffineExpr addMulMap = builder.getAffineConstantExpr(0);
-  AffineExpr mulMap = builder.getAffineConstantExpr(1);
 
   SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
 
@@ -75,18 +74,70 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
     addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
     offsetValues[offsetIdx] = indicesVec[i];
     offsetValues[offsetIdx + 1] = strides[i];
-
-    mulMap = mulMap * symbols[i];
   }
 
   // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
   int64_t scaler = dstBits / srcBits;
-  mulMap = mulMap.floorDiv(scaler);
+
+  // If all strides and sizes are constant, we can compute the result
+  // directly without creating the AffineMaxOp.
+  int64_t constResult = 0;
+  int64_t constStride = 0;
+  int64_t constSize = 0;
+  bool isAllConstant = true;
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    if (auto constantStride = getConstantIntValue(strides[i])) {
+      constStride = *constantStride;
+    } else {
+      isAllConstant = false;
+      break;
+    }
+    if (auto constantSize = getConstantIntValue(sizes[i])) {
+      constSize = *constantSize;
+    } else {
+      isAllConstant = false;
+      break;
+    }
+    constResult = std::max(constResult, constStride * constSize / scaler);
+  }
+
+  size_t symbolIndex = 0;
+  SmallVector<Value> values;
+  SmallVector<AffineExpr> productExpressions;
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    AffineExpr strideExpr, sizeExpr;
+    OpFoldResult stride = strides[i];
+    OpFoldResult size = sizes[i];
+    if (auto constantStride = getConstantIntValue(stride)) {
+      strideExpr = builder.getAffineConstantExpr(*constantStride);
+    } else {
+      strideExpr = symbols[symbolIndex++];
+      values.push_back(getValueOrCreateConstantIndexOp(builder, loc, stride));
+    }
+
+    if (auto constantSize = getConstantIntValue(size)) {
+      sizeExpr = builder.getAffineConstantExpr(*constantSize);
+    } else {
+      sizeExpr = symbols[symbolIndex++];
+      values.push_back(getValueOrCreateConstantIndexOp(builder, loc, size));
+    }
+
+    productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler));
+  }
+  AffineMap maxMap = AffineMap::get(
+      /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
+      builder.getContext());
+
+  OpFoldResult linearizedSize;
+  if (isAllConstant) {
+    linearizedSize = builder.getIndexAttr(constResult);
+  } else {
+    Value totalSize = builder.create<affine::AffineMaxOp>(loc, maxMap, values);
+    linearizedSize = totalSize;
+  }
 
   OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
       builder, loc, addMulMap.floorDiv(scaler), offsetValues);
-  OpFoldResult linearizedSize =
-      affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
 
   // Adjust baseOffset by the scale factor (dstBits / srcBits).
   AffineExpr s0;
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 1d6cbfa343ba5..f6740fae3046e 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -104,7 +104,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
   %1 = memref.load %0[%arg2, %arg3] : memref<?x?xi4>
   return %1 : i4
 }
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
 //      CHECK: func @memref_load_i4_dynamic(
@@ -112,7 +112,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
-//      CHECK:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //      CHECK:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
 //      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -122,7 +122,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
 //      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
 //      CHECK:   return %[[TRUNC]]
 
-//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //  CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
 //      CHECK32: func @memref_load_i4_dynamic(
@@ -130,7 +130,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
 // CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
-//      CHECK32:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
 //      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK32:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -399,7 +399,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
   memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
   return
 }
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
 //      CHECK: func @memref_store_i4_dynamic(
@@ -408,7 +408,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
-//  CHECK-DAG:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK-DAG:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
 //  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
 //  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
@@ -423,7 +423,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
 //      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
 //      CHECK:   return
 
-//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //  CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
 //      CHECK32: func @memref_store_i4_dynamic(
@@ -432,7 +432,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
-//  CHECK32-DAG:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK32-DAG:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
 //  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
 //  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 9e2d131f421b7..faadf48ad3984 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -58,27 +58,27 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
   %1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
   return %1 : vector<8xi4>
 }
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //      CHECK: func.func @vector_load_i4_dynamic(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
-//      CHECK:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //      CHECK:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
 //      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi8>, vector<4xi8>
 //      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
 
-//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //      CHECK32: func.func @vector_load_i4_dynamic(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
 // CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
 // CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
-//      CHECK32:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
 //      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK32:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
@@ -450,7 +450,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
     return
 }
 
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //      CHECK: func @vector_store_i4_dynamic
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
@@ -458,13 +458,13 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:   %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9]+]]: index
-//      CHECK: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
 //      CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
 //      CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
 //      CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
 //      CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi8>, vector<4xi8>
 
-//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //      CHECK32: func @vector_store_i4_dynamic
 // CHECK32-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
@@ -472,7 +472,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
 // CHECK32-SAME:   %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:   %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:   %[[ARG4:[a-zA-Z0-9]+]]: index
-//      CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK32: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
 //      CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
 //      CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
 //      CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>

>From 2e39f5db12b69a2a4583fde4adf399930054be07 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Wed, 7 May 2025 19:07:58 +0000
Subject: [PATCH 2/4] Use constant total size if available

---
 mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp | 30 +++----------------
 1 file changed, 4 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index ae92f1bffee75..420001d7202ae 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -78,29 +78,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
 
   // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
   int64_t scaler = dstBits / srcBits;
-
-  // If all strides and sizes are constant, we can compute the result
-  // directly without creating the AffineMaxOp.
-  int64_t constResult = 0;
-  int64_t constStride = 0;
-  int64_t constSize = 0;
-  bool isAllConstant = true;
-  for (unsigned i = 0; i < sourceRank; ++i) {
-    if (auto constantStride = getConstantIntValue(strides[i])) {
-      constStride = *constantStride;
-    } else {
-      isAllConstant = false;
-      break;
-    }
-    if (auto constantSize = getConstantIntValue(sizes[i])) {
-      constSize = *constantSize;
-    } else {
-      isAllConstant = false;
-      break;
-    }
-    constResult = std::max(constResult, constStride * constSize / scaler);
-  }
-
   size_t symbolIndex = 0;
   SmallVector<Value> values;
   SmallVector<AffineExpr> productExpressions;
@@ -129,10 +106,11 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
       builder.getContext());
 
   OpFoldResult linearizedSize;
-  if (isAllConstant) {
-    linearizedSize = builder.getIndexAttr(constResult);
+  Value totalSize =
+      builder.createOrFold<affine::AffineMaxOp>(loc, maxMap, values);
+  if (auto constantSize = getConstantIntValue(totalSize)) {
+    linearizedSize = builder.getIndexAttr(*constantSize);
   } else {
-    Value totalSize = builder.create<affine::AffineMaxOp>(loc, maxMap, values);
     linearizedSize = totalSize;
   }
 

>From cde70c4629d2f6ab8654f52bb71638e4b6022f97 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Wed, 7 May 2025 19:53:12 +0000
Subject: [PATCH 3/4] Amending amdgpu transfer-read to use new linearized size

---
 .../AMDGPU/Transforms/TransferReadToLoad.cpp  | 48 ++-----------------
 .../Dialect/AMDGPU/transfer-read-to-load.mlir | 10 ++--
 2 files changed, 9 insertions(+), 49 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index b1527a5c3f838..a2ac41c263cc6 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -162,60 +162,20 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
         stridedMetadata.getConstifiedMixedStrides();
     SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
     OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
+    memref::LinearizedMemRefInfo linearizedInfo;
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
                                                  elementBitWidth, offset, sizes,
                                                  strides, indices);
 
-    // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
-    // Note below doesn't give the correct result for the linearized size.
-    // Value totalSize = getValueOrCreateConstantIndexOp(
-    //    rewriter, loc, linearizedInfo.linearizedSize);
-    // It computes the multiplied sizes of all dimensions instead of taking
-    // the maximum of each dimension size * stride.
-    SmallVector<AffineExpr> productExpressions;
-    unsigned sourceRank = cast<ShapedType>(src.getType()).getRank();
-
-    SmallVector<AffineExpr> symbols(2 * sourceRank);
-    SmallVector<Value> offsetValues;
-    bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
-
-    size_t symbolIndex = 0;
-    for (size_t i = 0; i < sourceRank; ++i) {
-      AffineExpr strideExpr, sizeExpr;
-      OpFoldResult stride = strides[i];
-      OpFoldResult size = sizes[i];
-      if (auto constantStride = getConstantIntValue(stride)) {
-        strideExpr = rewriter.getAffineConstantExpr(*constantStride);
-      } else {
-        strideExpr = symbols[symbolIndex++];
-        offsetValues.push_back(
-            getValueOrCreateConstantIndexOp(rewriter, loc, stride));
-      }
-
-      if (auto constantSize = getConstantIntValue(size)) {
-        sizeExpr = rewriter.getAffineConstantExpr(*constantSize);
-      } else {
-        sizeExpr = symbols[symbolIndex++];
-        offsetValues.push_back(
-            getValueOrCreateConstantIndexOp(rewriter, loc, size));
-      }
-
-      productExpressions.push_back(strideExpr * sizeExpr);
-    }
-
-    AffineMap maxMap = AffineMap::get(
-        /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
-        rewriter.getContext());
-    Value totalSize =
-        rewriter.create<affine::AffineMaxOp>(loc, maxMap, offsetValues);
-
     // delta = bufferSize - linearizedOffset
     Value vectorSizeOffset =
         rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
     Value linearIndex =
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+    Value totalSize = getValueOrCreateConstantIndexOp(
+        rewriter, loc, linearizedInfo.linearizedSize);
     Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
 
     // 1) check if delta < vectorSize
diff --git a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
index 22b425680fb05..ba2179d333f4b 100644
--- a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
@@ -52,9 +52,9 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
 
 // -----
 
-// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
-// CHECK: #map1 = affine_map<()[s0, s1, s2] -> (s0 * s1, s2)>
-// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> (s0 * s1, s2)>
+// CHECK: func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(
 // CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>
 // CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
 // CHECK-SAME: %[[ARG3:.*]]: vector<4xi1>
@@ -68,8 +68,8 @@ func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref<?x?xi8,
 // CHECK: %[[C0:.*]] = arith.constant 0 : index
 // CHECK: %[[C4:.*]] = arith.constant 4 : index
 // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
-// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[STRIDES]]#0, %[[ARG2]]]
-// CHECK: %[[SIZE:.*]] = affine.max #map1()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[SIZES]]#1]
+// CHECK: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[SIZES]]#1]
+// CHECK: %[[LINEAR:.*]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[STRIDES]]#0, %[[ARG2]]]
 // CHECK: %[[IF:.*]] = scf.if
 // CHECK: return
 

>From 4942b1a5edfc9d6d75b26d53b0ad7d9867ba51a8 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Thu, 8 May 2025 14:38:04 +0000
Subject: [PATCH 4/4] Use makeComposedFoldedAffineMax

---
 mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp | 40 ++++++-------------
 .../Dialect/MemRef/emulate-narrow-type.mlir   | 16 ++++----
 .../Vector/vector-emulate-narrow-type.mlir    | 24 +++++------
 3 files changed, 32 insertions(+), 48 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 420001d7202ae..98b94451855ea 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
 
@@ -75,47 +76,30 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
     offsetValues[offsetIdx] = indicesVec[i];
     offsetValues[offsetIdx + 1] = strides[i];
   }
-
   // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
   int64_t scaler = dstBits / srcBits;
+  OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
+      builder, loc, addMulMap.floorDiv(scaler), offsetValues);
+
   size_t symbolIndex = 0;
-  SmallVector<Value> values;
+  SmallVector<OpFoldResult> values;
   SmallVector<AffineExpr> productExpressions;
   for (unsigned i = 0; i < sourceRank; ++i) {
-    AffineExpr strideExpr, sizeExpr;
+    AffineExpr strideExpr = symbols[symbolIndex++];
     OpFoldResult stride = strides[i];
-    OpFoldResult size = sizes[i];
-    if (auto constantStride = getConstantIntValue(stride)) {
-      strideExpr = builder.getAffineConstantExpr(*constantStride);
-    } else {
-      strideExpr = symbols[symbolIndex++];
-      values.push_back(getValueOrCreateConstantIndexOp(builder, loc, stride));
-    }
+    values.push_back(getValueOrCreateConstantIndexOp(builder, loc, stride));
 
-    if (auto constantSize = getConstantIntValue(size)) {
-      sizeExpr = builder.getAffineConstantExpr(*constantSize);
-    } else {
-      sizeExpr = symbols[symbolIndex++];
-      values.push_back(getValueOrCreateConstantIndexOp(builder, loc, size));
-    }
+    AffineExpr sizeExpr = symbols[symbolIndex++];
+    OpFoldResult size = sizes[i];
+    values.push_back(getValueOrCreateConstantIndexOp(builder, loc, size));
 
     productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler));
   }
   AffineMap maxMap = AffineMap::get(
       /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
       builder.getContext());
-
-  OpFoldResult linearizedSize;
-  Value totalSize =
-      builder.createOrFold<affine::AffineMaxOp>(loc, maxMap, values);
-  if (auto constantSize = getConstantIntValue(totalSize)) {
-    linearizedSize = builder.getIndexAttr(*constantSize);
-  } else {
-    linearizedSize = totalSize;
-  }
-
-  OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
-      builder, loc, addMulMap.floorDiv(scaler), offsetValues);
+  OpFoldResult linearizedSize =
+      affine::makeComposedFoldedAffineMax(builder, loc, maxMap, values);
 
   // Adjust baseOffset by the scale factor (dstBits / srcBits).
   AffineExpr s0;
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index f6740fae3046e..0cb3b7b744476 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -104,7 +104,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
   %1 = memref.load %0[%arg2, %arg3] : memref<?x?xi4>
   return %1 : i4
 }
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
 //      CHECK: func @memref_load_i4_dynamic(
@@ -112,7 +112,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
-//      CHECK:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
 //      CHECK:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
 //      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -122,7 +122,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
 //      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
 //      CHECK:   return %[[TRUNC]]
 
-//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8, s0 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //  CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
 //      CHECK32: func @memref_load_i4_dynamic(
@@ -130,7 +130,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
 // CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
-//      CHECK32:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
 //      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK32:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -399,7 +399,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
   memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
   return
 }
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
 //      CHECK: func @memref_store_i4_dynamic(
@@ -408,7 +408,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
-//  CHECK-DAG:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
+//  CHECK-DAG:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
 //  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
 //  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
 //  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
@@ -423,7 +423,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
 //      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
 //      CHECK:   return
 
-//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8, s0 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //  CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
 //      CHECK32: func @memref_store_i4_dynamic(
@@ -432,7 +432,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
-//  CHECK32-DAG:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
+//  CHECK32-DAG:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
 //  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
 //  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
 //  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index faadf48ad3984..6c924492b513e 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -58,27 +58,27 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
   %1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
   return %1 : vector<8xi4>
 }
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //      CHECK: func.func @vector_load_i4_dynamic(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
-//      CHECK:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
 //      CHECK:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
 //      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi8>, vector<4xi8>
 //      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
 
-//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8, s0 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //      CHECK32: func.func @vector_load_i4_dynamic(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
 // CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
 // CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
-//      CHECK32:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
 //      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK32:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
@@ -450,7 +450,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
     return
 }
 
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //      CHECK: func @vector_store_i4_dynamic
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
@@ -458,13 +458,13 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:   %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9]+]]: index
-//      CHECK: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
+//      CHECK: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
 //      CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
 //      CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
 //      CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
 //      CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi8>, vector<4xi8>
 
-//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8, s0 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //      CHECK32: func @vector_store_i4_dynamic
 // CHECK32-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
@@ -472,7 +472,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
 // CHECK32-SAME:   %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:   %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:   %[[ARG4:[a-zA-Z0-9]+]]: index
-//      CHECK32: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
+//      CHECK32: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
 //      CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
 //      CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
 //      CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
@@ -537,7 +537,7 @@ func.func @vector_maskedstore_i4(
 // CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
 // CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
 
-// CHECK-LABEL:   func.func @vector_maskedstore_i4(
+// CHECK:         func.func @vector_maskedstore_i4(
 // CHECK-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
 // CHECK-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
 // CHECK-SAME:      %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
@@ -557,7 +557,7 @@ func.func @vector_maskedstore_i4(
 // CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
 // CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
 
-// CHECK32-LABEL:   func.func @vector_maskedstore_i4(
+// CHECK32:         func.func @vector_maskedstore_i4(
 // CHECK32-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:      %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
@@ -623,7 +623,7 @@ func.func @vector_maskedstore_i4_constant_mask(
 }
 
 // CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
-// CHECK-LABEL:   func.func @vector_maskedstore_i4_constant_mask(
+// CHECK:         func.func @vector_maskedstore_i4_constant_mask(
 // CHECK-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
 // CHECK-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
 // CHECK-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
@@ -639,7 +639,7 @@ func.func @vector_maskedstore_i4_constant_mask(
 // CHECK:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
 
 // CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
-// CHECK32-LABEL:   func.func @vector_maskedstore_i4_constant_mask(
+// CHECK32:         func.func @vector_maskedstore_i4_constant_mask(
 // CHECK32-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {



More information about the Mlir-commits mailing list