[Mlir-commits] [mlir] [MLIR][memref] Fix normalization issue in memref.load (PR #107771)

Kai Sasaki llvmlistbot at llvm.org
Mon Oct 7 23:46:14 PDT 2024


https://github.com/Lewuathe updated https://github.com/llvm/llvm-project/pull/107771

>From 48e081b431be17ee508d363dfd83e4559c718be7 Mon Sep 17 00:00:00 2001
From: Darshan Bhat <darshanbhatsirsi at gmail.com>
Date: Sun, 8 Sep 2024 22:22:48 +0530
Subject: [PATCH] [MLIR][memref] Fix normalization issue in memref.load

This change will fix the normalization issue with
memref.load when the associated affine map is
reducing the dimension.
This PR fixes #82675
---
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       | 100 +++++++++++++++++-
 .../Dialect/MemRef/normalize-memrefs.mlir     |  38 +++++++
 2 files changed, 136 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 898467d573362b..910ad1733d03e8 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -27,6 +27,7 @@
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/LogicalResult.h"
 #include <optional>
 
 #define DEBUG_TYPE "affine-utils"
@@ -1093,6 +1094,90 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
     op->erase();
 }
 
+// Private helper function to transform memref.load with reduced rank.
+// This function will modify the indices of the memref.load to match the
+// newMemRef.
+LogicalResult transformMemRefLoadWithReducedRank(
+    Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
+    ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
+    ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
+  unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
+  unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
+  unsigned oldMapNumInputs = oldMemRefRank;
+  SmallVector<Value, 4> oldMapOperands(
+      op->operand_begin() + memRefOperandPos + 1,
+      op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
+  SmallVector<Value, 4> oldMemRefOperands;
+  oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
+  SmallVector<Value, 4> remapOperands;
+  remapOperands.reserve(extraOperands.size() + oldMemRefRank +
+                        symbolOperands.size());
+  remapOperands.append(extraOperands.begin(), extraOperands.end());
+  remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
+  remapOperands.append(symbolOperands.begin(), symbolOperands.end());
+
+  SmallVector<Value, 4> remapOutputs;
+  remapOutputs.reserve(oldMemRefRank);
+  SmallVector<Value, 4> affineApplyOps;
+
+  OpBuilder builder(op);
+
+  if (indexRemap &&
+      indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
+    // Remapped indices.
+    for (auto resultExpr : indexRemap.getResults()) {
+      auto singleResMap = AffineMap::get(
+          indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
+      auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+                                                remapOperands);
+      remapOutputs.push_back(afOp);
+      affineApplyOps.push_back(afOp);
+    }
+  } else {
+    // No remapping specified.
+    remapOutputs.assign(remapOperands.begin(), remapOperands.end());
+  }
+
+  SmallVector<Value, 4> newMapOperands;
+  newMapOperands.reserve(newMemRefRank);
+
+  // Prepend 'extraIndices' in 'newMapOperands'.
+  for (Value extraIndex : extraIndices) {
+    assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
+           "invalid memory op index");
+    newMapOperands.push_back(extraIndex);
+  }
+
+  // Append 'remapOutputs' to 'newMapOperands'.
+  newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
+
+  // Create new fully composed AffineMap for new op to be created.
+  assert(newMapOperands.size() == newMemRefRank);
+
+  OperationState state(op->getLoc(), op->getName());
+  // Construct the new operation using this memref.
+  state.operands.reserve(newMapOperands.size() + extraIndices.size());
+  state.operands.push_back(newMemRef);
+
+  // Insert the new memref map operands.
+  state.operands.append(newMapOperands.begin(), newMapOperands.end());
+
+  state.types.reserve(op->getNumResults());
+  for (auto result : op->getResults())
+    state.types.push_back(result.getType());
+
+  // Copy over the attributes from the old operation to the new operation.
+  for (auto namedAttr : op->getAttrs()) {
+    state.attributes.push_back(namedAttr);
+  }
+
+  // Create the new operation.
+  auto *repOp = builder.create(state);
+  op->replaceAllUsesWith(repOp);
+  op->erase();
+
+  return success();
+}
 // Perform the replacement in `op`.
 LogicalResult mlir::affine::replaceAllMemRefUsesWith(
     Value oldMemRef, Value newMemRef, Operation *op,
@@ -1146,8 +1231,19 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
       // is set.
       return failure();
     }
