[Mlir-commits] [mlir] a27d886 - [mlir][linalg][bufferize] Fix element-wise access optimization for sparse tensors (#87305)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 2 17:57:29 PDT 2024


Author: Matthias Springer
Date: 2024-04-03T09:57:25+09:00
New Revision: a27d886ce4cc8be8f67a8331c400d6fe2a273ebd

URL: https://github.com/llvm/llvm-project/commit/a27d886ce4cc8be8f67a8331c400d6fe2a273ebd
DIFF: https://github.com/llvm/llvm-project/commit/a27d886ce4cc8be8f67a8331c400d6fe2a273ebd.diff

LOG: [mlir][linalg][bufferize] Fix element-wise access optimization for sparse tensors (#87305)

`linalg.generic` ops with sparse tensors do not necessarily bufferize to
element-wise access, because insertions into a sparse tensor may change
the layout of (or reallocate) the underlying sparse data structures.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 58fb2e91b4f637..899b8c87d0df77 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
@@ -110,6 +111,10 @@ struct LinalgOpInterface
                                      ArrayRef<OpOperand *> opOperands) const {
     auto linalgOp = cast<linalg::LinalgOp>(op);
 
+    // Accesses into sparse data structures are not necessarily elementwise.
+    if (sparse_tensor::hasAnySparseOperand(linalgOp))
+      return false;
+
     // All loops must be parallel.
     if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
       return false;

diff  --git a/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir b/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
index 6c2292be161a53..b769acdc7825ce 100644
--- a/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
+++ b/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
@@ -70,3 +70,39 @@ func.func @update_notinplace(%argb: tensor<10xf32>, %arga: tensor<10xf32, #SV>)
   } -> tensor<10xf32>
   return %0, %argb : tensor<10xf32>, tensor<10xf32>
 }
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>
+
+// linalg.generic with sparse tensors does not necessarily bufferize to
+// element-wise access into the underlying sparse data structures.
+
+// CHECK-LABEL: func @sparse_non_elementwise(
+func.func @sparse_non_elementwise(%arg0: tensor<64x64xf32, #sparse>, %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[alloc0:.*]] = bufferization.alloc_tensor()
+  // CHECK: %[[alloc1:.*]] = bufferization.alloc_tensor()
+  %0 = bufferization.alloc_tensor() : tensor<64x64xf32>
+  // CHECK: %[[generic0:.*]] = linalg.generic {{.*}} outs(%[[alloc1]] : {{.*}})
+  %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%0 : tensor<64x64xf32>) {
+  ^bb0(%out: f32):
+    linalg.yield %cst : f32
+  } -> tensor<64x64xf32>
+  // CHECK: linalg.generic {{.*}} outs(%[[generic0]] : {{.*}})
+  %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %arg2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%1 : tensor<64x64xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %4 = arith.mulf %in, %in_0 : f32
+    %5 = arith.addf %out, %4 : f32
+    linalg.yield %5 : f32
+  } -> tensor<64x64xf32>
+  // CHECK: linalg.generic {{.*}} outs(%[[alloc0]] : {{.*}})
+  %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %2 : tensor<64x64xf32, #sparse>, tensor<64x64xf32>) outs(%0 : tensor<64x64xf32>) attrs =  {sorted = true} {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %4 = arith.mulf %in, %in_0 : f32
+    linalg.yield %4 : f32
+  } -> tensor<64x64xf32>
+  return %3 : tensor<64x64xf32>
+}


        


More information about the Mlir-commits mailing list