[Mlir-commits] [mlir] [mlir][memref] Support folding memref.load from global splat constants (PR #176627)

Longsheng Mou llvmlistbot at llvm.org
Sun Jan 18 08:17:41 PST 2026


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/176627

>From 04c8bd1a6580177c16d99bad7eb24dc454f58c66 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sun, 18 Jan 2026 14:03:17 +0800
Subject: [PATCH 1/2] [mlir][memref] Support folding memref.load from global
 splat constants

This change extends the memref.load folding hook to fold loads from
global constant memrefs initialized with splat values.
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   | 20 ++++++++++++++++++++
 mlir/test/Dialect/MemRef/canonicalize.mlir | 19 +++++++++++++++++++
 2 files changed, 39 insertions(+)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e0f7a8b452a1d..b15b2a3e24b95 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1705,6 +1705,26 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
   /// load(memrefcast) -> load
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
+
+  // Fold load from a global constant memref.
+  auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
+  if (!getGlobalOp)
+    return {};
+
+  // Get to the memref.global defining the symbol.
+  auto global = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
+      getGlobalOp, getGlobalOp.getNameAttr());
+  if (!global)
+    return {};
+  // Check if the global memref is a constant.
+  auto cstAttr =
+      dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
+  if (!cstAttr)
+    return {};
+  // If it's a splat constant, we can fold irrespective of indices.
+  if (auto splatAttr = dyn_cast<SplatElementsAttr>(cstAttr))
+    return splatAttr.getSplatValue<Attribute>();
+
   return OpFoldResult();
 }
 
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 122906037b952..a3af67076a1d2 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1416,6 +1416,25 @@ func.func @load_store_nontemporal(%input : memref<32xf32, affine_map<(d0) -> (d0
 
 // -----
 
+memref.global "private" constant @__constant_32xf32 : memref<32xf32> = dense<1.000000e+00>
+// CHECK-LABEL: func @fold_const_splat_global
+func.func @fold_const_splat_global() -> memref<32xf32> {
+  // CHECK-NEXT: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+  %0 = memref.get_global @__constant_32xf32 : memref<32xf32>
+  %alloc = memref.alloc() : memref<32xf32>
+  %c32 = arith.constant 32 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  scf.for %arg0 = %c0 to %c32 step %c1 {
+    %1 = memref.load %0[%arg0] : memref<32xf32>
+    // CHECK: memref.store %[[CST]], %{{.*}}
+    memref.store %1, %alloc[%arg0] : memref<32xf32>
+  }
+  return %alloc : memref<32xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_trivial_memory_space_cast(
 //  CHECK-SAME:     %[[arg:.*]]: memref<?xf32>
 //       CHECK:   return %[[arg]]

>From eab897f4323a30bbcc42a67cbabd5d0845691694 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Mon, 19 Jan 2026 00:17:01 +0800
Subject: [PATCH 2/2] directly check for SplatElementsAttr

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 13 +++++--------
 1 file changed, 5 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b15b2a3e24b95..7f796239ae121 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1716,16 +1716,13 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
       getGlobalOp, getGlobalOp.getNameAttr());
   if (!global)
     return {};
-  // Check if the global memref is a constant.
-  auto cstAttr =
-      dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
-  if (!cstAttr)
-    return {};
   // If it's a splat constant, we can fold irrespective of indices.
-  if (auto splatAttr = dyn_cast<SplatElementsAttr>(cstAttr))
-    return splatAttr.getSplatValue<Attribute>();
+  auto splatAttr =
+      dyn_cast_or_null<SplatElementsAttr>(global.getConstantInitValue());
+  if (!splatAttr)
+    return {};
 
-  return OpFoldResult();
+  return splatAttr.getSplatValue<Attribute>();
 }
 
 FailureOr<std::optional<SmallVector<Value>>>



More information about the Mlir-commits mailing list