[Mlir-commits] [mlir] [mlir][sparse] first end-to-end linalg.generic op on BSR (PR #70880)
Aart Bik
llvmlistbot at llvm.org
Tue Oct 31 18:05:42 PDT 2023
https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/70880
>From 83a4ae330e0640d92bf9702b98ae6a9f31aa64e9 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 31 Oct 2023 17:07:01 -0700
Subject: [PATCH 1/2] [mlir][sparse] first end-to-end linalg.generic op on BSR
---
.../SparseTensor/IR/SparseTensorType.h | 8 +++---
.../Transforms/SparseReinterpretMap.cpp | 8 +++---
.../SparsificationAndBufferizationPass.cpp | 14 ++++++++--
.../Dialect/SparseTensor/CPU/block.mlir | 27 +++++++++++++++++--
4 files changed, 44 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 1fd91d0c02e4d1b..3e9cada83c6d50b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -341,11 +341,9 @@ template <typename T>
inline SparseTensorType getSparseTensorType(T t) {
return SparseTensorType(getRankedTensorType(t));
}
-template <typename T>
-inline std::optional<SparseTensorType> tryGetSparseTensorType(T t) {
- RankedTensorType rtp = getRankedTensorType(t);
- if (rtp)
- return SparseTensorType(rtp);
+inline std::optional<SparseTensorType> tryGetSparseTensorType(Value v) {
+ if (isa<RankedTensorType>(v.getType()))
+ return getSparseTensorType(v);
return std::nullopt;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 31cc8525725d43d..a822effbb2ab78c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -116,11 +116,11 @@ struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
if (map.getResult(i).getKind() != AffineExprKind::DimId)
return failure();
// Inspect sparse operands.
- auto stt = getSparseTensorType(t.get());
- if (stt.hasEncoding()) {
- if (stt.isPermutation())
+ auto stt = tryGetSparseTensorType(t.get());
+ if (stt && stt->hasEncoding()) {
+ if (stt->isPermutation())
continue;
- assert(stt.getDimRank() < stt.getLvlRank()); // only allowed non-perm
+ assert(stt->getDimRank() < stt->getLvlRank()); // only allowed non-perm
if (tx)
return failure(); // more than one non-perm
if (!map.isIdentity())
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 41940f731e76c17..354e2e4bd4facc6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -104,8 +104,12 @@ class SparsificationAndBufferizationPass
}
void runOnOperation() override {
+#ifdef AART
+ llvm::dbgs() << "\n\n**** BEGIN MINI PIPELINE ****\n\n";
+ getOperation().dump();
+#endif
+ // Run enabling transformations.
{
- // Run enabling transformations.
OpPassManager pm("builtin.module");
pm.addPass(createPreSparsificationRewritePass());
pm.addNestedPass<func::FuncOp>(
@@ -128,7 +132,7 @@ class SparsificationAndBufferizationPass
bufferizationOptions)))
return signalPassFailure();
- // `testAnalysisOnly` is a debug/testing flag. If set, the results of
+ // Option `testAnalysisOnly` is a debug/testing flag. If set, the results of
// OneShotAnalysis are added to the IR via attributes. In that case, do not
// continue with the remaining pipeline.
if (bufferizationOptions.testAnalysisOnly)
@@ -139,6 +143,8 @@ class SparsificationAndBufferizationPass
// of `bufferization.alloc_tensor` ops.
{
OpPassManager pm("builtin.module");
+ pm.addPass(
+ createSparseReinterpretMapPass(ReinterpretMapScope::kGenericOnly));
pm.addPass(createSparsificationPass(sparsificationOptions));
pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
@@ -166,6 +172,10 @@ class SparsificationAndBufferizationPass
// Bufferize all dense ops.
if (failed(runDenseBufferization()))
signalPassFailure();
+#ifdef AART
+ llvm::dbgs() << "\n\n**** END MINI PIPELINE ****\n\n";
+ getOperation().dump();
+#endif
}
private:
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
index 78d35ada6acc11c..e1cdc9ed6ba3d41 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir
@@ -25,6 +25,8 @@
// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false
// R_UN: %{compile} | env %{env} %{run} | FileCheck %s
+!Filename = !llvm.ptr<i8>
+
#BSR = #sparse_tensor.encoding<{
map = (i, j) ->
( i floordiv 2 : dense
@@ -38,8 +40,12 @@
map = (i, j, k, l) -> ( i : dense, j : compressed, k : dense, l : dense)
}>
-
-!Filename = !llvm.ptr<i8>
+#trait_scale_inplace = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel"]
+}
//
// Example 2x2 block storage:
@@ -62,6 +68,17 @@ module {
func.func private @getTensorFilename(index) -> (!Filename)
+ func.func @scale(%arg0: tensor<?x?xf64, #BSR>) -> tensor<?x?xf64, #BSR> {
+ %c = arith.constant 3.0 : f64
+ %0 = linalg.generic #trait_scale_inplace
+ outs(%arg0: tensor<?x?xf64, #BSR>) {
+ ^bb(%x: f64):
+ %1 = arith.mulf %x, %c : f64
+ linalg.yield %1 : f64
+ } -> tensor<?x?xf64, #BSR>
+ return %0 : tensor<?x?xf64, #BSR>
+ }
+
func.func @entry() {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0.0 : f64
@@ -89,6 +106,12 @@ module {
%vecdsdd = vector.transfer_read %vdsdd[%c0], %f0 : memref<?xf64>, vector<12xf64>
vector.print %vecdsdd : vector<12xf64>
+ // CHECK-NEXT: ( 3, 6, 0, 9, 12, 0, 0, 15, 18, 21, 24, 0 )
+ %As = call @scale(%A) : (tensor<?x?xf64, #BSR>) -> (tensor<?x?xf64, #BSR>)
+ %vals = sparse_tensor.values %As : tensor<?x?xf64, #BSR> to memref<?xf64>
+ %vecs = vector.transfer_read %vals[%c0], %f0 : memref<?xf64>, vector<12xf64>
+ vector.print %vecs : vector<12xf64>
+
// Release the resources.
bufferization.dealloc_tensor %A: tensor<?x?xf64, #BSR>
>From 5042ad2e9adf20649f445a789adacf992a9850cf Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 31 Oct 2023 18:05:18 -0700
Subject: [PATCH 2/2] typo
---
.../Transforms/SparsificationAndBufferizationPass.cpp | 8 --------
1 file changed, 8 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 354e2e4bd4facc6..4a293f6819d0976 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -104,10 +104,6 @@ class SparsificationAndBufferizationPass
}
void runOnOperation() override {
-#ifdef AART
- llvm::dbgs() << "\n\n**** BEGIN MINI PIPELINE ****\n\n";
- getOperation().dump();
-#endif
// Run enabling transformations.
{
OpPassManager pm("builtin.module");
@@ -172,10 +168,6 @@ class SparsificationAndBufferizationPass
// Bufferize all dense ops.
if (failed(runDenseBufferization()))
signalPassFailure();
-#ifdef AART
- llvm::dbgs() << "\n\n**** END MINI PIPELINE ****\n\n";
- getOperation().dump();
-#endif
}
private:
More information about the Mlir-commits
mailing list