[Mlir-commits] [mlir] [MLIR] Add a utility pass to linearize `memref` (PR #136797)
Alan Li
llvmlistbot at llvm.org
Fri May 9 10:40:19 PDT 2025
================
@@ -0,0 +1,283 @@
+//===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains patterns for flattening an multi-rank memref-related
+// ops into 1-d memref ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <numeric>
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_FLATTENMEMREFSPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
+ OpFoldResult in) {
+ if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
+ return rewriter.create<arith::ConstantIndexOp>(
+ loc, cast<IntegerAttr>(offsetAttr).getInt());
+ }
+ return cast<Value>(in);
+}
+
+/// Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the
+/// span of the memref.
+static OpFoldResult computeSize(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> dims,
+ ArrayRef<OpFoldResult> strides) {
+ assert(dims.size() == strides.size() &&
+ "number of dimensions and strides should be equal");
+ SmallVector<AffineExpr> symbols(2 * dims.size());
+ bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
+ SmallVector<AffineExpr> productExpressions;
+ SmallVector<OpFoldResult> values;
+ size_t symbolIndex = 0;
+ for (auto &&[dim, stride] : llvm::zip(dims, strides)) {
+ AffineExpr dimExpr = symbols[symbolIndex++];
+ AffineExpr strideExpr = symbols[symbolIndex++];
+ productExpressions.push_back(dimExpr * strideExpr);
+ values.push_back(dim);
+ values.push_back(stride);
+ }
+
+ AffineMap maxMap = AffineMap::get(0, symbols.size(), productExpressions,
+ builder.getContext());
+ return affine::makeComposedFoldedAffineMax(builder, loc, maxMap, values);
+}
+
+/// Returns a collapsed memref and the linearized index to access the element
+/// at the specified indices.
+static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
+ Location loc,
+ Value source,
+ ValueRange indices) {
+ int64_t sourceOffset;
+ SmallVector<int64_t, 4> sourceStrides;
+ auto sourceType = cast<MemRefType>(source.getType());
+ if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
+ assert(false);
+ }
+
+ memref::ExtractStridedMetadataOp stridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
+
+ auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
+ OpFoldResult linearizedIndices;
+ memref::LinearizedMemRefInfo linearizedInfo;
+ std::tie(linearizedInfo, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, typeBit, typeBit,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(),
+ getAsOpFoldResult(indices));
+
+ return std::make_pair(
+ rewriter.create<memref::ReinterpretCastOp>(
+ loc, source,
+ /* offset = */ linearizedInfo.linearizedOffset,
+ /* shapes = */
+ ArrayRef<OpFoldResult>{computeSize(
+ rewriter, loc, stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides())},
+ /* strides = */
+ ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}),
+ getValueFromOpFoldResult(rewriter, loc, linearizedIndices));
+}
+
+static bool needFlattening(Value val) {
+ auto type = cast<MemRefType>(val.getType());
+ return type.getRank() > 1;
+}
+
+static bool checkLayout(Value val) {
+ auto type = cast<MemRefType>(val.getType());
+ return type.getLayout().isIdentity() ||
+ isa<StridedLayoutAttr>(type.getLayout());
+}
+
+namespace {
+static Value getTargetMemref(Operation *op) {
+ return llvm::TypeSwitch<Operation *, Value>(op)
+ .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
+ memref::AllocOp>([](auto op) { return op.getMemref(); })
+ .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
+ vector::MaskedStoreOp>(
+ [](auto op) { return op.getBase(); })
+ .template Case<vector::TransferReadOp, vector::TransferWriteOp>(
+ [](auto op) { return op.getSource(); })
+ .Default([](auto) { return Value{}; });
+}
+
+template <typename T>
+static void castResult(T oper, T newOper, Location loc,
+ PatternRewriter &rewriter) {
+ memref::ExtractStridedMetadataOp stridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper);
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ oper, cast<MemRefType>(oper.getType()), newOper,
+ /*offset=*/rewriter.getIndexAttr(0),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides());
+}
+
+template <typename T>
+static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
+ Value offset) {
+ auto loc = op->getLoc();
+ llvm::TypeSwitch<Operation *>(op.getOperation())
+ .template Case<memref::AllocOp>([&](auto oper) {
+ auto newAlloc = rewriter.create<memref::AllocOp>(
+ loc, cast<MemRefType>(flatMemref.getType()),
+ oper.getAlignmentAttr());
+ castResult(oper, newAlloc, loc, rewriter);
+ })
+ .template Case<memref::AllocaOp>([&](auto oper) {
+ auto newAlloca = rewriter.create<memref::AllocaOp>(
+ loc, cast<MemRefType>(flatMemref.getType()),
+ oper.getAlignmentAttr());
+ castResult(oper, newAlloca, loc, rewriter);
+ })
+ .template Case<memref::LoadOp>([&](auto op) {
+ auto newLoad = rewriter.create<memref::LoadOp>(
+ loc, op->getResultTypes(), flatMemref, ValueRange{offset});
+ newLoad->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newLoad.getResult());
+ })
+ .template Case<memref::StoreOp>([&](auto op) {
+ auto newStore = rewriter.create<memref::StoreOp>(
+ loc, op->getOperands().front(), flatMemref, ValueRange{offset});
+ newStore->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newStore);
+ })
+ .template Case<vector::LoadOp>([&](auto op) {
+ auto newLoad = rewriter.create<vector::LoadOp>(
+ loc, op->getResultTypes(), flatMemref, ValueRange{offset});
+ newLoad->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newLoad.getResult());
+ })
+ .template Case<vector::StoreOp>([&](auto op) {
+ auto newStore = rewriter.create<vector::StoreOp>(
+ loc, op->getOperands().front(), flatMemref, ValueRange{offset});
+ newStore->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newStore);
+ })
+ .template Case<vector::MaskedLoadOp>([&](auto op) {
+ auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
+ loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(),
+ op.getPassThru());
+ newMaskedLoad->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newMaskedLoad.getResult());
+ })
+ .template Case<vector::MaskedStoreOp>([&](auto op) {
+ auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
+ loc, flatMemref, ValueRange{offset}, op.getMask(),
+ op.getValueToStore());
+ newMaskedStore->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newMaskedStore);
+ })
+ .template Case<vector::TransferReadOp>([&](auto op) {
+ auto newTransferRead = rewriter.create<vector::TransferReadOp>(
----------------
lialan wrote:
Now you mentioned this. I think we cannot have reliable method to linearize a TransferReadOp if it involves permutation map, or have broadcasting. This op is just really, too flexible.
I think we can only transform it only under certain conditions.
https://github.com/llvm/llvm-project/pull/136797
More information about the Mlir-commits
mailing list