[Mlir-commits] [mlir] Normalize reinterpret_cast op (PR #133417)
Arnab Dutta
llvmlistbot at llvm.org
Fri Mar 28 04:35:41 PDT 2025
https://github.com/arnab-polymage updated https://github.com/llvm/llvm-project/pull/133417
>From ffbbcf7148502ba85c1e30346971ec4344b63f99 Mon Sep 17 00:00:00 2001
From: Arnab Dutta <arnab at polymagelabs.com>
Date: Fri, 28 Mar 2025 17:04:12 +0530
Subject: [PATCH] Normalize reinterpret_cast op
Normalize reinterpret_cast op for statically shaped input
and output memrefs. Also improve `replaceAllMemRefUsesWith`
to perform correct replacement in case if memref.load/memref.tore
ops.
---
mlir/include/mlir/Dialect/Affine/Utils.h | 9 +-
mlir/lib/Dialect/Affine/Utils/Utils.cpp | 255 ++++++++++--------
.../MemRef/Transforms/NormalizeMemRefs.cpp | 9 +
.../Dialect/MemRef/normalize-memrefs-ops.mlir | 18 ++
.../Dialect/MemRef/normalize-memrefs.mlir | 4 +-
5 files changed, 175 insertions(+), 120 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index ff1900bc8f2eb..1032d4d92b589 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -32,6 +32,7 @@ class FuncOp;
namespace memref {
class AllocOp;
class AllocaOp;
+class ReinterpretCastOp;
} // namespace memref
namespace affine {
@@ -243,9 +244,9 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
ArrayRef<Value> symbolOperands = {},
bool allowNonDereferencingOps = false);
-/// Rewrites the memref defined by this alloc op to have an identity layout map
-/// and updates all its indexing uses. Returns failure if any of its uses
-/// escape (while leaving the IR in a valid state).
+/// Rewrites the memref defined by alloc or reinterpret_cast op to have an
+/// identity layout map and updates all its indexing uses. Returns failure if
+/// any of its uses escape (while leaving the IR in a valid state).
template <typename AllocLikeOp>
LogicalResult normalizeMemRef(AllocLikeOp *op);
extern template LogicalResult
@@ -253,6 +254,8 @@ normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
extern template LogicalResult
normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);
+LogicalResult normalizeMemRef(memref::ReinterpretCastOp *op);
+
/// Normalizes `memrefType` so that the affine layout map of the memref is
/// transformed to an identity map with a new shape being computed for the
/// normalized memref type and returns it. The old memref type is simplify
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 2723cff6900d0..3ff73e37ac189 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1098,90 +1098,12 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
op->erase();
}
-// Private helper function to transform memref.load with reduced rank.
-// This function will modify the indices of the memref.load to match the
-// newMemRef.
-LogicalResult transformMemRefLoadWithReducedRank(
- Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
- ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
- ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
- unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
- unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
- unsigned oldMapNumInputs = oldMemRefRank;
- SmallVector<Value, 4> oldMapOperands(
- op->operand_begin() + memRefOperandPos + 1,
- op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
- SmallVector<Value, 4> oldMemRefOperands;
- oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
- SmallVector<Value, 4> remapOperands;
- remapOperands.reserve(extraOperands.size() + oldMemRefRank +
- symbolOperands.size());
- remapOperands.append(extraOperands.begin(), extraOperands.end());
- remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
- remapOperands.append(symbolOperands.begin(), symbolOperands.end());
-
- SmallVector<Value, 4> remapOutputs;
- remapOutputs.reserve(oldMemRefRank);
- SmallVector<Value, 4> affineApplyOps;
-
- OpBuilder builder(op);
-
- if (indexRemap &&
- indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
- // Remapped indices.
- for (auto resultExpr : indexRemap.getResults()) {
- auto singleResMap = AffineMap::get(
- indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
- auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
- remapOperands);
- remapOutputs.push_back(afOp);
- affineApplyOps.push_back(afOp);
- }
- } else {
- // No remapping specified.
- remapOutputs.assign(remapOperands.begin(), remapOperands.end());
- }
-
- SmallVector<Value, 4> newMapOperands;
- newMapOperands.reserve(newMemRefRank);
-
- // Prepend 'extraIndices' in 'newMapOperands'.
- for (Value extraIndex : extraIndices) {
- assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
- "invalid memory op index");
- newMapOperands.push_back(extraIndex);
- }
-
- // Append 'remapOutputs' to 'newMapOperands'.
- newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
-
- // Create new fully composed AffineMap for new op to be created.
- assert(newMapOperands.size() == newMemRefRank);
-
- OperationState state(op->getLoc(), op->getName());
- // Construct the new operation using this memref.
- state.operands.reserve(newMapOperands.size() + extraIndices.size());
- state.operands.push_back(newMemRef);
-
- // Insert the new memref map operands.
- state.operands.append(newMapOperands.begin(), newMapOperands.end());
-
- state.types.reserve(op->getNumResults());
- for (auto result : op->getResults())
- state.types.push_back(result.getType());
-
- // Copy over the attributes from the old operation to the new operation.
- for (auto namedAttr : op->getAttrs()) {
- state.attributes.push_back(namedAttr);
- }
-
- // Create the new operation.
- auto *repOp = builder.create(state);
- op->replaceAllUsesWith(repOp);
- op->erase();
-
- return success();
+// Checks if `op` is non dereferencing.
+// TODO: This hardcoded check will be removed once the right interface is added.
+static bool isDereferencingOp(Operation *op) {
+ return isa<AffineMapAccessInterface, memref::LoadOp, memref::StoreOp>(op);
}
+
// Perform the replacement in `op`.
LogicalResult mlir::affine::replaceAllMemRefUsesWith(
Value oldMemRef, Value newMemRef, Operation *op,
@@ -1216,53 +1138,57 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
if (usePositions.empty())
return success();
- if (usePositions.size() > 1) {
- // TODO: extend it for this case when needed (rare).
- assert(false && "multiple dereferencing uses in a single op not supported");
- return failure();
- }
-
unsigned memRefOperandPos = usePositions.front();
OpBuilder builder(op);
// The following checks if op is dereferencing memref and performs the access
// index rewrites.
auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
- if (!affMapAccInterface) {
+ if (!isDereferencingOp(op)) {
if (!allowNonDereferencingOps) {
// Failure: memref used in a non-dereferencing context (potentially
// escapes); no replacement in these cases unless allowNonDereferencingOps
// is set.
return failure();
}
+ for (unsigned pos : usePositions)
+ op->setOperand(pos, newMemRef);
+ return success();
+ }
- // Check if it is a memref.load
- auto memrefLoad = dyn_cast<memref::LoadOp>(op);
- bool isReductionLike =
- indexRemap.getNumResults() < indexRemap.getNumInputs();
- if (!memrefLoad || !isReductionLike) {
- op->setOperand(memRefOperandPos, newMemRef);
- return success();
- }
+ if (usePositions.size() > 1) {
+ // TODO: extend it for this case when needed (rare).
+ assert(false && "multiple dereferencing uses in a single op not supported");
+ return failure();
+ }
- return transformMemRefLoadWithReducedRank(
- op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
- symbolOperands, indexRemap);
+ // Perform index rewrites for the dereferencing op and then replace the op.
+ SmallVector<Value, 4> oldMapOperands;
+ AffineMap oldMap;
+ unsigned oldMemRefNumIndices = oldMemRefRank;
+ if (affMapAccInterface) {
+ // If `op` implements AffineMapAccessInterface, we can get the indices by
+ // quering the number of map operands from the operand list from a certain
+ // offset (`memRefOperandPos` in this case).
+ NamedAttribute oldMapAttrPair =
+ affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
+ oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
+ oldMemRefNumIndices = oldMap.getNumInputs();
+ oldMapOperands.assign(op->operand_begin() + memRefOperandPos + 1,
+ op->operand_begin() + memRefOperandPos + 1 +
+ oldMemRefNumIndices);
+ } else {
+ oldMapOperands.assign(op->operand_begin() + memRefOperandPos + 1,
+ op->operand_begin() + memRefOperandPos + 1 +
+ oldMemRefRank);
}
- // Perform index rewrites for the dereferencing op and then replace the op
- NamedAttribute oldMapAttrPair =
- affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
- AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
- unsigned oldMapNumInputs = oldMap.getNumInputs();
- SmallVector<Value, 4> oldMapOperands(
- op->operand_begin() + memRefOperandPos + 1,
- op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
// Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
SmallVector<Value, 4> oldMemRefOperands;
SmallVector<Value, 4> affineApplyOps;
oldMemRefOperands.reserve(oldMemRefRank);
- if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
+ if (affMapAccInterface &&
+ oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
for (auto resultExpr : oldMap.getResults()) {
auto singleResMap = AffineMap::get(oldMap.getNumDims(),
oldMap.getNumSymbols(), resultExpr);
@@ -1287,7 +1213,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
SmallVector<Value, 4> remapOutputs;
remapOutputs.reserve(oldMemRefRank);
-
if (indexRemap &&
indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
// Remapped indices.
@@ -1303,7 +1228,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
// No remapping specified.
remapOutputs.assign(remapOperands.begin(), remapOperands.end());
}
-
SmallVector<Value, 4> newMapOperands;
newMapOperands.reserve(newMemRefRank);
@@ -1338,13 +1262,26 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
state.operands.push_back(newMemRef);
// Insert the new memref map operands.
- state.operands.append(newMapOperands.begin(), newMapOperands.end());
+ if (affMapAccInterface) {
+ state.operands.append(newMapOperands.begin(), newMapOperands.end());
+ } else {
+ // In the case of dereferencing ops not implementing
+ // AffineMapAccessInterface, we need to apply the values of `newMapOperands`
+ // to the `newMap` to get the correct indices.
+ for (unsigned i = 0; i < newMemRefRank; i++)
+ state.operands.push_back(builder.create<AffineApplyOp>(
+ op->getLoc(),
+ AffineMap::get(newMap.getNumDims(), newMap.getNumSymbols(),
+ newMap.getResult(i)),
+ newMapOperands));
+ }
// Insert the remaining operands unmodified.
+ unsigned oldMapNumInputs = oldMapOperands.size();
+
state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
oldMapNumInputs,
op->operand_end());
-
// Result types don't change. Both memref's are of the same elemental type.
state.types.reserve(op->getNumResults());
for (auto result : op->getResults())
@@ -1353,7 +1290,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
// Add attribute for 'newMap', other Attributes do not change.
auto newMapAttr = AffineMapAttr::get(newMap);
for (auto namedAttr : op->getAttrs()) {
- if (namedAttr.getName() == oldMapAttrPair.getName())
+ if (affMapAccInterface &&
+ namedAttr.getName() ==
+ affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef).getName())
state.attributes.push_back({namedAttr.getName(), newMapAttr});
else
state.attributes.push_back(namedAttr);
@@ -1846,6 +1785,92 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
return success();
}
+LogicalResult
+mlir::affine::normalizeMemRef(memref::ReinterpretCastOp *reinterpretCastOp) {
+ MemRefType memrefType = reinterpretCastOp->getType();
+ AffineMap oldLayoutMap = memrefType.getLayout().getAffineMap();
+ Value oldMemRef = reinterpretCastOp->getResult();
+
+ // If `oldLayoutMap` is identity, `memrefType` is already normalized.
+ if (oldLayoutMap.isIdentity())
+ return success();
+
+ // Fetch a new memref type after normalizing the old memref to have an
+ // identity map layout.
+ MemRefType newMemRefType = normalizeMemRefType(memrefType);
+ if (newMemRefType == memrefType)
+ // `oldLayoutMap` couldn't be transformed to an identity map.
+ return failure();
+
+ uint64_t newRank = newMemRefType.getRank();
+ SmallVector<Value> mapOperands(oldLayoutMap.getNumDims() +
+ oldLayoutMap.getNumSymbols());
+ SmallVector<Value> oldStrides = reinterpretCastOp->getStrides();
+ Location loc = reinterpretCastOp->getLoc();
+ // As `newMemRefType` is normalized, it is unit strided.
+ SmallVector<int64_t> newStaticStrides(newRank, 1);
+ ArrayRef<int64_t> oldShape = memrefType.getShape();
+ mlir::ValueRange oldSizes = reinterpretCastOp->getSizes();
+ unsigned idx = 0;
+ SmallVector<int64_t> newStaticSizes;
+ OpBuilder b(*reinterpretCastOp);
+ // Collectthe map operands which will be used to compute the new normalized
+ // memref shape.
+ for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) {
+ if (oldShape[i] == ShapedType::kDynamic)
+ mapOperands[i] =
+ b.create<arith::SubIOp>(loc, oldSizes[0].getType(), oldSizes[idx++],
+ b.create<arith::ConstantIndexOp>(loc, 1));
+ else
+ mapOperands[i] = b.create<arith::ConstantIndexOp>(loc, oldShape[i] - 1);
+ }
+ for (unsigned i = 0, e = oldStrides.size(); i < e; i++)
+ mapOperands[memrefType.getRank() + i] = oldStrides[i];
+ SmallVector<Value> newSizes;
+ ArrayRef<int64_t> newShape = newMemRefType.getShape();
+ // Compute size along all the dimensions of the new normalized memref.
+ for (unsigned i = 0; i < newRank; i++) {
+ if (newShape[i] != ShapedType::kDynamic)
+ continue;
+ newSizes.push_back(b.create<AffineApplyOp>(
+ loc,
+ AffineMap::get(oldLayoutMap.getNumDims(), oldLayoutMap.getNumSymbols(),
+ oldLayoutMap.getResult(i)),
+ mapOperands));
+ }
+ for (unsigned i = 0, e = newSizes.size(); i < e; i++)
+ newSizes[i] =
+ b.create<arith::AddIOp>(loc, newSizes[i].getType(), newSizes[i],
+ b.create<arith::ConstantIndexOp>(loc, 1));
+ // Create the new reinterpret_cast op.
+ memref::ReinterpretCastOp newReinterpretCast =
+ b.create<memref::ReinterpretCastOp>(
+ loc, newMemRefType, reinterpretCastOp->getSource(),
+ reinterpretCastOp->getOffsets(), newSizes, mlir::ValueRange(),
+ /*static_offsets=*/reinterpretCastOp->getStaticOffsets(),
+ /*static_sizes=*/newShape,
+ /*static_strides=*/newStaticStrides);
+
+ // Replace all uses of the old memref.
+ if (failed(replaceAllMemRefUsesWith(oldMemRef,
+ /*newMemRef=*/newReinterpretCast,
+ /*extraIndices=*/{},
+ /*indexRemap=*/oldLayoutMap,
+ /*extraOperands=*/{},
+ /*symbolOperands=*/oldStrides,
+ /*domOpFilter=*/nullptr,
+ /*postDomOpFilter=*/nullptr,
+ /*allowNonDereferencingOps=*/true))) {
+ // If it failed (due to escapes for example), bail out.
+ newReinterpretCast->erase();
+ return failure();
+ }
+
+ oldMemRef.replaceAllUsesWith(newReinterpretCast);
+ reinterpretCastOp->erase();
+ return success();
+}
+
template LogicalResult
mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
template LogicalResult
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index 08b853fe65b85..b8d8a99c33084 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -363,6 +363,15 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
for (memref::AllocaOp allocaOp : allocaOps)
(void)normalizeMemRef(&allocaOp);
+ // Turn memrefs' non-identity layouts maps into ones with identity. Collect
+ // reinterpret_cast ops first and then process since normalizeMemRef
+ // replaces/erases ops during memref rewriting.
+ SmallVector<memref::ReinterpretCastOp> reinterpretCastOps;
+ funcOp.walk(
+ [&](memref::ReinterpretCastOp op) { reinterpretCastOps.push_back(op); });
+ for (memref::ReinterpretCastOp reinterpretCastOp : reinterpretCastOps)
+ (void)normalizeMemRef(&reinterpretCastOp);
+
// We use this OpBuilder to create new memref layout later.
OpBuilder b(funcOp);
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
index 3bede131325a7..d8d3b13cd498b 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
@@ -165,3 +165,21 @@ func.func @prefetch_normalize(%arg0: memref<512xf32, affine_map<(d0) -> (d0 floo
}
return
}
+
+#map_strided = affine_map<(d0, d1) -> (d0 * 7 + d1)>
+
+// CHECK-LABEL: test_reinterpret_cast
+func.func @test_reinterpret_cast(%arg0: memref<5x7xf32>, %arg1: memref<5x7xf32>, %arg2: memref<5x7xf32>) {
+ %0 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [5, 7], strides: [7, 1] : memref<5x7xf32> to memref<5x7xf32, #map_strided>
+ // CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [35], strides: [1] : memref<5x7xf32> to memref<35xf32>
+ affine.for %arg5 = 0 to 5 {
+ affine.for %arg6 = 0 to 7 {
+ %1 = affine.load %0[%arg5, %arg6] : memref<5x7xf32, #map_strided>
+ // CHECK: affine.load %reinterpret_cast[%{{.*}} * 7 + %{{.*}}] : memref<35xf32>
+ %2 = affine.load %arg1[%arg5, %arg6] : memref<5x7xf32>
+ %3 = arith.subf %1, %2 : f32
+ affine.store %3, %arg2[%arg5, %arg6] : memref<5x7xf32>
+ }
+ }
+ return
+}
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index e93a1a4ebae53..440f4776424cc 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -3,8 +3,8 @@
// This file tests whether the memref type having non-trivial map layouts
// are normalized to trivial (identity) layouts.
-// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>
-// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)>
+// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 + (d1 floordiv 2) * 6)>
+// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 2 + (d0 floordiv 2) * 6)>
// CHECK-DAG: #[[$REDUCE_MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>
// CHECK-LABEL: func @permute()
More information about the Mlir-commits
mailing list