[Mlir-commits] [mlir] 66088af - [mlir][sparse] Add arith-expand pass to the sparse-compiler pipeline.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 27 14:42:27 PDT 2022
Author: bixia1
Date: 2022-07-27T14:42:21-07:00
New Revision: 66088afbc806d963a717c4194c1980e271272de9
URL: https://github.com/llvm/llvm-project/commit/66088afbc806d963a717c4194c1980e271272de9
DIFF: https://github.com/llvm/llvm-project/commit/66088afbc806d963a717c4194c1980e271272de9.diff
LOG: [mlir][sparse] Add arith-expand pass to the sparse-compiler pipeline.
Modify an existing test to test the situation.
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D130658
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index 8c7639baaffcd..62f03b32f1086 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h"
#include "mlir/Conversion/Passes.h"
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
@@ -73,6 +74,8 @@ void mlir::sparse_tensor::buildSparseCompiler(
pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
pm.addPass(createMemRefToLLVMPass());
pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass());
+ pm.addNestedPass<mlir::func::FuncOp>(
+ mlir::arith::createArithmeticExpandOpsPass());
pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
pm.addPass(createConvertMathToLibmPass());
pm.addPass(createConvertComplexToLibmPass());
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir
index e29cb36a735ca..bfce26d0ed23b 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_binary.mlir
@@ -43,27 +43,26 @@
module {
// Creates a new sparse vector using the minimum values from two input sparse vectors.
// When there is no overlap, include the present value in the output.
- func.func @vector_min(%arga: tensor<?xf64, #SparseVector>,
- %argb: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
+ func.func @vector_min(%arga: tensor<?xi32, #SparseVector>,
+ %argb: tensor<?xi32, #SparseVector>) -> tensor<?xi32, #SparseVector> {
%c = arith.constant 0 : index
- %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
- %xv = bufferization.alloc_tensor(%d) : tensor<?xf64, #SparseVector>
+ %d = tensor.dim %arga, %c : tensor<?xi32, #SparseVector>
+ %xv = bufferization.alloc_tensor(%d) : tensor<?xi32, #SparseVector>
%0 = linalg.generic #trait_vec_op
- ins(%arga, %argb: tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>)
- outs(%xv: tensor<?xf64, #SparseVector>) {
- ^bb(%a: f64, %b: f64, %x: f64):
- %1 = sparse_tensor.binary %a, %b : f64, f64 to f64
+ ins(%arga, %argb: tensor<?xi32, #SparseVector>, tensor<?xi32, #SparseVector>)
+ outs(%xv: tensor<?xi32, #SparseVector>) {
+ ^bb(%a: i32, %b: i32, %x: i32):
+ %1 = sparse_tensor.binary %a, %b : i32, i32 to i32
overlap={
- ^bb0(%a0: f64, %b0: f64):
- %cmp = arith.cmpf "olt", %a0, %b0 : f64
- %2 = arith.select %cmp, %a0, %b0: f64
- sparse_tensor.yield %2 : f64
+ ^bb0(%a0: i32, %b0: i32):
+ %2 = arith.minsi %a0, %b0: i32
+ sparse_tensor.yield %2 : i32
}
left=identity
right=identity
- linalg.yield %1 : f64
- } -> tensor<?xf64, #SparseVector>
- return %0 : tensor<?xf64, #SparseVector>
+ linalg.yield %1 : i32
+ } -> tensor<?xi32, #SparseVector>
+ return %0 : tensor<?xi32, #SparseVector>
}
// Creates a new sparse vector by multiplying a sparse vector with a dense vector.
@@ -428,8 +427,13 @@ module {
[0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 0., 1.]
> : tensor<32xf64>
+ %v1_si = arith.fptosi %v1 : tensor<32xf64> to tensor<32xi32>
+ %v2_si = arith.fptosi %v2 : tensor<32xf64> to tensor<32xi32>
+
%sv1 = sparse_tensor.convert %v1 : tensor<32xf64> to tensor<?xf64, #SparseVector>
%sv2 = sparse_tensor.convert %v2 : tensor<32xf64> to tensor<?xf64, #SparseVector>
+ %sv1_si = sparse_tensor.convert %v1_si : tensor<32xi32> to tensor<?xi32, #SparseVector>
+ %sv2_si = sparse_tensor.convert %v2_si : tensor<32xi32> to tensor<?xi32, #SparseVector>
%dv3 = tensor.cast %v3 : tensor<32xf64> to tensor<?xf64>
// Setup sparse matrices.
@@ -459,9 +463,9 @@ module {
%sm4 = sparse_tensor.convert %m4 : tensor<4x4xf64> to tensor<4x4xf64, #DCSR>
// Call sparse vector kernels.
- %0 = call @vector_min(%sv1, %sv2)
- : (tensor<?xf64, #SparseVector>,
- tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
+ %0 = call @vector_min(%sv1_si, %sv2_si)
+ : (tensor<?xi32, #SparseVector>,
+ tensor<?xi32, #SparseVector>) -> tensor<?xi32, #SparseVector>
%1 = call @vector_mul(%sv1, %dv3)
: (tensor<?xf64, #SparseVector>,
tensor<?xf64>) -> tensor<?xf64, #SparseVector>
@@ -494,7 +498,7 @@ module {
// CHECK-NEXT: ( 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 4, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 8, 0, 9 )
// CHECK-NEXT: ( 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, -1, -1, -1, -1, -1, -1 )
// CHECK-NEXT: ( 0, 11, 0, 12, 13, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 15, 0, 16, 0, 0, 17, 0, 0, 0, 0, 0, 0, 18, 19, 0, 20 )
- // CHECK-NEXT: ( 1, 11, 2, 13, 14, 3, 15, 4, 16, 5, 6, 7, 8, 9, -1, -1 )
+ // CHECK-NEXT: ( 1, 11, 2, 13, 14, 3, 15, 4, 16, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
// CHECK-NEXT: ( 1, 11, 0, 2, 13, 0, 0, 0, 0, 0, 14, 3, 0, 0, 0, 0, 15, 4, 16, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 8, 0, 9 )
// CHECK-NEXT: ( 0, 6, 3, 28, 0, 6, 56, 72, 9, -1, -1, -1, -1, -1, -1, -1 )
// CHECK-NEXT: ( 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 28, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 56, 72, 0, 9 )
@@ -518,7 +522,7 @@ module {
//
call @dump_vec(%sv1) : (tensor<?xf64, #SparseVector>) -> ()
call @dump_vec(%sv2) : (tensor<?xf64, #SparseVector>) -> ()
- call @dump_vec(%0) : (tensor<?xf64, #SparseVector>) -> ()
+ call @dump_vec_i32(%0) : (tensor<?xi32, #SparseVector>) -> ()
call @dump_vec(%1) : (tensor<?xf64, #SparseVector>) -> ()
call @dump_vec(%2) : (tensor<?xf64, #SparseVector>) -> ()
call @dump_vec_i32(%3) : (tensor<?xi32, #SparseVector>) -> ()
@@ -537,7 +541,7 @@ module {
bufferization.dealloc_tensor %sm2 : tensor<?x?xf64, #DCSR>
bufferization.dealloc_tensor %sm3 : tensor<4x4xf64, #DCSR>
bufferization.dealloc_tensor %sm4 : tensor<4x4xf64, #DCSR>
- bufferization.dealloc_tensor %0 : tensor<?xf64, #SparseVector>
+ bufferization.dealloc_tensor %0 : tensor<?xi32, #SparseVector>
bufferization.dealloc_tensor %1 : tensor<?xf64, #SparseVector>
bufferization.dealloc_tensor %2 : tensor<?xf64, #SparseVector>
bufferization.dealloc_tensor %3 : tensor<?xi32, #SparseVector>
More information about the Mlir-commits
mailing list