[Mlir-commits] [mlir] Normalize reinterpret_cast op (PR #133417)

Arnab Dutta llvmlistbot at llvm.org
Fri Mar 28 06:19:24 PDT 2025


https://github.com/arnab-polymage updated https://github.com/llvm/llvm-project/pull/133417

>From b2de1157114ad843086f54b4174118121e8598e0 Mon Sep 17 00:00:00 2001
From: Arnab Dutta <arnab at polymagelabs.com>
Date: Fri, 28 Mar 2025 18:48:57 +0530
Subject: [PATCH] Normalize reinterpret_cast op

Normalize reinterpret_cast op for statically shaped input
and output memrefs. Also improve `replaceAllMemRefUsesWith`
to perform correct replacement in case if memref.load/memref.tore
ops.
---
 mlir/include/mlir/Dialect/Affine/Utils.h      |   9 +-
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       | 257 ++++++++++--------
 .../MemRef/Transforms/NormalizeMemRefs.cpp    |   9 +
 .../Dialect/MemRef/normalize-memrefs-ops.mlir |  32 +++
 .../Dialect/MemRef/normalize-memrefs.mlir     |   4 +-
 5 files changed, 191 insertions(+), 120 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index ff1900bc8f2eb..1032d4d92b589 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -32,6 +32,7 @@ class FuncOp;
 namespace memref {
 class AllocOp;
 class AllocaOp;
+class ReinterpretCastOp;
 } // namespace memref
 
 namespace affine {
@@ -243,9 +244,9 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
                                        ArrayRef<Value> symbolOperands = {},
                                        bool allowNonDereferencingOps = false);
 
-/// Rewrites the memref defined by this alloc op to have an identity layout map
-/// and updates all its indexing uses. Returns failure if any of its uses
-/// escape (while leaving the IR in a valid state).
+/// Rewrites the memref defined by alloc or reinterpret_cast op to have an
+/// identity layout map and updates all its indexing uses. Returns failure if
+/// any of its uses escape (while leaving the IR in a valid state).
 template <typename AllocLikeOp>
 LogicalResult normalizeMemRef(AllocLikeOp *op);
 extern template LogicalResult
@@ -253,6 +254,8 @@ normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
 extern template LogicalResult
 normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);
 
+LogicalResult normalizeMemRef(memref::ReinterpretCastOp *op);
+
 /// Normalizes `memrefType` so that the affine layout map of the memref is
 /// transformed to an identity map with a new shape being computed for the
 /// normalized memref type and returns it. The old memref type is simplify
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 2723cff6900d0..b56b6073c9a7b 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1098,90 +1098,12 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
     op->erase();
 }
 
