[Mlir-commits] [mlir] [mlir][linalg][bufferize] Fix element-wise access optimization for sparse tensors (PR #87305)

Matthias Springer llvmlistbot at llvm.org
Mon Apr 1 20:08:46 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/87305

`linalg.generic` ops with sparse tensors do not necessarily bufferize to element-wise access, because insertions into a sparse tensor may change the layout of (or reallocate) the underlying sparse data structures.

>From e4c3746fc09b6bad9ac0304a8d91d049d47f46b3 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 2 Apr 2024 03:07:31 +0000
Subject: [PATCH] [mlir][linalg][bufferize] Fix elementwise access optimization
 for sparse tensors

`linalg.generic` ops with sparse tensors do not necessarily bufferize to element-wise access, because insertions into a sparse tensor may change the layout of (or reallocate) the underlying sparse data structures.
---
 .../BufferizableOpInterfaceImpl.cpp           |  5 +++
 ..._shot_bufferize_tensor_copy_insertion.mlir | 36 +++++++++++++++++++
 2 files changed, 41 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 58fb2e91b4f637..899b8c87d0df77 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
@@ -110,6 +111,10 @@ struct LinalgOpInterface
                                      ArrayRef<OpOperand *> opOperands) const {
     auto linalgOp = cast<linalg::LinalgOp>(op);
 
+    // Accesses into sparse data structures are not necessarily elementwise.
+    if (sparse_tensor::hasAnySparseOperand(linalgOp))
+      return false;
+
     // All loops must be parallel.
     if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
       return false;
diff --git a/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir b/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
index 6c2292be161a53..b769acdc7825ce 100644
--- a/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
+++ b/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
@@ -70,3 +70,39 @@ func.func @update_notinplace(%argb: tensor<10xf32>, %arga: tensor<10xf32, #SV>)
   } -> tensor<10xf32>
   return %0, %argb : tensor<10xf32>, tensor<10xf32>
 }
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>
+
+// linalg.generic with sparse tensors does not necessarily bufferize to
+// element-wise access into the underlying sparse data structures.
+
+// CHECK-LABEL: func @sparse_non_elementwise(
+func.func @sparse_non_elementwise(%arg0: tensor<64x64xf32, #sparse>, %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[alloc0:.*]] = bufferization.alloc_tensor()
+  // CHECK: %[[alloc1:.*]] = bufferization.alloc_tensor()
+  %0 = bufferization.alloc_tensor() : tensor<64x64xf32>
+  // CHECK: %[[generic0:.*]] = linalg.generic {{.*}} outs(%[[alloc1]] : {{.*}})
+  %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%0 : tensor<64x64xf32>) {
+  ^bb0(%out: f32):
+    linalg.yield %cst : f32
+  } -> tensor<64x64xf32>
+  // CHECK: linalg.generic {{.*}} outs(%[[generic0]] : {{.*}})
+  %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %arg2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%1 : tensor<64x64xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %4 = arith.mulf %in, %in_0 : f32
+    %5 = arith.addf %out, %4 : f32
+    linalg.yield %5 : f32
+  } -> tensor<64x64xf32>
+  // CHECK: linalg.generic {{.*}} outs(%[[alloc0]] : {{.*}})
+  %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %2 : tensor<64x64xf32, #sparse>, tensor<64x64xf32>) outs(%0 : tensor<64x64xf32>) attrs =  {sorted = true} {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %4 = arith.mulf %in, %in_0 : f32
+    linalg.yield %4 : f32
+  } -> tensor<64x64xf32>
+  return %3 : tensor<64x64xf32>
+}



More information about the Mlir-commits mailing list