[Mlir-commits] [mlir] [MLIR][memref] Fix normalization issue in memref.load (PR #107771)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Sep 14 06:38:33 PDT 2024
https://github.com/DarshanRamakant updated https://github.com/llvm/llvm-project/pull/107771
>From f8611e556eac8fdd8764a196a1fd5f7ad43708f8 Mon Sep 17 00:00:00 2001
From: Darshan Bhat <darshanbhatsirsi at gmail.com>
Date: Sun, 8 Sep 2024 22:22:48 +0530
Subject: [PATCH] [MLIR][memref] Fix normalization issue in memref.load
This change will fix the normalization issue with
memref.load when the associated affine map is
reducing the dimension.
---
mlir/lib/Dialect/Affine/Utils/Utils.cpp | 101 +++++++++++++++++-
.../Dialect/MemRef/normalize-memrefs.mlir | 30 ++++++
2 files changed, 129 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 898467d573362b..9496c4b219a033 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -27,6 +27,7 @@
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/LogicalResult.h"
#include <optional>
#define DEBUG_TYPE "affine-utils"
@@ -1093,6 +1094,90 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
op->erase();
}
+// Private helper function to transform memref.load with reduced rank.
+// This function will modify the indices of the memref.load to match the
+// newMemRef.
+LogicalResult transformMemRefLoadWithReducedRank(
+ Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
+ ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
+ ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
+ unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
+ unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
+ unsigned oldMapNumInputs = oldMemRefRank;
+ SmallVector<Value, 4> oldMapOperands(
+ op->operand_begin() + memRefOperandPos + 1,
+ op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
+ SmallVector<Value, 4> oldMemRefOperands;
+ oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
+ SmallVector<Value, 4> remapOperands;
+ remapOperands.reserve(extraOperands.size() + oldMemRefRank +
+ symbolOperands.size());
+ remapOperands.append(extraOperands.begin(), extraOperands.end());
+ remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
+ remapOperands.append(symbolOperands.begin(), symbolOperands.end());
+
+ SmallVector<Value, 4> remapOutputs;
+ remapOutputs.reserve(oldMemRefRank);
+ SmallVector<Value, 4> affineApplyOps;
+
+ OpBuilder builder(op);
+
+ if (indexRemap &&
+ indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
+ // Remapped indices.
+ for (auto resultExpr : indexRemap.getResults()) {
+ auto singleResMap = AffineMap::get(
+ indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
+ auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+ remapOperands);
+ remapOutputs.push_back(afOp);
+ affineApplyOps.push_back(afOp);
+ }
+ } else {
+ // No remapping specified.
+ remapOutputs.assign(remapOperands.begin(), remapOperands.end());
+ }
+
+ SmallVector<Value, 4> newMapOperands;
+ newMapOperands.reserve(newMemRefRank);
+
+ // Prepend 'extraIndices' in 'newMapOperands'.
+ for (Value extraIndex : extraIndices) {
+ assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
+ "invalid memory op index");
+ newMapOperands.push_back(extraIndex);
+ }
+
+ // Append 'remapOutputs' to 'newMapOperands'.
+ newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
+
+ // Create new fully composed AffineMap for new op to be created.
+ assert(newMapOperands.size() == newMemRefRank);
+
+ OperationState state(op->getLoc(), op->getName());
+ // Construct the new operation using this memref.
+ state.operands.reserve(newMapOperands.size() + extraIndices.size());
+ state.operands.push_back(newMemRef);
+
+ // Insert the new memref map operands.
+ state.operands.append(newMapOperands.begin(), newMapOperands.end());
+
+ state.types.reserve(op->getNumResults());
+ for (auto result : op->getResults())
+ state.types.push_back(result.getType());
+
+ // Copy over the attributes from the old operation to the new operation.
+ for (auto namedAttr : op->getAttrs()) {
+ state.attributes.push_back(namedAttr);
+ }
+
+ // Create the new operation.
+ auto *repOp = builder.create(state);
+ op->replaceAllUsesWith(repOp);
+ op->erase();
+
+ return success();
+}
// Perform the replacement in `op`.
LogicalResult mlir::affine::replaceAllMemRefUsesWith(
Value oldMemRef, Value newMemRef, Operation *op,
@@ -1146,8 +1231,20 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
// is set.
return failure();
}
- op->setOperand(memRefOperandPos, newMemRef);
- return success();
+
+ // Check if it is a memref.load
+ auto memrefLoad = dyn_cast<memref::LoadOp>(op);
+ bool isReductionLike =
+ indexRemap.getNumResults() < indexRemap.getNumInputs();
+ if (!memrefLoad || !isReductionLike) {
+ op->setOperand(memRefOperandPos, newMemRef);
+ return success();
+ }
+
+ return transformMemRefLoadWithReducedRank(op, oldMemRef, newMemRef,
+ memRefOperandPos, extraIndices,
+ extraOperands, symbolOperands,
+ indexRemap);
}
// Perform index rewrites for the dereferencing op and then replace the op
NamedAttribute oldMapAttrPair =
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index c7af033a22a2c6..3fc6c62b33a1be 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -363,3 +363,33 @@ func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index,
%1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>>
return %1 : tensor<16x512xf32>
}
+
+#map0 = affine_map<(i,k) -> (2 * (i mod 2) + (k mod 2) + 4 * (i floordiv 2) + 8 * (k floordiv 2))>
+#map1 = affine_map<(k,j) -> ((k mod 2) + 2 * (j mod 2) + 8 * (k floordiv 2) + 4 * (j floordiv 2))>
+#map2 = affine_map<(i,j) -> (4 * i + j)>
+// CHECK-LABEL: func @memref_load_with_reduction_map
+func.func @memref_load_with_reduction_map(%arg0 : memref<4x4xf32,#map2>) -> () {
+ %0 = memref.alloc() : memref<4x8xf32,#map0>
+ %1 = memref.alloc() : memref<8x4xf32,#map1>
+ %2 = memref.alloc() : memref<4x4xf32,#map2>
+ // CHECK-NOT: memref<4x8xf32>
+ // CHECK-NOT: memref<8x4xf32>
+ // CHECK-NOT: memref<4x4xf32>
+ %cst = arith.constant 3.0 : f32
+ %cst0 = arith.constant 0 : index
+ affine.for %i = 0 to 4 {
+ affine.for %j = 0 to 8 {
+ affine.for %k = 0 to 8 {
+ // CHECK: affine.apply #map{{.*}}(%{{.*}}, %{{.*}})
+ // CHECK: memref.load %alloc[%{{.*}}] : memref<32xf32>
+ %a = memref.load %0[%i, %k] : memref<4x8xf32,#map0>
+ %b = memref.load %1[%k, %j] :memref<8x4xf32,#map1>
+ %c = memref.load %2[%i, %j] : memref<4x4xf32,#map2>
+ %3 = arith.mulf %a, %b : f32
+ %4 = arith.addf %3, %c : f32
+ affine.store %4, %arg0[%i, %j] : memref<4x4xf32,#map2>
+ }
+ }
+ }
+ return
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list