-// Private helper function to transform memref.load with reduced rank.
-// This function will modify the indices of the memref.load to match the
-// newMemRef.
-LogicalResult transformMemRefLoadWithReducedRank(
-    Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
-    ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
-    ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
-  unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
-  unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
-  unsigned oldMapNumInputs = oldMemRefRank;
-  SmallVector<Value, 4> oldMapOperands(
-      op->operand_begin() + memRefOperandPos + 1,
-      op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
-  SmallVector<Value, 4> oldMemRefOperands;
-  oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
-  SmallVector<Value, 4> remapOperands;
-  remapOperands.reserve(extraOperands.size() + oldMemRefRank +
-                        symbolOperands.size());
-  remapOperands.append(extraOperands.begin(), extraOperands.end());
-  remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
-  remapOperands.append(symbolOperands.begin(), symbolOperands.end());
-
-  SmallVector<Value, 4> remapOutputs;
-  remapOutputs.reserve(oldMemRefRank);
-  SmallVector<Value, 4> affineApplyOps;
-
-  OpBuilder builder(op);
-
-  if (indexRemap &&
-      indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
-    // Remapped indices.
-    for (auto resultExpr : indexRemap.getResults()) {
-      auto singleResMap = AffineMap::get(
-          indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
-      auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
-                                                remapOperands);
-      remapOutputs.push_back(afOp);
-      affineApplyOps.push_back(afOp);
-    }
-  } else {
-    // No remapping specified.
-    remapOutputs.assign(remapOperands.begin(), remapOperands.end());
-  }
-
-  SmallVector<Value, 4> newMapOperands;
-  newMapOperands.reserve(newMemRefRank);
-
-  // Prepend 'extraIndices' in 'newMapOperands'.
-  for (Value extraIndex : extraIndices) {
-    assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
-           "invalid memory op index");
-    newMapOperands.push_back(extraIndex);
-  }
-
-  // Append 'remapOutputs' to 'newMapOperands'.
-  newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
-
-  // Create new fully composed AffineMap for new op to be created.
-  assert(newMapOperands.size() == newMemRefRank);
-
-  OperationState state(op->getLoc(), op->getName());
-  // Construct the new operation using this memref.
-  state.operands.reserve(newMapOperands.size() + extraIndices.size());
-  state.operands.push_back(newMemRef);
-
-  // Insert the new memref map operands.
-  state.operands.append(newMapOperands.begin(), newMapOperands.end());
-
-  state.types.reserve(op->getNumResults());
-  for (auto result : op->getResults())
-    state.types.push_back(result.getType());
-
-  // Copy over the attributes from the old operation to the new operation.
-  for (auto namedAttr : op->getAttrs()) {
-    state.attributes.push_back(namedAttr);
-  }
-
-  // Create the new operation.
-  auto *repOp = builder.create(state);
-  op->replaceAllUsesWith(repOp);
-  op->erase();
-
-  return success();
+// Checks if `op` is non dereferencing.
+// TODO: This hardcoded check will be removed once the right interface is added.
+static bool isDereferencingOp(Operation *op) {
+  return isa<AffineMapAccessInterface, memref::LoadOp, memref::StoreOp>(op);
 }
+
 // Perform the replacement in `op`.
 LogicalResult mlir::affine::replaceAllMemRefUsesWith(
     Value oldMemRef, Value newMemRef, Operation *op,
@@ -1216,53 +1138,57 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
   if (usePositions.empty())
     return success();
 
-  if (usePositions.size() > 1) {
-    // TODO: extend it for this case when needed (rare).
-    assert(false && "multiple dereferencing uses in a single op not supported");
-    return failure();
-  }
-
   unsigned memRefOperandPos = usePositions.front();
 
   OpBuilder builder(op);
   // The following checks if op is dereferencing memref and performs the access
   // index rewrites.
   auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
-  if (!affMapAccInterface) {
+  if (!isDereferencingOp(op)) {
     if (!allowNonDereferencingOps) {
       // Failure: memref used in a non-dereferencing context (potentially
       // escapes); no replacement in these cases unless allowNonDereferencingOps
       // is set.
       return failure();
     }
+    for (unsigned pos : usePositions)
+      op->setOperand(pos, newMemRef);
+    return success();
+  }
 
-    // Check if it is a memref.load
-    auto memrefLoad = dyn_cast<memref::LoadOp>(op);
-    bool isReductionLike =
-        indexRemap.getNumResults() < indexRemap.getNumInputs();
-    if (!memrefLoad || !isReductionLike) {
-      op->setOperand(memRefOperandPos, newMemRef);
-      return success();
-    }
+  if (usePositions.size() > 1) {
+    // TODO: extend it for this case when needed (rare).
+    assert(false && "multiple dereferencing uses in a single op not supported");
+    return failure();
+  }
 
-    return transformMemRefLoadWithReducedRank(
-        op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
-        symbolOperands, indexRemap);
+  // Perform index rewrites for the dereferencing op and then replace the op.
+  SmallVector<Value, 4> oldMapOperands;
+  AffineMap oldMap;
+  unsigned oldMemRefNumIndices = oldMemRefRank;
+  if (affMapAccInterface) {
+    // If `op` implements AffineMapAccessInterface, we can get the indices by
+    // quering the number of map operands from the operand list from a certain
+    // offset (`memRefOperandPos` in this case).
+    NamedAttribute oldMapAttrPair =
+        affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
+    oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
+    oldMemRefNumIndices = oldMap.getNumInputs();
+    oldMapOperands.assign(op->operand_begin() + memRefOperandPos + 1,
+                          op->operand_begin() + memRefOperandPos + 1 +
+                              oldMemRefNumIndices);
+  } else {
+    oldMapOperands.assign(op->operand_begin() + memRefOperandPos + 1,
+                          op->operand_begin() + memRefOperandPos + 1 +
+                              oldMemRefRank);
   }
-  // Perform index rewrites for the dereferencing op and then replace the op
-  NamedAttribute oldMapAttrPair =
-      affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
-  AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
-  unsigned oldMapNumInputs = oldMap.getNumInputs();
-  SmallVector<Value, 4> oldMapOperands(
-      op->operand_begin() + memRefOperandPos + 1,
-      op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
 
   // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
   SmallVector<Value, 4> oldMemRefOperands;
   SmallVector<Value, 4> affineApplyOps;
   oldMemRefOperands.reserve(oldMemRefRank);
-  if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
+  if (affMapAccInterface &&
+      oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
     for (auto resultExpr : oldMap.getResults()) {
       auto singleResMap = AffineMap::get(oldMap.getNumDims(),
                                          oldMap.getNumSymbols(), resultExpr);
@@ -1287,7 +1213,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
 
   SmallVector<Value, 4> remapOutputs;
   remapOutputs.reserve(oldMemRefRank);
-
   if (indexRemap &&
       indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
     // Remapped indices.
@@ -1303,7 +1228,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
     // No remapping specified.
     remapOutputs.assign(remapOperands.begin(), remapOperands.end());
   }
-
   SmallVector<Value, 4> newMapOperands;
   newMapOperands.reserve(newMemRefRank);
 
@@ -1338,13 +1262,26 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
   state.operands.push_back(newMemRef);
 
   // Insert the new memref map operands.
-  state.operands.append(newMapOperands.begin(), newMapOperands.end());
+  if (affMapAccInterface) {
+    state.operands.append(newMapOperands.begin(), newMapOperands.end());
+  } else {
+    // In the case of dereferencing ops not implementing
+    // AffineMapAccessInterface, we need to apply the values of `newMapOperands`
+    // to the `newMap` to get the correct indices.
+    for (unsigned i = 0; i < newMemRefRank; i++)
+      state.operands.push_back(builder.create<AffineApplyOp>(
+          op->getLoc(),
+          AffineMap::get(newMap.getNumDims(), newMap.getNumSymbols(),
+                         newMap.getResult(i)),
+          newMapOperands));
+  }
 
   // Insert the remaining operands unmodified.
+  unsigned oldMapNumInputs = oldMapOperands.size();
+
   state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
                             oldMapNumInputs,
                         op->operand_end());
-
   // Result types don't change. Both memref's are of the same elemental type.
   state.types.reserve(op->getNumResults());
   for (auto result : op->getResults())
@@ -1353,7 +1290,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
   // Add attribute for 'newMap', other Attributes do not change.
   auto newMapAttr = AffineMapAttr::get(newMap);
   for (auto namedAttr : op->getAttrs()) {
-    if (namedAttr.getName() == oldMapAttrPair.getName())
+    if (affMapAccInterface &&
+        namedAttr.getName() ==
+            affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef).getName())
       state.attributes.push_back({namedAttr.getName(), newMapAttr});
     else
       state.attributes.push_back(namedAttr);
@@ -1846,6 +1785,94 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
   return success();
 }
 
+LogicalResult
+mlir::affine::normalizeMemRef(memref::ReinterpretCastOp *reinterpretCastOp) {
+  MemRefType memrefType = reinterpretCastOp->getType();
+  AffineMap oldLayoutMap = memrefType.getLayout().getAffineMap();
+  Value oldMemRef = reinterpretCastOp->getResult();
+
+  // If `oldLayoutMap` is identity, `memrefType` is already normalized.
+  if (oldLayoutMap.isIdentity())
+    return success();
+
+  // Fetch a new memref type after normalizing the old memref to have an
+  // identity map layout.
+  MemRefType newMemRefType = normalizeMemRefType(memrefType);
+  newMemRefType.dump();
+  if (newMemRefType == memrefType)
+    // `oldLayoutMap` couldn't be transformed to an identity map.
+    return failure();
+
+  uint64_t newRank = newMemRefType.getRank();
+  SmallVector<Value> mapOperands(oldLayoutMap.getNumDims() +
+                                 oldLayoutMap.getNumSymbols());
+  SmallVector<Value> oldStrides = reinterpretCastOp->getStrides();
+  Location loc = reinterpretCastOp->getLoc();
+  // As `newMemRefType` is normalized, it is unit strided.
+  SmallVector<int64_t> newStaticStrides(newRank, 1);
+  SmallVector<int64_t> newStaticOffsets(newRank, 0);
+  ArrayRef<int64_t> oldShape = memrefType.getShape();
+  mlir::ValueRange oldSizes = reinterpretCastOp->getSizes();
+  unsigned idx = 0;
+  SmallVector<int64_t> newStaticSizes;
+  OpBuilder b(*reinterpretCastOp);
+  // Collectthe map operands which will be used to compute the new normalized
+  // memref shape.
+  for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) {
+    if (oldShape[i] == ShapedType::kDynamic)
+      mapOperands[i] =
+          b.create<arith::SubIOp>(loc, oldSizes[0].getType(), oldSizes[idx++],
+                                  b.create<arith::ConstantIndexOp>(loc, 1));
+    else
+      mapOperands[i] = b.create<arith::ConstantIndexOp>(loc, oldShape[i] - 1);
+  }
+  for (unsigned i = 0, e = oldStrides.size(); i < e; i++)
+    mapOperands[memrefType.getRank() + i] = oldStrides[i];
+  SmallVector<Value> newSizes;
+  ArrayRef<int64_t> newShape = newMemRefType.getShape();
+  // Compute size along all the dimensions of the new normalized memref.
+  for (unsigned i = 0; i < newRank; i++) {
+    if (newShape[i] != ShapedType::kDynamic)
+      continue;
+    newSizes.push_back(b.create<AffineApplyOp>(
+        loc,
+        AffineMap::get(oldLayoutMap.getNumDims(), oldLayoutMap.getNumSymbols(),
+                       oldLayoutMap.getResult(i)),
+        mapOperands));
+  }
+  for (unsigned i = 0, e = newSizes.size(); i < e; i++)
+    newSizes[i] =
+        b.create<arith::AddIOp>(loc, newSizes[i].getType(), newSizes[i],
+                                b.create<arith::ConstantIndexOp>(loc, 1));
+  // Create the new reinterpret_cast op.
+  memref::ReinterpretCastOp newReinterpretCast =
+      b.create<memref::ReinterpretCastOp>(
+          loc, newMemRefType, reinterpretCastOp->getSource(),
+          /*offsets=*/mlir::ValueRange(), newSizes, /*strides=*/mlir::ValueRange(),
+          /*static_offsets=*/newStaticOffsets,
+          /*static_sizes=*/newShape,
+          /*static_strides=*/newStaticStrides);
+
+  // Replace all uses of the old memref.
+  if (failed(replaceAllMemRefUsesWith(oldMemRef,
+                                      /*newMemRef=*/newReinterpretCast,
+                                      /*extraIndices=*/{},
+                                      /*indexRemap=*/oldLayoutMap,
+                                      /*extraOperands=*/{},
+                                      /*symbolOperands=*/oldStrides,
+                                      /*domOpFilter=*/nullptr,
+                                      /*postDomOpFilter=*/nullptr,
+                                      /*allowNonDereferencingOps=*/true))) {
+    // If it failed (due to escapes for example), bail out.
+    newReinterpretCast->erase();
+    return failure();
+  }
+
+  oldMemRef.replaceAllUsesWith(newReinterpretCast);
+  reinterpretCastOp->erase();
+  return success();
+}
+
 template LogicalResult
 mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
 template LogicalResult
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index 08b853fe65b85..b8d8a99c33084 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -363,6 +363,15 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
   for (memref::AllocaOp allocaOp : allocaOps)
     (void)normalizeMemRef(&allocaOp);
 
+  // Turn memrefs' non-identity layouts maps into ones with identity. Collect
+  // reinterpret_cast ops first and then process since normalizeMemRef
+  // replaces/erases ops during memref rewriting.
+  SmallVector<memref::ReinterpretCastOp> reinterpretCastOps;
+  funcOp.walk(
+      [&](memref::ReinterpretCastOp op) { reinterpretCastOps.push_back(op); });
+  for (memref::ReinterpretCastOp reinterpretCastOp : reinterpretCastOps)
+    (void)normalizeMemRef(&reinterpretCastOp);
+
   // We use this OpBuilder to create new memref layout later.
   OpBuilder b(funcOp);
 
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
index 3bede131325a7..c3791d02f6cee 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
@@ -165,3 +165,35 @@ func.func @prefetch_normalize(%arg0: memref<512xf32, affine_map<(d0) -> (d0 floo
   }
   return
 }
+
+#map_strided = affine_map<(d0, d1) -> (d0 * 7 + d1)>
+
+// CHECK-LABEL: test_reinterpret_cast
+func.func @test_reinterpret_cast(%arg0: memref<5x7xf32>, %arg1: memref<5x7xf32>, %arg2: memref<5x7xf32>) {
+  %0 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [5, 7], strides: [7, 1] : memref<5x7xf32> to memref<5x7xf32, #map_strided>
+  // CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [35], strides: [1] : memref<5x7xf32> to memref<35xf32>
+  affine.for %arg5 = 0 to 5 {
+    affine.for %arg6 = 0 to 7 {
+      %1 = affine.load %0[%arg5, %arg6] : memref<5x7xf32, #map_strided>
+      // CHECK: affine.load %reinterpret_cast[%{{.*}} * 7 + %{{.*}}] : memref<35xf32>
+      %2 = affine.load %arg1[%arg5, %arg6] : memref<5x7xf32>
+      %3 = arith.subf %1, %2 : f32
+      affine.store %3, %arg2[%arg5, %arg6] : memref<5x7xf32>
+    }
+  }
+  return
+}
+
+
+// CHECK-LABEL: reinterpret_cast_non_zero_offset
+func.func @reinterpret_cast_non_zero_offset(%arg0: index, %arg1: memref<1x10x17xi32, strided<[?, ?, ?], offset: ?>>, %arg2: memref<1x10x17xi32, strided<[?, ?, ?], offset: ?>>, %arg3: memref<1x10x17xi32, strided<[?, ?, ?], offset: ?>>) -> (memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>) {
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x10x17xi32>
+  %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x17xf32>
+  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x10x17xf32>
+  cf.br ^bb3
+^bb3:  // pred: ^bb1
+  // CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [32], strides: [1] : memref<2x17xf32> to memref<32xf32>
+  // CHECK: return %[[REINTERPRET_CAST]], %[[REINTERPRET_CAST]], %{{.*}}, %{{.*}}, %{{.*}} : memref<32xf32>, memref<32xf32>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>
+  %reinterpret_cast = memref.reinterpret_cast %alloc_0 to offset: [27], sizes: [1, 5], strides: [17, 1] : memref<2x17xf32> to memref<1x5xf32, strided<[17, 1], offset: 27>>
+  return %reinterpret_cast, %reinterpret_cast, %alloc_0, %alloc, %alloc_1 : memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>
+}
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index e93a1a4ebae53..440f4776424cc 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -3,8 +3,8 @@
 // This file tests whether the memref type having non-trivial map layouts
 // are normalized to trivial (identity) layouts.
 
-// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>
-// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)>
+// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 + (d1 floordiv 2) * 6)>
+// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 2 + (d0 floordiv 2) * 6)>
 // CHECK-DAG: #[[$REDUCE_MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>
 
 // CHECK-LABEL: func @permute()



More information about the Mlir-commits mailing list