-    op->setOperand(memRefOperandPos, newMemRef);
-    return success();
+
+    // Check if it is a memref.load
+    auto memrefLoad = dyn_cast<memref::LoadOp>(op);
+    bool isReductionLike =
+        indexRemap.getNumResults() < indexRemap.getNumInputs();
+    if (!memrefLoad || !isReductionLike) {
+      op->setOperand(memRefOperandPos, newMemRef);
+      return success();
+    }
+
+    return transformMemRefLoadWithReducedRank(
+        op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
+        symbolOperands, indexRemap);
   }
   // Perform index rewrites for the dereferencing op and then replace the op
   NamedAttribute oldMapAttrPair =
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index c7af033a22a2c6..11114bcf2b1ab1 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -3,6 +3,10 @@
 // This file tests whether the memref type having non-trivial map layouts
 // are normalized to trivial (identity) layouts.
 
+// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)>
+// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)>
+// CHECK-DAG: #[[$REDUCE_MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 4 + d1)>
+
 // CHECK-LABEL: func @permute()
 func.func @permute() {
   %A = memref.alloc() : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>>
@@ -363,3 +367,37 @@ func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index,
   %1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>>
   return %1 : tensor<16x512xf32>
 }
+
+#map0 = affine_map<(i,k) -> (2 * (i mod 2) + (k mod 2) + 4 * (i floordiv 2) + 8 * (k floordiv 2))>
+#map1 = affine_map<(k,j) -> ((k mod 2) + 2 * (j mod 2) + 8 * (k floordiv 2) + 4 * (j floordiv 2))>
+#map2 = affine_map<(i,j) -> (4 * i + j)>
+// CHECK-LABEL: func @memref_load_with_reduction_map
+func.func @memref_load_with_reduction_map(%arg0 :  memref<4x4xf32,#map2>) -> () {
+  %0 = memref.alloc() : memref<4x8xf32,#map0>
+  %1 = memref.alloc() : memref<8x4xf32,#map1>
+  %2 = memref.alloc() : memref<4x4xf32,#map2>
+  // CHECK-NOT:  memref<4x8xf32>
+  // CHECK-NOT:  memref<8x4xf32>
+  // CHECK-NOT:  memref<4x4xf32>
+  %cst = arith.constant 3.0 : f32
+  %cst0 = arith.constant 0 : index
+  affine.for %i = 0 to 4 {
+    affine.for %j = 0 to 8 {
+      affine.for %k = 0 to 8 {
+        // CHECK: %[[INDEX0:.*]] = affine.apply #[[$REDUCE_MAP1]](%{{.*}}, %{{.*}})
+        // CHECK: memref.load %alloc[%[[INDEX0]]] : memref<32xf32>
+        %a = memref.load %0[%i, %k] : memref<4x8xf32,#map0>
+        // CHECK: %[[INDEX1:.*]] = affine.apply #[[$REDUCE_MAP2]](%{{.*}}, %{{.*}})
+        // CHECK: memref.load %alloc_0[%[[INDEX1]]] : memref<32xf32>
+        %b = memref.load %1[%k, %j] :memref<8x4xf32,#map1>
+        // CHECK: %[[INDEX2:.*]] = affine.apply #[[$REDUCE_MAP3]](%{{.*}}, %{{.*}})
+        // CHECK: memref.load %alloc_1[%[[INDEX2]]] : memref<16xf32>
+        %c = memref.load %2[%i, %j] : memref<4x4xf32,#map2>
+        %3 = arith.mulf %a, %b : f32
+        %4 = arith.addf %3, %c : f32
+        affine.store %4, %arg0[%i, %j] : memref<4x4xf32,#map2>
+      }
+    }
+  }
+  return
+}
\ No newline at end of file



More information about the Mlir-commits mailing list