[Mlir-commits] [mlir] [MLIR][memref] Fix normalization issue in memref.load (PR #107771)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Sep 8 10:49:25 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (DarshanRamakant)

<details>
<summary>Changes</summary>

This change will fix the normalization issue with
memref.load when the associated affine map is
reducing the dimension.

---
Full diff: https://github.com/llvm/llvm-project/pull/107771.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+83-1) 
- (modified) mlir/test/Dialect/MemRef/normalize-memrefs.mlir (+30) 


``````````diff
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 898467d573362b..70bfb322932346 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -27,6 +27,7 @@
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/LogicalResult.h"
 #include <optional>
 
 #define DEBUG_TYPE "affine-utils"
@@ -1146,7 +1147,88 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
       // is set.
       return failure();
     }
-    op->setOperand(memRefOperandPos, newMemRef);
+
+    // Check if it is a memref.load
+    auto memrefLoad = dyn_cast<memref::LoadOp>(op);
+    bool isReductionLike =
+        indexRemap.getNumResults() < indexRemap.getNumInputs();
+    if (!memrefLoad || !isReductionLike) {
+      op->setOperand(memRefOperandPos, newMemRef);
+      return success();
+    }
+
+    unsigned oldMapNumInputs = oldMemRefRank;
+    SmallVector<Value, 4> oldMapOperands(
+        op->operand_begin() + memRefOperandPos + 1,
+        op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
+    SmallVector<Value, 4> oldMemRefOperands;
+    oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
+    SmallVector<Value, 4> remapOperands;
+    remapOperands.reserve(extraOperands.size() + oldMemRefRank +
+                          symbolOperands.size());
+    remapOperands.append(extraOperands.begin(), extraOperands.end());
+    remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
+    remapOperands.append(symbolOperands.begin(), symbolOperands.end());
+
+    SmallVector<Value, 4> remapOutputs;
+    remapOutputs.reserve(oldMemRefRank);
+    SmallVector<Value, 4> affineApplyOps;
+
+    if (indexRemap &&
+        indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
+      // Remapped indices.
+      for (auto resultExpr : indexRemap.getResults()) {
+        auto singleResMap = AffineMap::get(
+            indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
+        auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+                                                  remapOperands);
+        remapOutputs.push_back(afOp);
+        affineApplyOps.push_back(afOp);
+      }
+    } else {
+      // No remapping specified.
+      remapOutputs.assign(remapOperands.begin(), remapOperands.end());
+    }
+
+    SmallVector<Value, 4> newMapOperands;
+    newMapOperands.reserve(newMemRefRank);
+
+    // Prepend 'extraIndices' in 'newMapOperands'.
+    for (Value extraIndex : extraIndices) {
+      assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
+             "invalid memory op index");
+      newMapOperands.push_back(extraIndex);
+    }
+
+    // Append 'remapOutputs' to 'newMapOperands'.
+    newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
+
+    // Create new fully composed AffineMap for new op to be created.
+    assert(newMapOperands.size() == newMemRefRank);
+
+    OperationState state(op->getLoc(), op->getName());
+    // Construct the new operation using this memref.
+    state.operands.reserve(newMapOperands.size() + extraIndices.size());
+    state.operands.push_back(newMemRef);
+
+    // Insert the new memref map operands.
+    state.operands.append(newMapOperands.begin(), newMapOperands.end());
+
+    state.types.reserve(op->getNumResults());
+    for (auto result : op->getResults())
+      state.types.push_back(result.getType());
+
+    // Add attribute for 'newMap', other Attributes do not change.
+    // auto newMapAttr = AffineMapAttr::get(newMap);
+    for (auto namedAttr : op->getAttrs()) {
+      state.attributes.push_back(namedAttr);
+    }
+
+    // Create the new operation.
+    auto *repOp = builder.create(state);
+    op->replaceAllUsesWith(repOp);
+    op->erase();
+
     return success();
   }
   // Perform index rewrites for the dereferencing op and then replace the op
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index c7af033a22a2c6..ca485f9fddbc8d 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -363,3 +363,33 @@ func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index,
   %1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>>
   return %1 : tensor<16x512xf32>
 }
+
+#map0 = affine_map<(i,k) -> (2*(i mod 2) + (k mod 2) + 4 * (i floordiv 2) + 8*(k floordiv 2))>
+#map1 = affine_map<(k,j) -> ((k mod 2) + 2 * (j mod 2) + 8 * (k floordiv 2) + 4*(j floordiv 2))>
+#map2 = affine_map<(i,j) -> (4*i+j)>
+// CHECK-LABEL: func @memref_load_with_reduction_map
+func.func @memref_load_with_reduction_map(%arg0 :  memref<4x4xf32,#map2>) -> () {
+  %0 = memref.alloc() : memref<4x8xf32,#map0>
+  %1 = memref.alloc() : memref<8x4xf32,#map1>
+  %2 = memref.alloc() : memref<4x4xf32,#map2>
+  // CHECK-NOT:  memref<4x8xf32>
+  // CHECK-NOT:  memref<8x4xf32>
+  // CHECK-NOT:  memref<4x4xf32>
+  %cst = arith.constant 3.0 : f32
+  %cst0 = arith.constant 0 : index
+  affine.for %i = 0 to 4 {
+    affine.for %j = 0 to 8 {
+      affine.for %k = 0 to 8 {
+        // CHECK: affine.apply #map{{.*}}(%{{.*}}, %{{.*}})
+        // CHECK: memref.load %alloc[%{{.*}}] : memref<32xf32>
+        %a = memref.load %0[%i, %k] : memref<4x8xf32,#map0>
+        %b = memref.load %1[%k, %j] :memref<8x4xf32,#map1>
+        %c = memref.load %2[%i, %j] : memref<4x4xf32,#map2>
+        %3 = arith.mulf %a, %b : f32
+        %4 = arith.addf %3, %c : f32
+        affine.store %4, %arg0[%i, %j] : memref<4x4xf32,#map2>
+      }
+    }
+  }
+  return
+}
\ No newline at end of file

``````````

</details>


https://github.com/llvm/llvm-project/pull/107771


More information about the Mlir-commits mailing list