[llvm-branch-commits] [mlir] [mlir][memref][NFC] Simplify `constifyIndexValues` (PR #135940)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Apr 16 02:10:12 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/135940

>From ea19bcfab213967b0e86aa1346734432e4843e0f Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 16 Apr 2025 10:02:41 +0200
Subject: [PATCH] [mlir][memref][NFC] Simplify `constifyIndexValues`

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 153 ++++++++---------------
 1 file changed, 50 insertions(+), 103 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 63f5251398716..d174a05ffceb0 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -88,101 +88,30 @@ SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
 // Utility functions for propagating static information
 //===----------------------------------------------------------------------===//
 
-/// Helper function that infers the constant values from a list of \p values,
-/// a \p memRefTy, and another helper function \p getAttributes.
-/// The inferred constant values replace the related `OpFoldResult` in
-/// \p values.
+/// Helper function that sets values[i] to constValues[i] if the latter is a
+/// static value, as indicated by ShapedType::kDynamic.
 ///
-/// \note This function shouldn't be used directly, instead, use the
-/// `getConstifiedMixedXXX` methods from the related operations.
-///
-/// \p getAttributes retuns a list of potentially constant values, as determined
-/// by \p isDynamic, from the given \p memRefTy. The returned list must have as
-/// many elements as \p values or be empty.
-///
-/// E.g., consider the following example:
-/// ```
-/// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
-///     memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-/// ```
-/// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
-/// Now using this helper function with:
-/// - `values == [2, %dyn_stride]`,
-/// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
-/// - `getAttributes == getConstantStrides` (i.e., a wrapper around
-/// `getStridesAndOffset`), and
-/// - `isDynamic == ShapedType::isDynamic`
-/// Will yield: `values == [2, 1]`
-static void constifyIndexValues(
-    SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
-    MLIRContext *ctxt,
-    llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
-    llvm::function_ref<bool(int64_t)> isDynamic) {
-  SmallVector<int64_t> constValues = getAttributes(memRefTy);
-  Builder builder(ctxt);
-  for (const auto &it : llvm::enumerate(constValues)) {
-    int64_t constValue = it.value();
-    if (!isDynamic(constValue))
-      values[it.index()] = builder.getIndexAttr(constValue);
-  }
-  for (OpFoldResult &ofr : values) {
-    if (auto attr = dyn_cast<Attribute>(ofr)) {
-      // FIXME: We shouldn't need to do that, but right now, the static indices
-      // are created with the wrong type: `i64` instead of `index`.
-      // As a result, if we were to keep the attribute as is, we may fail to see
-      // that two attributes are equal because one would have the i64 type and
-      // the other the index type.
-      // The alternative would be to create constant indices with getI64Attr in
-      // this and the previous loop, but it doesn't logically make sense (we are
-      // dealing with indices here) and would only strenghten the inconsistency
-      // around how static indices are created (some places use getI64Attr,
-      // others use getIndexAttr).
-      // The workaround here is to stick to the IndexAttr type for all the
-      // values, hence we recreate the attribute even when it is already static
-      // to make sure the type is consistent.
-      ofr = builder.getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt());
+/// If constValues[i] is dynamic, tries to extract a constant value from
+/// value[i] to allow for additional folding opportunities. Also convertes all
+/// existing attributes to index attributes. (They may be i64 attributes.)
+static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
+                                ArrayRef<int64_t> constValues) {
+  assert(constValues.size() == values.size() &&
+         "incorrect number of const values");
+  for (int64_t i = 0, e = constValues.size(); i < e; ++i) {
+    Builder builder(values[i].getContext());
+    if (!ShapedType::isDynamic(constValues[i])) {
+      // Constant value is known, use it directly.
+      values[i] = builder.getIndexAttr(constValues[i]);
       continue;
     }
-    std::optional<int64_t> maybeConstant =
-        getConstantIntValue(cast<Value>(ofr));
-    if (maybeConstant)
-      ofr = builder.getIndexAttr(*maybeConstant);
+    if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
+      // Try to extract a constant or convert an existing to index.
+      values[i] = builder.getIndexAttr(*cst);
+    }
   }
 }
 
-/// Wrapper around `getShape` that conforms to the function signature
-/// expected for `getAttributes` in `constifyIndexValues`.
-static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
-  ArrayRef<int64_t> sizes = memRefTy.getShape();
-  return SmallVector<int64_t>(sizes);
-}
-
-/// Wrapper around `getStridesAndOffset` that returns only the offset and
-/// conforms to the function signature expected for `getAttributes` in
-/// `constifyIndexValues`.
-static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
-  SmallVector<int64_t> strides;
-  int64_t offset;
-  LogicalResult hasStaticInformation =
-      memrefType.getStridesAndOffset(strides, offset);
-  if (failed(hasStaticInformation))
-    return SmallVector<int64_t>();
-  return SmallVector<int64_t>(1, offset);
-}
-
-/// Wrapper around `getStridesAndOffset` that returns only the strides and
-/// conforms to the function signature expected for `getAttributes` in
-/// `constifyIndexValues`.
-static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
-  SmallVector<int64_t> strides;
-  int64_t offset;
-  LogicalResult hasStaticInformation =
-      memrefType.getStridesAndOffset(strides, offset);
-  if (failed(hasStaticInformation))
-    return SmallVector<int64_t>();
-  return strides;
-}
-
 //===----------------------------------------------------------------------===//
 // AllocOp / AllocaOp
 //===----------------------------------------------------------------------===//
