[Mlir-commits] [mlir] [MLIR] Add a utility pass to linearize `memref` (PR #136797)

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Apr 24 14:58:50 PDT 2025


================
@@ -0,0 +1,356 @@
+//===----- FlattenMemRefs.cpp - MemRef ops flattener pass  ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains patterns for flattening an multi-rank memref-related
+// ops into 1-d memref ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_FLATTENMEMREFSPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+static void setInsertionPointToStart(OpBuilder &builder, Value val) {
+  if (auto *parentOp = val.getDefiningOp()) {
+    builder.setInsertionPointAfter(parentOp);
+  } else {
+    builder.setInsertionPointToStart(val.getParentBlock());
+  }
+}
+
+static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
+                  OpFoldResult>
+getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
+                        ArrayRef<OpFoldResult> subOffsets,
+                        ArrayRef<OpFoldResult> subStrides = std::nullopt) {
+  auto sourceType = cast<MemRefType>(source.getType());
+  auto sourceRank = static_cast<unsigned>(sourceType.getRank());
+
+  memref::ExtractStridedMetadataOp newExtractStridedMetadata;
+  {
+    OpBuilder::InsertionGuard g(rewriter);
+    setInsertionPointToStart(rewriter, source);
+    newExtractStridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
+  }
+
+  auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
+
+  auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
+    return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
+                                      : rewriter.getIndexAttr(dim);
+  };
+
+  OpFoldResult origOffset =
+      getDim(sourceOffset, newExtractStridedMetadata.getOffset());
+  ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
+  OpFoldResult outmostDim =
+      getDim(sourceType.getShape().front(),
+             newExtractStridedMetadata.getSizes().front());
+
+  SmallVector<OpFoldResult> origStrides;
+  origStrides.reserve(sourceRank);
+
+  SmallVector<OpFoldResult> strides;
+  strides.reserve(sourceRank);
+
+  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
+  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
+  for (auto i : llvm::seq(0u, sourceRank)) {
+    OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
+
+    if (!subStrides.empty()) {
+      strides.push_back(affine::makeComposedFoldedAffineApply(
+          rewriter, loc, s0 * s1, {subStrides[i], origStride}));
+    }
+
+    origStrides.emplace_back(origStride);
+  }
+
+  // Compute linearized index:
+  auto &&[expr, values] =
+      computeLinearIndex(rewriter.getIndexAttr(0), origStrides, subOffsets);
+  OpFoldResult linearizedIndex =
+      affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
+
+  // Compute collapsed size: (the outmost stride * outmost dimension).
+  SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
----------------
krzysz00 wrote:

I don't think you want "outermost" here, but max(size[dim] * stride[dim), unless you know the layout is contiguous / row-major-ish. Consider a column-major memref

https://github.com/llvm/llvm-project/pull/136797


More information about the Mlir-commits mailing list