[Mlir-commits] [mlir] [MLIR] Fixing the memref linearization size computation (PR #138922)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 8 06:52:23 PDT 2025


================
@@ -75,18 +74,48 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
     addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
     offsetValues[offsetIdx] = indicesVec[i];
     offsetValues[offsetIdx + 1] = strides[i];
-
-    mulMap = mulMap * symbols[i];
   }
 
   // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
   int64_t scaler = dstBits / srcBits;
-  mulMap = mulMap.floorDiv(scaler);
+  size_t symbolIndex = 0;
+  SmallVector<Value> values;
+  SmallVector<AffineExpr> productExpressions;
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    AffineExpr strideExpr, sizeExpr;
+    OpFoldResult stride = strides[i];
+    OpFoldResult size = sizes[i];
+    if (auto constantStride = getConstantIntValue(stride)) {
+      strideExpr = builder.getAffineConstantExpr(*constantStride);
+    } else {
+      strideExpr = symbols[symbolIndex++];
+      values.push_back(getValueOrCreateConstantIndexOp(builder, loc, stride));
+    }
+
+    if (auto constantSize = getConstantIntValue(size)) {
+      sizeExpr = builder.getAffineConstantExpr(*constantSize);
+    } else {
+      sizeExpr = symbols[symbolIndex++];
+      values.push_back(getValueOrCreateConstantIndexOp(builder, loc, size));
+    }
+
+    productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler));
+  }
+  AffineMap maxMap = AffineMap::get(
+      /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
+      builder.getContext());
+
+  OpFoldResult linearizedSize;
+  Value totalSize =
+      builder.createOrFold<affine::AffineMaxOp>(loc, maxMap, values);
+  if (auto constantSize = getConstantIntValue(totalSize)) {
+    linearizedSize = builder.getIndexAttr(*constantSize);
+  } else {
+    linearizedSize = totalSize;
+  }
----------------
Max191 wrote:

nit: I think you could get rid of all the if/else on getConstantIntValues if you use `affine::makeComposedFoldedAffineMax`:
https://github.com/llvm/llvm-project/blob/be6c6e2f902c71f267f91852e3391a5301f949ac/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h#L447-L452

I.e., just make all the exprs symbols, and then just pass the OpFoldResult operands.

https://github.com/llvm/llvm-project/pull/138922


More information about the Mlir-commits mailing list