[Mlir-commits] [mlir] [mlir] Add memref normalization support for reinterpret_cast op (PR #133417)
Arnab Dutta
llvmlistbot at llvm.org
Tue Apr 29 23:07:32 PDT 2025
https://github.com/arnab-polymage updated https://github.com/llvm/llvm-project/pull/133417
>From 3ca4a4746d450eb02a666db76fa6264106c7ed2e Mon Sep 17 00:00:00 2001
From: Arnab Dutta <arnab at polymagelabs.com>
Date: Wed, 30 Apr 2025 10:40:10 +0530
Subject: [PATCH] Normalize reinterpret_cast op
Rewrites the memref defined by reinterpet_cast op to have an identity layout map
and updates all its indexing uses. Also extend `replaceAllMemRefUsesWith` utility
to work when there are multiple occurences of `oldMemRef` in `op`'s operand list
when op is non-dereferencing.
---
mlir/include/mlir/Dialect/Affine/Utils.h | 8 +-
mlir/lib/Dialect/Affine/Utils/Utils.cpp | 255 ++++++++++--------
.../MemRef/Transforms/NormalizeMemRefs.cpp | 36 ++-
.../Dialect/MemRef/normalize-memrefs-ops.mlir | 31 +++
.../Dialect/MemRef/normalize-memrefs.mlir | 4 +-
5 files changed, 200 insertions(+), 134 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 3b4bb34105581..ae5a68a6be157 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -32,6 +32,7 @@ class FuncOp;
namespace memref {
class AllocOp;
class AllocaOp;
+class ReinterpretCastOp;
} // namespace memref
namespace affine {
@@ -243,15 +244,16 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
ArrayRef<Value> symbolOperands = {},
bool allowNonDereferencingOps = false);
-/// Rewrites the memref defined by this alloc op to have an identity layout map
-/// and updates all its indexing uses. Returns failure if any of its uses
-/// escape (while leaving the IR in a valid state).
+/// Rewrites the memref defined by alloc or reinterpret_cast op to have an
+/// identity layout map and updates all its indexing uses. Returns failure if
+/// any of its uses escape (while leaving the IR in a valid state).
template <typename AllocLikeOp>
LogicalResult normalizeMemRef(AllocLikeOp op);
extern template LogicalResult
normalizeMemRef<memref::AllocaOp>(memref::AllocaOp op);
extern template LogicalResult
normalizeMemRef<memref::AllocOp>(memref::AllocOp op);
+LogicalResult normalizeMemRef(memref::ReinterpretCastOp op);
/// Normalizes `memrefType` so that the affine layout map of the memref is
/// transformed to an identity map with a new shape being computed for the
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 2925aa918cb1c..db000652196d1 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1098,90 +1098,12 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
op->erase();
}
-// Private helper function to transform memref.load with reduced rank.
-// This function will modify the indices of the memref.load to match the
-// newMemRef.
-LogicalResult transformMemRefLoadWithReducedRank(
- Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
- ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
- ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
- unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
- unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
- unsigned oldMapNumInputs = oldMemRefRank;
- SmallVector<Value, 4> oldMapOperands(
- op->operand_begin() + memRefOperandPos + 1,
- op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
- SmallVector<Value, 4> oldMemRefOperands;
- oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
- SmallVector<Value, 4> remapOperands;
- remapOperands.reserve(extraOperands.size() + oldMemRefRank +
- symbolOperands.size());
- remapOperands.append(extraOperands.begin(), extraOperands.end());
- remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
- remapOperands.append(symbolOperands.begin(), symbolOperands.end());
-
- SmallVector<Value, 4> remapOutputs;
- remapOutputs.reserve(oldMemRefRank);
- SmallVector<Value, 4> affineApplyOps;
-
- OpBuilder builder(op);
-
- if (indexRemap &&
- indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
- // Remapped indices.
- for (auto resultExpr : indexRemap.getResults()) {
- auto singleResMap = AffineMap::get(
- indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
- auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
- remapOperands);
- remapOutputs.push_back(afOp);
- affineApplyOps.push_back(afOp);
- }
- } else {
- // No remapping specified.
- remapOutputs.assign(remapOperands.begin(), remapOperands.end());
- }
-
- SmallVector<Value, 4> newMapOperands;
- newMapOperands.reserve(newMemRefRank);
-
- // Prepend 'extraIndices' in 'newMapOperands'.
- for (Value extraIndex : extraIndices) {
- assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
- "invalid memory op index");
- newMapOperands.push_back(extraIndex);
- }
-
- // Append 'remapOutputs' to 'newMapOperands'.
- newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
-
- // Create new fully composed AffineMap for new op to be created.
- assert(newMapOperands.size() == newMemRefRank);
-
- OperationState state(op->getLoc(), op->getName());
- // Construct the new operation using this memref.
- state.operands.reserve(newMapOperands.size() + extraIndices.size());
- state.operands.push_back(newMemRef);
-
- // Insert the new memref map operands.
- state.operands.append(newMapOperands.begin(), newMapOperands.end());
-
- state.types.reserve(op->getNumResults());
- for (auto result : op->getResults())
- state.types.push_back(result.getType());
-
- // Copy over the attributes from the old operation to the new operation.
- for (auto namedAttr : op->getAttrs()) {
- state.attributes.push_back(namedAttr);
- }
-
- // Create the new operation.
- auto *repOp = builder.create(state);
- op->replaceAllUsesWith(repOp);
- op->erase();
-
- return success();
+// Checks if `op` is non dereferencing.
+// TODO: This hardcoded check will be removed once the right interface is added.
+static bool isDereferencingOp(Operation *op) {
+ return isa<AffineMapAccessInterface, memref::LoadOp, memref::StoreOp>(op);
}
+
// Perform the replacement in `op`.
LogicalResult mlir::affine::replaceAllMemRefUsesWith(
Value oldMemRef, Value newMemRef, Operation *op,
@@ -1216,53 +1138,53 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
if (usePositions.empty())
return success();
- if (usePositions.size() > 1) {
- // TODO: extend it for this case when needed (rare).
- assert(false && "multiple dereferencing uses in a single op not supported");
- return failure();
- }
-
unsigned memRefOperandPos = usePositions.front();
OpBuilder builder(op);
// The following checks if op is dereferencing memref and performs the access
// index rewrites.
- auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
- if (!affMapAccInterface) {
+ if (!isDereferencingOp(op)) {
if (!allowNonDereferencingOps) {
// Failure: memref used in a non-dereferencing context (potentially
// escapes); no replacement in these cases unless allowNonDereferencingOps
// is set.
return failure();
}
+ for (unsigned pos : usePositions)
+ op->setOperand(pos, newMemRef);
+ return success();
+ }
- // Check if it is a memref.load
- auto memrefLoad = dyn_cast<memref::LoadOp>(op);
- bool isReductionLike =
- indexRemap.getNumResults() < indexRemap.getNumInputs();
- if (!memrefLoad || !isReductionLike) {
- op->setOperand(memRefOperandPos, newMemRef);
- return success();
- }
+ if (usePositions.size() > 1) {
+ // TODO: extend it for this case when needed (rare).
+ LLVM_DEBUG(llvm::dbgs()
+ << "multiple dereferencing uses in a single op not supported");
+ return failure();
+ }
- return transformMemRefLoadWithReducedRank(
- op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
- symbolOperands, indexRemap);
+ // Perform index rewrites for the dereferencing op and then replace the op.
+ SmallVector<Value, 4> oldMapOperands;
+ AffineMap oldMap;
+ unsigned oldMemRefNumIndices = oldMemRefRank;
+ auto startIdx = op->operand_begin() + memRefOperandPos + 1;
+ auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
+ if (affMapAccInterface) {
+ // If `op` implements AffineMapAccessInterface, we can get the indices by
+ // quering the number of map operands from the operand list from a certain
+ // offset (`memRefOperandPos` in this case).
+ NamedAttribute oldMapAttrPair =
+ affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
+ oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
+ oldMemRefNumIndices = oldMap.getNumInputs();
}
- // Perform index rewrites for the dereferencing op and then replace the op
- NamedAttribute oldMapAttrPair =
- affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
- AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
- unsigned oldMapNumInputs = oldMap.getNumInputs();
- SmallVector<Value, 4> oldMapOperands(
- op->operand_begin() + memRefOperandPos + 1,
- op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
+ oldMapOperands.assign(startIdx, startIdx + oldMemRefNumIndices);
// Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
SmallVector<Value, 4> oldMemRefOperands;
SmallVector<Value, 4> affineApplyOps;
oldMemRefOperands.reserve(oldMemRefRank);
- if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
+ if (affMapAccInterface &&
+ oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
for (auto resultExpr : oldMap.getResults()) {
auto singleResMap = AffineMap::get(oldMap.getNumDims(),
oldMap.getNumSymbols(), resultExpr);
@@ -1287,7 +1209,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
SmallVector<Value, 4> remapOutputs;
remapOutputs.reserve(oldMemRefRank);
-
if (indexRemap &&
indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
// Remapped indices.
@@ -1303,7 +1224,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
// No remapping specified.
remapOutputs.assign(remapOperands.begin(), remapOperands.end());
}
-
SmallVector<Value, 4> newMapOperands;
newMapOperands.reserve(newMemRefRank);
@@ -1338,13 +1258,26 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
state.operands.push_back(newMemRef);
// Insert the new memref map operands.
- state.operands.append(newMapOperands.begin(), newMapOperands.end());
+ if (affMapAccInterface) {
+ state.operands.append(newMapOperands.begin(), newMapOperands.end());
+ } else {
+ // In the case of dereferencing ops not implementing
+ // AffineMapAccessInterface, we need to apply the values of `newMapOperands`
+ // to the `newMap` to get the correct indices.
+ for (unsigned i = 0; i < newMemRefRank; i++) {
+ state.operands.push_back(builder.create<AffineApplyOp>(
+ op->getLoc(),
+ AffineMap::get(newMap.getNumDims(), newMap.getNumSymbols(),
+ newMap.getResult(i)),
+ newMapOperands));
+ }
+ }
// Insert the remaining operands unmodified.
+ unsigned oldMapNumInputs = oldMapOperands.size();
state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
oldMapNumInputs,
op->operand_end());
-
// Result types don't change. Both memref's are of the same elemental type.
state.types.reserve(op->getNumResults());
for (auto result : op->getResults())
@@ -1353,7 +1286,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
// Add attribute for 'newMap', other Attributes do not change.
auto newMapAttr = AffineMapAttr::get(newMap);
for (auto namedAttr : op->getAttrs()) {
- if (namedAttr.getName() == oldMapAttrPair.getName())
+ if (affMapAccInterface &&
+ namedAttr.getName() ==
+ affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef).getName())
state.attributes.push_back({namedAttr.getName(), newMapAttr});
else
state.attributes.push_back(namedAttr);
@@ -1845,6 +1780,94 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) {
return success();
}
+LogicalResult
+mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) {
+ MemRefType memrefType = reinterpretCastOp.getType();
+ AffineMap oldLayoutMap = memrefType.getLayout().getAffineMap();
+ Value oldMemRef = reinterpretCastOp.getResult();
+
+ // If `oldLayoutMap` is identity, `memrefType` is already normalized.
+ if (oldLayoutMap.isIdentity())
+ return success();
+
+ // Fetch a new memref type after normalizing the old memref to have an
+ // identity map layout.
+ MemRefType newMemRefType = normalizeMemRefType(memrefType);
+ if (newMemRefType == memrefType)
+ // `oldLayoutMap` couldn't be transformed to an identity map.
+ return failure();
+
+ uint64_t newRank = newMemRefType.getRank();
+ SmallVector<Value> mapOperands(oldLayoutMap.getNumDims() +
+ oldLayoutMap.getNumSymbols());
+ SmallVector<Value> oldStrides = reinterpretCastOp.getStrides();
+ Location loc = reinterpretCastOp.getLoc();
+ // As `newMemRefType` is normalized, it is unit strided.
+ SmallVector<int64_t> newStaticStrides(newRank, 1);
+ SmallVector<int64_t> newStaticOffsets(newRank, 0);
+ ArrayRef<int64_t> oldShape = memrefType.getShape();
+ ValueRange oldSizes = reinterpretCastOp.getSizes();
+ unsigned idx = 0;
+ SmallVector<int64_t> newStaticSizes;
+ OpBuilder b(reinterpretCastOp);
+ // Collect the map operands which will be used to compute the new normalized
+ // memref shape.
+ for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) {
+ if (memrefType.isDynamicDim(i))
+ mapOperands[i] =
+ b.create<arith::SubIOp>(loc, oldSizes[0].getType(), oldSizes[idx++],
+ b.create<arith::ConstantIndexOp>(loc, 1));
+ else
+ mapOperands[i] = b.create<arith::ConstantIndexOp>(loc, oldShape[i] - 1);
+ }
+ for (unsigned i = 0, e = oldStrides.size(); i < e; i++)
+ mapOperands[memrefType.getRank() + i] = oldStrides[i];
+ SmallVector<Value> newSizes;
+ ArrayRef<int64_t> newShape = newMemRefType.getShape();
+ // Compute size along all the dimensions of the new normalized memref.
+ for (unsigned i = 0; i < newRank; i++) {
+ if (!newMemRefType.isDynamicDim(i))
+ continue;
+ newSizes.push_back(b.create<AffineApplyOp>(
+ loc,
+ AffineMap::get(oldLayoutMap.getNumDims(), oldLayoutMap.getNumSymbols(),
+ oldLayoutMap.getResult(i)),
+ mapOperands));
+ }
+ for (unsigned i = 0, e = newSizes.size(); i < e; i++) {
+ newSizes[i] =
+ b.create<arith::AddIOp>(loc, newSizes[i].getType(), newSizes[i],
+ b.create<arith::ConstantIndexOp>(loc, 1));
+ }
+ // Create the new reinterpret_cast op.
+ auto newReinterpretCast = b.create<memref::ReinterpretCastOp>(
+ loc, newMemRefType, reinterpretCastOp.getSource(),
+ /*offsets=*/ValueRange(), newSizes,
+ /*strides=*/ValueRange(),
+ /*static_offsets=*/newStaticOffsets,
+ /*static_sizes=*/newShape,
+ /*static_strides=*/newStaticStrides);
+
+ // Replace all uses of the old memref.
+ if (failed(replaceAllMemRefUsesWith(oldMemRef,
+ /*newMemRef=*/newReinterpretCast,
+ /*extraIndices=*/{},
+ /*indexRemap=*/oldLayoutMap,
+ /*extraOperands=*/{},
+ /*symbolOperands=*/oldStrides,
+ /*domOpFilter=*/nullptr,
+ /*postDomOpFilter=*/nullptr,
+ /*allowNonDereferencingOps=*/true))) {
+ // If it failed (due to escapes for example), bail out.
+ newReinterpretCast.erase();
+ return failure();
+ }
+
+ oldMemRef.replaceAllUsesWith(newReinterpretCast);
+ reinterpretCastOp.erase();
+ return success();
+}
+
template LogicalResult
mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp op);
template LogicalResult
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index 95fed04a7864e..756c555494b71 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -30,6 +30,7 @@ namespace memref {
using namespace mlir;
using namespace mlir::affine;
+using namespace mlir::memref;
namespace {
@@ -164,7 +165,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
return true;
if (funcOp
- .walk([&](memref::AllocOp allocOp) -> WalkResult {
+ .walk([&](AllocOp allocOp) -> WalkResult {
Value oldMemRef = allocOp.getResult();
if (!allocOp.getType().getLayout().isIdentity() &&
!isMemRefNormalizable(oldMemRef.getUsers()))
@@ -175,7 +176,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
return false;
if (funcOp
- .walk([&](memref::AllocaOp allocaOp) -> WalkResult {
+ .walk([&](AllocaOp allocaOp) -> WalkResult {
Value oldMemRef = allocaOp.getResult();
if (!allocaOp.getType().getLayout().isIdentity() &&
!isMemRefNormalizable(oldMemRef.getUsers()))
@@ -346,22 +347,31 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
}
/// Normalizes the memrefs within a function which includes those arising as a
-/// result of AllocOps, AllocaOps, CallOps and function's argument. The ModuleOp
-/// argument is used to help update function's signature after normalization.
+/// result of AllocOps, AllocaOps, CallOps, ReinterpretCastOps and function's
+/// argument. The ModuleOp argument is used to help update function's signature
+/// after normalization.
void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
ModuleOp moduleOp) {
// Turn memrefs' non-identity layouts maps into ones with identity. Collect
- // alloc/alloca ops first and then process since normalizeMemRef
- // replaces/erases ops during memref rewriting.
- SmallVector<memref::AllocOp, 4> allocOps;
- funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
- for (memref::AllocOp allocOp : allocOps)
+ // alloc, alloca ops and reinterpret_cast ops first and then process since
+ // normalizeMemRef replaces/erases ops during memref rewriting.
+ SmallVector<AllocOp, 4> allocOps;
+ SmallVector<AllocaOp> allocaOps;
+ SmallVector<ReinterpretCastOp> reinterpretCastOps;
+ funcOp.walk([&](Operation *op) {
+ if (auto allocOp = dyn_cast<AllocOp>(op))
+ allocOps.push_back(allocOp);
+ else if (auto allocaOp = dyn_cast<AllocaOp>(op))
+ allocaOps.push_back(allocaOp);
+ else if (auto reinterpretCastOp = dyn_cast<ReinterpretCastOp>(op))
+ reinterpretCastOps.push_back(reinterpretCastOp);
+ });
+ for (AllocOp allocOp : allocOps)
(void)normalizeMemRef(allocOp);
-
- SmallVector<memref::AllocaOp> allocaOps;
- funcOp.walk([&](memref::AllocaOp op) { allocaOps.push_back(op); });
- for (memref::AllocaOp allocaOp : allocaOps)
+ for (AllocaOp allocaOp : allocaOps)
(void)normalizeMemRef(allocaOp);
+ for (ReinterpretCastOp reinterpretCastOp : reinterpretCastOps)
+ (void)normalizeMemRef(reinterpretCastOp);
// We use this OpBuilder to create new memref layout later.
OpBuilder b(funcOp);
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
index 3bede131325a7..344da4e5e2462 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
@@ -165,3 +165,34 @@ func.func @prefetch_normalize(%arg0: memref<512xf32, affine_map<(d0) -> (d0 floo
}
return
}
+
+#map_strided = affine_map<(d0, d1) -> (d0 * 7 + d1)>
+
+// CHECK-LABEL: test_reinterpret_cast
+func.func @test_reinterpret_cast(%arg0: memref<5x7xf32>, %arg1: memref<5x7xf32>, %arg2: memref<5x7xf32>) {
+ %0 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [5, 7], strides: [7, 1] : memref<5x7xf32> to memref<5x7xf32, #map_strided>
+ // CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [35], strides: [1] : memref<5x7xf32> to memref<35xf32>
+ affine.for %arg5 = 0 to 5 {
+ affine.for %arg6 = 0 to 7 {
+ %1 = affine.load %0[%arg5, %arg6] : memref<5x7xf32, #map_strided>
+ // CHECK: affine.load %reinterpret_cast[%{{.*}} * 7 + %{{.*}}] : memref<35xf32>
+ %2 = affine.load %arg1[%arg5, %arg6] : memref<5x7xf32>
+ %3 = arith.subf %1, %2 : f32
+ affine.store %3, %arg2[%arg5, %arg6] : memref<5x7xf32>
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: reinterpret_cast_non_zero_offset
+func.func @reinterpret_cast_non_zero_offset(%arg0: index, %arg1: memref<1x10x17xi32, strided<[?, ?, ?], offset: ?>>, %arg2: memref<1x10x17xi32, strided<[?, ?, ?], offset: ?>>, %arg3: memref<1x10x17xi32, strided<[?, ?, ?], offset: ?>>) -> (memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x10x17xi32>
+ %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x17xf32>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x10x17xf32>
+ cf.br ^bb3
+^bb3: // pred: ^bb1
+ // CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [32], strides: [1] : memref<2x17xf32> to memref<32xf32>
+ // CHECK: return %[[REINTERPRET_CAST]], %[[REINTERPRET_CAST]], %{{.*}}, %{{.*}}, %{{.*}} : memref<32xf32>, memref<32xf32>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>
+ %reinterpret_cast = memref.reinterpret_cast %alloc_0 to offset: [27], sizes: [1, 5], strides: [17, 1] : memref<2x17xf32> to memref<1x5xf32, strided<[17, 1], offset: 27>>
+ return %reinterpret_cast, %reinterpret_cast, %alloc_0, %alloc, %alloc_1 : memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>
+}
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index e93a1a4ebae53..440f4776424cc 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -3,8 +3,8 @@
// This file tests whether the memref type having non-trivial map layouts
// are normalized to trivial (identity) layouts.
-// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>
-// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)>
+// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 + (d1 floordiv 2) * 6)>
+// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 2 + (d0 floordiv 2) * 6)>
// CHECK-DAG: #[[$REDUCE_MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>
// CHECK-LABEL: func @permute()
More information about the Mlir-commits
mailing list