@@ -1445,24 +1374,34 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
 
 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
   SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
-  constifyIndexValues(values, getSource().getType(), getContext(),
-                      getConstantSizes, ShapedType::isDynamic);
+  constifyIndexValues(values, getSource().getType().getShape());
   return values;
 }
 
 SmallVector<OpFoldResult>
 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
   SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
-  constifyIndexValues(values, getSource().getType(), getContext(),
-                      getConstantStrides, ShapedType::isDynamic);
+  SmallVector<int64_t> staticValues;
+  int64_t unused;
+  LogicalResult status =
+      getSource().getType().getStridesAndOffset(staticValues, unused);
+  (void)status;
+  assert(succeeded(status) && "could not get strides from type");
+  constifyIndexValues(values, staticValues);
   return values;
 }
 
 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
   OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
   SmallVector<OpFoldResult> values(1, offsetOfr);
-  constifyIndexValues(values, getSource().getType(), getContext(),
-                      getConstantOffset, ShapedType::isDynamic);
+  SmallVector<int64_t> staticValues, unused;
+  int64_t offset;
+  LogicalResult status =
+      getSource().getType().getStridesAndOffset(unused, offset);
+  (void)status;
+  assert(succeeded(status) && "could not get offset from type");
+  staticValues.push_back(offset);
+  constifyIndexValues(values, staticValues);
   return values[0];
 }
 
@@ -1975,15 +1914,18 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
 
 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
   SmallVector<OpFoldResult> values = getMixedSizes();
-  constifyIndexValues(values, getType(), getContext(), getConstantSizes,
-                      ShapedType::isDynamic);
+  constifyIndexValues(values, getType().getShape());
   return values;
 }
 
 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
   SmallVector<OpFoldResult> values = getMixedStrides();
-  constifyIndexValues(values, getType(), getContext(), getConstantStrides,
-                      ShapedType::isDynamic);
+  SmallVector<int64_t> staticValues;
+  int64_t unused;
+  LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
+  (void)status;
+  assert(succeeded(status) && "could not get strides from type");
+  constifyIndexValues(values, staticValues);
   return values;
 }
 
@@ -1991,8 +1933,13 @@ OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
   SmallVector<OpFoldResult> values = getMixedOffsets();
   assert(values.size() == 1 &&
          "reinterpret_cast must have one and only one offset");
-  constifyIndexValues(values, getType(), getContext(), getConstantOffset,
-                      ShapedType::isDynamic);
+  SmallVector<int64_t> staticValues, unused;
+  int64_t offset;
+  LogicalResult status = getType().getStridesAndOffset(unused, offset);
+  (void)status;
+  assert(succeeded(status) && "could not get offset from type");
+  staticValues.push_back(offset);
+  constifyIndexValues(values, staticValues);
   return values[0];
 }
 
@@ -2062,7 +2009,7 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
       // Second, check the sizes.
       if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
                        op.getConstifiedMixedSizes()))
-          return false;
+        return false;
 
       // Finally, check the offset.
       assert(op.getMixedOffsets().size() == 1 &&
@@ -2106,7 +2053,7 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
       rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
                                           extractStridedMetadata.getSource());
 
-    return success();
+      return success();
   }
 };
 } // namespace



More information about the llvm-branch-commits mailing list