[Mlir-commits] [mlir] d186277 - [MLIR][Bufferization] Fold LoadOp only when the buffer is read only (#172595)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 13 22:13:36 PST 2026


Author: Batzorig Zorigoo
Date: 2026-01-14T07:13:31+01:00
New Revision: d186277e6b5f051262422165fac2885ad925cca5

URL: https://github.com/llvm/llvm-project/commit/d186277e6b5f051262422165fac2885ad925cca5
DIFF: https://github.com/llvm/llvm-project/commit/d186277e6b5f051262422165fac2885ad925cca5.diff

LOG: [MLIR][Bufferization] Fold LoadOp only when the buffer is read only (#172595)

When we `memref.load` from a buffer, it folded to `tensor.extract` even
when the buffer was writable, causing unexpected results. For example:

```mlir
func.func @load_after_write_from_buffer_cast(%arg0: index, %arg1: index,
                            %arg2: tensor<?x?xf32>) -> f32 {
  %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
  linalg.ceil ins(%0 : memref<?x?xf32>) outs(%0 : memref<?x?xf32>)
  %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
  return %1 : f32
}
```
would fold into
```mlir
module {
  func.func @load_after_write_from_buffer_cast(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
    %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
    linalg.ceil ins(%0 : memref<?x?xf32>) outs(%0 : memref<?x?xf32>)
    %extracted = tensor.extract %arg2[%arg0, %arg1] : tensor<?x?xf32>
    return %extracted : f32
  }
}
```

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/test/Dialect/Bufferization/canonicalize.mlir
    mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
    mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
    mlir/test/Dialect/SparseTensor/sparse_pack.mlir
    mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index eda6bf276be06..4515a5b5a2671 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -847,7 +847,7 @@ struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
   LogicalResult matchAndRewrite(memref::LoadOp load,
                                 PatternRewriter &rewriter) const override {
     auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
-    if (!toBuffer)
+    if (!toBuffer || !toBuffer.getReadOnly())
       return failure();
 
     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toBuffer.getTensor(),

diff  --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index ae1d1fcfc19dc..df07511798b91 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -294,11 +294,29 @@ func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8
 
 // -----
 
+// Verify LoadOfToBuffer skips writable buffers
+// CHECK-LABEL: func @load_after_write_from_buffer_cast(
+func.func @load_after_write_from_buffer_cast(%arg0: index, %arg1: index,
+                            %arg2: tensor<?x?xf32>) -> f32 {
+  %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
+  linalg.ceil ins(%0 : memref<?x?xf32>) outs(%0 : memref<?x?xf32>)
+  %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
+  return %1 : f32
+}
+// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
+//      CHECK: %[[M:.+]] = bufferization.to_buffer %[[TENSOR]] : tensor<?x?xf32> to memref<?x?xf32>
+//      CHECK: linalg.ceil ins(%[[M]] : memref<?x?xf32>) outs(%[[M]] : memref<?x?xf32>)
+//      CHECK: %[[RES:.*]] = memref.load %[[M]][%[[IDX0]], %[[IDX1]]] : memref<?x?xf32>
+//      CHECK: return %[[RES]] : f32
+
+// -----
+
 // Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
 // CHECK-LABEL: func @load_from_buffer_cast(
 func.func @load_from_buffer_cast(%arg0: index, %arg1: index,
                             %arg2: tensor<?x?xf32>) -> f32 {
-  %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
+  %0 = bufferization.to_buffer %arg2 read_only : tensor<?x?xf32> to memref<?x?xf32>
   %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
   return %1 : f32
 }

diff  --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
index d828afe13c622..8c7315736819e 100644
--- a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
@@ -31,6 +31,7 @@
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_buffer %[[VAL_10]] :
+// CHECK-DAG:       %[[M:.*]] = bufferization.to_buffer %[[VAL_1]] :
 // CHECK-DAG:       linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_16:.*]] = arith.subi %[[VAL_15]], %[[VAL_7]] : index
@@ -49,7 +50,7 @@
 // CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
 // CHECK:               %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_7]] : index
 // CHECK:               %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
-// CHECK:               %[[VAL_29:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : tensor<8x8xf32>
+// CHECK:               %[[VAL_29:.*]] = memref.load %[[M]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
 // CHECK:               %[[VAL_30:.*]] = arith.mulf %[[VAL_28]], %[[VAL_29]] : f32
 // CHECK:               memref.store %[[VAL_30]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
 // CHECK:             } {"Emitted from" = "linalg.generic"}

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index bf3473ead204e..912f78a0b81fc 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -31,7 +31,7 @@
 // CHECK:                 } do {
 // CHECK:                   %[[D3:.*]] = "subsect<trivial<compressed[0,1]>>.deref"
 // CHECK:                   "trivial<batch[1,1]>.locate"(%{{.*}}, %[[D3]])
-// CHECK:                   tensor.extract %{{.*}}{{\[}}%[[D2]], %[[D3]]]
+// CHECK:                   memref.load %{{.*}}{{\[}}%[[D2]], %[[D3]]]
 // CHECK:                   arith.muli
 // CHECK:                   arith.addi
 // CHECK:                   "subsect<trivial<compressed[0,1]>>.next

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 4546d3367b16d..ebbcc5fc7c7cf 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -22,7 +22,7 @@
 // CHECK:           %[[VAL_13:.*]] = sparse_tensor.storage_specifier.init
 // CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]]  lvl_sz at 0 with %[[VAL_4]]
 // CHECK:           %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]]  pos_mem_sz at 0 with %[[VAL_3]]
-// CHECK:           %[[VAL_16:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_5]]] : tensor<2xindex>
+// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<2xindex>
 // CHECK:           %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index
 // CHECK:           %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]]  crd_mem_sz at 0 with %[[VAL_17]]
 // CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  lvl_sz at 1 with %[[VAL_4]]

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index 4abaf03dff50f..e2c841b1ac7d5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -27,7 +27,7 @@
 // CHECK-HIR-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
 // CHECK-HIR-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
 // CHECK-HIR-DAG:       %[[VAL_10:.*]] = bufferization.to_buffer %[[VAL_1]] : tensor<f32> to memref<f32>
-// CHECK-HIR:           %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
+// CHECK-HIR:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
 // CHECK-HIR:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
 // CHECK-HIR:             %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_6]] : index
 // CHECK-HIR:             %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
@@ -59,7 +59,7 @@
 // CHECK-MIR-DAG:       %[[DimSize2:.*]] = call @sparseLvlSize(%[[ARGA]], %[[I2]])
 // CHECK-MIR-DAG:       %[[VAL_8:.*]] = call @sparseValuesF32(%[[ARGA]]) : (!llvm.ptr) -> memref<?xf32>
 // CHECK-MIR-DAG:       %[[VAL_10:.*]] = bufferization.to_buffer %[[ARGX]] : tensor<f32> to memref<f32>
-// CHECK-MIR:           %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor<f32>
+// CHECK-MIR:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
 // CHECK-MIR:           %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
 // CHECK-MIR:             %[[VAL_18:.*]] = arith.muli %[[D2]], %[[DimSize1]] : index
 // CHECK-MIR:             %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {


        


More information about the Mlir-commits mailing list