[Mlir-commits] [mlir] [mlir][sparse] first end-to-end linalg.generic op on BSR (PR #70880)

Aart Bik llvmlistbot at llvm.org
Tue Oct 31 18:05:42 PDT 2023


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/70880

>From 83a4ae330e0640d92bf9702b98ae6a9f31aa64e9 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 31 Oct 2023 17:07:01 -0700
Subject: [PATCH 1/2] [mlir][sparse] first end-to-end linalg.generic op on BSR

---
 .../SparseTensor/IR/SparseTensorType.h        |  8 +++---
 .../Transforms/SparseReinterpretMap.cpp       |  8 +++---
 .../SparsificationAndBufferizationPass.cpp    | 14 ++++++++--
 .../Dialect/SparseTensor/CPU/block.mlir       | 27 +++++++++++++++++--
 4 files changed, 44 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 1fd91d0c02e4d1b..3e9cada83c6d50b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -341,11 +341,9 @@ template <typename T>
 inline SparseTensorType getSparseTensorType(T t) {
   return SparseTensorType(getRankedTensorType(t));
 }
-template <typename T>
-inline std::optional<SparseTensorType> tryGetSparseTensorType(T t) {
-  RankedTensorType rtp = getRankedTensorType(t);
-  if (rtp)
-    return SparseTensorType(rtp);
+inline std::optional<SparseTensorType> tryGetSparseTensorType(Value v) {
+  if (isa<RankedTensorType>(v.getType()))
+    return getSparseTensorType(v);
   return std::nullopt;
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 31cc8525725d43d..a822effbb2ab78c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -116,11 +116,11 @@ struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
         if (map.getResult(i).getKind() != AffineExprKind::DimId)
           return failure();
       // Inspect sparse operands.
-      auto stt = getSparseTensorType(t.get());
-      if (stt.hasEncoding()) {
-        if (stt.isPermutation())
+      auto stt = tryGetSparseTensorType(t.get());
+      if (stt && stt->hasEncoding()) {
+        if (stt->isPermutation())
           continue;
-        assert(stt.getDimRank() < stt.getLvlRank()); // only allowed non-perm
+        assert(stt->getDimRank() < stt->getLvlRank()); // only allowed non-perm
         if (tx)
           return failure(); // more than one non-perm
         if (!map.isIdentity())
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 41940f731e76c17..354e2e4bd4facc6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -104,8 +104,12 @@ class SparsificationAndBufferizationPass
   }
 
   void runOnOperation() override {
+#ifdef AART
+    llvm::dbgs() << "\n\n**** BEGIN MINI PIPELINE ****\n\n";
+    getOperation().dump();
+#endif
+    // Run enabling transformations.
     {
-      // Run enabling transformations.
       OpPassManager pm("builtin.module");
       pm.addPass(createPreSparsificationRewritePass());
       pm.addNestedPass<func::FuncOp>(
@@ -128,7 +132,7 @@ class SparsificationAndBufferizationPass
                                                  bufferizationOptions)))
       return signalPassFailure();
 
-    // `testAnalysisOnly` is a debug/testing flag. If set, the results of
+    // Option `testAnalysisOnly` is a debug/testing flag. If set, the results of
     // OneShotAnalysis are added to the IR via attributes. In that case, do not
     // continue with the remaining pipeline.
     if (bufferizationOptions.testAnalysisOnly)
@@ -139,6 +143,8 @@ class SparsificationAndBufferizationPass
     // of `bufferization.alloc_tensor` ops.
     {
       OpPassManager pm("builtin.module");
+      pm.addPass(
+          createSparseReinterpretMapPass(ReinterpretMapScope::kGenericOnly));
       pm.addPass(createSparsificationPass(sparsificationOptions));
       pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
       pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
@@ -166,6 +172,10 @@ class SparsificationAndBufferizationPass
     // Bufferize all dense ops.
     if (failed(runDenseBufferization()))
       signalPassFailure();
+#ifdef AART
+    llvm::dbgs() << "\n\n**** END MINI PIPELINE ****\n\n";
+    getOperation().dump();
+#endif
   }
 
 private:
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
index 78d35ada6acc11c..e1cdc9ed6ba3d41 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
@@ -25,6 +25,8 @@
 // REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false
 // R_UN: %{compile} | env %{env} %{run} | FileCheck %s
 
+!Filename = !llvm.ptr<i8>
+
 #BSR = #sparse_tensor.encoding<{
   map = (i, j) ->
     ( i floordiv 2 : dense
@@ -38,8 +40,12 @@
   map = (i, j, k, l) -> ( i  : dense, j  : compressed, k  : dense, l  : dense)
 }>
 
-
-!Filename = !llvm.ptr<i8>
+#trait_scale_inplace = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  iterator_types = ["parallel", "parallel"]
+}
 
 //
 // Example 2x2 block storage:
@@ -62,6 +68,17 @@ module {
 
   func.func private @getTensorFilename(index) -> (!Filename)
 
+  func.func @scale(%arg0: tensor<?x?xf64, #BSR>) -> tensor<?x?xf64, #BSR> {
+    %c = arith.constant 3.0 : f64
+    %0 = linalg.generic #trait_scale_inplace
+      outs(%arg0: tensor<?x?xf64, #BSR>) {
+        ^bb(%x: f64):
+          %1 = arith.mulf %x, %c : f64
+          linalg.yield %1 : f64
+      } -> tensor<?x?xf64, #BSR>
+    return %0 : tensor<?x?xf64, #BSR>
+  }
+
   func.func @entry() {
     %c0 = arith.constant 0   : index
     %f0 = arith.constant 0.0 : f64
@@ -89,6 +106,12 @@ module {
     %vecdsdd = vector.transfer_read %vdsdd[%c0], %f0 : memref<?xf64>, vector<12xf64>
     vector.print %vecdsdd : vector<12xf64>
 
+    // CHECK-NEXT: ( 3, 6, 0, 9, 12, 0, 0, 15, 18, 21, 24, 0 )
+    %As = call @scale(%A) : (tensor<?x?xf64, #BSR>) -> (tensor<?x?xf64, #BSR>)
+    %vals = sparse_tensor.values %As : tensor<?x?xf64, #BSR> to memref<?xf64>
+    %vecs = vector.transfer_read %vals[%c0], %f0 : memref<?xf64>, vector<12xf64>
+    vector.print %vecs : vector<12xf64>
+
     // Release the resources.
     bufferization.dealloc_tensor %A: tensor<?x?xf64, #BSR>
 

>From 5042ad2e9adf20649f445a789adacf992a9850cf Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 31 Oct 2023 18:05:18 -0700
Subject: [PATCH 2/2] typo

---
 .../Transforms/SparsificationAndBufferizationPass.cpp     | 8 --------
 1 file changed, 8 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 354e2e4bd4facc6..4a293f6819d0976 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -104,10 +104,6 @@ class SparsificationAndBufferizationPass
   }
 
   void runOnOperation() override {
-#ifdef AART
-    llvm::dbgs() << "\n\n**** BEGIN MINI PIPELINE ****\n\n";
-    getOperation().dump();
-#endif
     // Run enabling transformations.
     {
       OpPassManager pm("builtin.module");
@@ -172,10 +168,6 @@ class SparsificationAndBufferizationPass
     // Bufferize all dense ops.
     if (failed(runDenseBufferization()))
       signalPassFailure();
-#ifdef AART
-    llvm::dbgs() << "\n\n**** END MINI PIPELINE ****\n\n";
-    getOperation().dump();
-#endif
   }
 
 private:



More information about the Mlir-commits mailing list