[llvm-branch-commits] [mlir] [mlir][memref][NFC] Simplify `constifyIndexValues` (PR #135940)
Quentin Colombet via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Apr 16 08:53:10 PDT 2025
================
@@ -88,101 +88,30 @@ SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
// Utility functions for propagating static information
//===----------------------------------------------------------------------===//
-/// Helper function that infers the constant values from a list of \p values,
-/// a \p memRefTy, and another helper function \p getAttributes.
-/// The inferred constant values replace the related `OpFoldResult` in
-/// \p values.
+/// Helper function that sets values[i] to constValues[i] if the latter is a
+/// static value, as indicated by ShapedType::kDynamic.
///
-/// \note This function shouldn't be used directly, instead, use the
-/// `getConstifiedMixedXXX` methods from the related operations.
-///
-/// \p getAttributes retuns a list of potentially constant values, as determined
-/// by \p isDynamic, from the given \p memRefTy. The returned list must have as
-/// many elements as \p values or be empty.
-///
-/// E.g., consider the following example:
-/// ```
-/// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
-/// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-/// ```
-/// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
-/// Now using this helper function with:
-/// - `values == [2, %dyn_stride]`,
-/// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
-/// - `getAttributes == getConstantStrides` (i.e., a wrapper around
-/// `getStridesAndOffset`), and
-/// - `isDynamic == ShapedType::isDynamic`
-/// Will yield: `values == [2, 1]`
-static void constifyIndexValues(
- SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
- MLIRContext *ctxt,
- llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
- llvm::function_ref<bool(int64_t)> isDynamic) {
- SmallVector<int64_t> constValues = getAttributes(memRefTy);
- Builder builder(ctxt);
- for (const auto &it : llvm::enumerate(constValues)) {
- int64_t constValue = it.value();
- if (!isDynamic(constValue))
- values[it.index()] = builder.getIndexAttr(constValue);
- }
- for (OpFoldResult &ofr : values) {
- if (auto attr = dyn_cast<Attribute>(ofr)) {
- // FIXME: We shouldn't need to do that, but right now, the static indices
- // are created with the wrong type: `i64` instead of `index`.
- // As a result, if we were to keep the attribute as is, we may fail to see
- // that two attributes are equal because one would have the i64 type and
- // the other the index type.
- // The alternative would be to create constant indices with getI64Attr in
- // this and the previous loop, but it doesn't logically make sense (we are
- // dealing with indices here) and would only strenghten the inconsistency
- // around how static indices are created (some places use getI64Attr,
- // others use getIndexAttr).
- // The workaround here is to stick to the IndexAttr type for all the
- // values, hence we recreate the attribute even when it is already static
- // to make sure the type is consistent.
- ofr = builder.getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt());
+/// If constValues[i] is dynamic, tries to extract a constant value from
+/// value[i] to allow for additional folding opportunities. Also convertes all
+/// existing attributes to index attributes. (They may be i64 attributes.)
+static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
+ ArrayRef<int64_t> constValues) {
+ assert(constValues.size() == values.size() &&
+ "incorrect number of const values");
+ for (int64_t i = 0, e = constValues.size(); i < e; ++i) {
----------------
qcolombet wrote:
Could we use `enumerate` here?
https://github.com/llvm/llvm-project/pull/135940
More information about the llvm-branch-commits
mailing list