[Mlir-commits] [mlir] [mlir][sparse] support complex type for sparse_tensor.print (PR #83934)
Aart Bik
llvmlistbot at llvm.org
Mon Mar 4 17:07:19 PST 2024
https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/83934
With an integration test example
>From e43e242c01456f60429a769a630e81d5225400ab Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Mon, 4 Mar 2024 17:04:13 -0800
Subject: [PATCH] [mlir][sparse] support complex type for sparse_tensor.print
With an integration test example
---
.../Transforms/SparseTensorRewriting.cpp | 16 +-
.../SparseTensor/CPU/sparse_complex_ops.mlir | 145 +++++++++---------
2 files changed, 87 insertions(+), 74 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 158845d88a4478..a65bce78d095cf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -692,7 +692,21 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
rewriter.setInsertionPointToStart(forOp.getBody());
auto idx = forOp.getInductionVar();
auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
- rewriter.create<vector::PrintOp>(loc, val, vector::PrintPunctuation::Comma);
+ if (llvm::isa<ComplexType>(val.getType())) {
+ // Since the vector dialect does not support complex types in any op,
+ // we split those into (real, imag) pairs here.
+ Value real = rewriter.create<complex::ReOp>(loc, val);
+ Value imag = rewriter.create<complex::ImOp>(loc, val);
+ rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
+ rewriter.create<vector::PrintOp>(loc, real,
+ vector::PrintPunctuation::Comma);
+ rewriter.create<vector::PrintOp>(loc, imag,
+ vector::PrintPunctuation::Close);
+ rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
+ } else {
+ rewriter.create<vector::PrintOp>(loc, val,
+ vector::PrintPunctuation::Comma);
+ }
rewriter.setInsertionPointAfter(forOp);
// Close bracket and end of line.
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
index f233a92fa14a7f..c4fc8b08078775 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
@@ -10,7 +10,7 @@
// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
-// DEFINE: %{run_opts} = -e entry -entry-point-result=void
+// DEFINE: %{run_opts} = -e main -entry-point-result=void
// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
//
@@ -162,31 +162,8 @@ module {
return %0 : tensor<?xf64, #SparseVector>
}
- func.func @dumpc(%arg0: tensor<?xcomplex<f64>, #SparseVector>, %d: index) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %mem = sparse_tensor.values %arg0 : tensor<?xcomplex<f64>, #SparseVector> to memref<?xcomplex<f64>>
- scf.for %i = %c0 to %d step %c1 {
- %v = memref.load %mem[%i] : memref<?xcomplex<f64>>
- %real = complex.re %v : complex<f64>
- %imag = complex.im %v : complex<f64>
- vector.print %real : f64
- vector.print %imag : f64
- }
- return
- }
-
- func.func @dumpf(%arg0: tensor<?xf64, #SparseVector>) {
- %c0 = arith.constant 0 : index
- %d0 = arith.constant 0.0 : f64
- %values = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf64>
- %0 = vector.transfer_read %values[%c0], %d0: memref<?xf64>, vector<3xf64>
- vector.print %0 : vector<3xf64>
- return
- }
-
// Driver method to call and verify complex kernels.
- func.func @entry() {
+ func.func @main() {
// Setup sparse vectors.
%v1 = arith.constant sparse<
[ [0], [28], [31] ],
@@ -217,54 +194,76 @@ module {
//
// Verify the results.
//
- %d3 = arith.constant 3 : index
- %d4 = arith.constant 4 : index
- // CHECK: -5.13
- // CHECK-NEXT: 2
- // CHECK-NEXT: 1
- // CHECK-NEXT: 0
- // CHECK-NEXT: 1
- // CHECK-NEXT: 4
- // CHECK-NEXT: 8
- // CHECK-NEXT: 6
- call @dumpc(%0, %d4) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
- // CHECK-NEXT: 3.43887
- // CHECK-NEXT: 1.47097
- // CHECK-NEXT: 3.85374
- // CHECK-NEXT: -27.0168
- // CHECK-NEXT: -193.43
- // CHECK-NEXT: 57.2184
- call @dumpc(%1, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
- // CHECK-NEXT: 0.433635
- // CHECK-NEXT: 2.30609
- // CHECK-NEXT: 2
- // CHECK-NEXT: 1
- // CHECK-NEXT: 2.53083
- // CHECK-NEXT: 1.18538
- call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
- // CHECK-NEXT: 0.761594
- // CHECK-NEXT: 0
- // CHECK-NEXT: -0.964028
- // CHECK-NEXT: 0
- // CHECK-NEXT: 0.995055
- // CHECK-NEXT: 0
- call @dumpc(%3, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
- // CHECK-NEXT: -5.13
- // CHECK-NEXT: 2
- // CHECK-NEXT: 3
- // CHECK-NEXT: 4
- // CHECK-NEXT: 5
- // CHECK-NEXT: 6
- call @dumpc(%4, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
- // CHECK-NEXT: -2.565
- // CHECK-NEXT: 1
- // CHECK-NEXT: 1.5
- // CHECK-NEXT: 2
- // CHECK-NEXT: 2.5
- // CHECK-NEXT: 3
- call @dumpc(%5, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
- // CHECK-NEXT: ( 5.50608, 5, 7.81025 )
- call @dumpf(%6) : (tensor<?xf64, #SparseVector>) -> ()
+ // CHECK: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 4
+ // CHECK-NEXT: dim = ( 32 )
+ // CHECK-NEXT: lvl = ( 32 )
+ // CHECK-NEXT: pos[0] : ( 0, 4,
+ // CHECK-NEXT: crd[0] : ( 0, 1, 28, 31,
+ // CHECK-NEXT: values : ( ( -5.13, 2 ), ( 1, 0 ), ( 1, 4 ), ( 8, 6 ),
+ // CHECK-NEXT: ----
+ //
+ // CHECK-NEXT: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 3
+ // CHECK-NEXT: dim = ( 32 )
+ // CHECK-NEXT: lvl = ( 32 )
+ // CHECK-NEXT: pos[0] : ( 0, 3,
+ // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+ // CHECK-NEXT: values : ( ( 3.43887, 1.47097 ), ( 3.85374, -27.0168 ), ( -193.43, 57.2184 ),
+ // CHECK-NEXT: ----
+ //
+ // CHECK-NEXT: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 3
+ // CHECK-NEXT: dim = ( 32 )
+ // CHECK-NEXT: lvl = ( 32 )
+ // CHECK-NEXT: pos[0] : ( 0, 3,
+ // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+ // CHECK-NEXT: values : ( ( 0.433635, 2.30609 ), ( 2, 1 ), ( 2.53083, 1.18538 ),
+ // CHECK-NEXT: ----
+ //
+ // CHECK-NEXT: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 3
+ // CHECK-NEXT: dim = ( 32 )
+ // CHECK-NEXT: lvl = ( 32 )
+ // CHECK-NEXT: pos[0] : ( 0, 3,
+ // CHECK-NEXT: crd[0] : ( 1, 28, 31,
+ // CHECK-NEXT: values : ( ( 0.761594, 0 ), ( -0.964028, 0 ), ( 0.995055, 0 ),
+ // CHECK-NEXT: ----
+ //
+ // CHECK-NEXT: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 3
+ // CHECK-NEXT: dim = ( 32 )
+ // CHECK-NEXT: lvl = ( 32 )
+ // CHECK-NEXT: pos[0] : ( 0, 3,
+ // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+ // CHECK-NEXT: values : ( ( -5.13, 2 ), ( 3, 4 ), ( 5, 6 ),
+ // CHECK-NEXT: ----
+ //
+ // CHECK-NEXT: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 3
+ // CHECK-NEXT: dim = ( 32 )
+ // CHECK-NEXT: lvl = ( 32 )
+ // CHECK-NEXT: pos[0] : ( 0, 3,
+ // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+ // CHECK-NEXT: values : ( ( -2.565, 1 ), ( 1.5, 2 ), ( 2.5, 3 ),
+ // CHECK-NEXT: ----
+ //
+ // CHECK-NEXT: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 3
+ // CHECK-NEXT: dim = ( 32 )
+ // CHECK-NEXT: lvl = ( 32 )
+ // CHECK-NEXT: pos[0] : ( 0, 3,
+ // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+ // CHECK-NEXT: values : ( 5.50608, 5, 7.81025,
+ // CHECK-NEXT: ----
+ //
+ sparse_tensor.print %0 : tensor<?xcomplex<f64>, #SparseVector>
+ sparse_tensor.print %1 : tensor<?xcomplex<f64>, #SparseVector>
+ sparse_tensor.print %2 : tensor<?xcomplex<f64>, #SparseVector>
+ sparse_tensor.print %3 : tensor<?xcomplex<f64>, #SparseVector>
+ sparse_tensor.print %4 : tensor<?xcomplex<f64>, #SparseVector>
+ sparse_tensor.print %5 : tensor<?xcomplex<f64>, #SparseVector>
+ sparse_tensor.print %6 : tensor<?xf64, #SparseVector>
// Release the resources.
bufferization.dealloc_tensor %sv1 : tensor<?xcomplex<f64>, #SparseVector>
More information about the Mlir-commits
mailing list