[Mlir-commits] [mlir] b140fb2 - [mlir][memref] Support folding memref.load from global splat constants (#176627)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 18 17:11:40 PST 2026
Author: Longsheng Mou
Date: 2026-01-19T09:11:36+08:00
New Revision: b140fb2f05d20d26d334cc94b235aaa5358c2d62
URL: https://github.com/llvm/llvm-project/commit/b140fb2f05d20d26d334cc94b235aaa5358c2d62
DIFF: https://github.com/llvm/llvm-project/commit/b140fb2f05d20d26d334cc94b235aaa5358c2d62.diff
LOG: [mlir][memref] Support folding memref.load from global splat constants (#176627)
This change extends the memref.load folding hook to fold loads from
global constant memrefs initialized with splat values.
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b782a8be19154..a9103b4d438ea 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1705,7 +1705,24 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
/// load(memrefcast) -> load
if (succeeded(foldMemRefCast(*this)))
return getResult();
- return OpFoldResult();
+
+ // Fold load from a global constant memref.
+ auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
+ if (!getGlobalOp)
+ return {};
+
+ // Get to the memref.global defining the symbol.
+ auto global = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
+ getGlobalOp, getGlobalOp.getNameAttr());
+ if (!global)
+ return {};
+ // If it's a splat constant, we can fold irrespective of indices.
+ auto splatAttr =
+ dyn_cast_or_null<SplatElementsAttr>(global.getConstantInitValue());
+ if (!splatAttr)
+ return {};
+
+ return splatAttr.getSplatValue<Attribute>();
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index d32f8f7efc5ff..17afd9a15b60d 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1416,6 +1416,25 @@ func.func @load_store_nontemporal(%input : memref<32xf32, affine_map<(d0) -> (d0
// -----
+memref.global "private" constant @__constant_32xf32 : memref<32xf32> = dense<1.000000e+00>
+// CHECK-LABEL: func @fold_const_splat_global
+func.func @fold_const_splat_global() -> memref<32xf32> {
+ // CHECK-NEXT: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+ %0 = memref.get_global @__constant_32xf32 : memref<32xf32>
+ %alloc = memref.alloc() : memref<32xf32>
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg0 = %c0 to %c32 step %c1 {
+ %1 = memref.load %0[%arg0] : memref<32xf32>
+ // CHECK: memref.store %[[CST]], %{{.*}}
+ memref.store %1, %alloc[%arg0] : memref<32xf32>
+ }
+ return %alloc : memref<32xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_trivial_memory_space_cast(
// CHECK-SAME: %[[arg:.*]]: memref<?xf32>
// CHECK: return %[[arg]]
More information about the Mlir-commits
mailing list