[Mlir-commits] [mlir] [mlir][sparse] support complex type for sparse_tensor.print (PR #83934)

Aart Bik llvmlistbot at llvm.org
Mon Mar 4 17:07:19 PST 2024


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/83934

With an integration test example

>From e43e242c01456f60429a769a630e81d5225400ab Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Mon, 4 Mar 2024 17:04:13 -0800
Subject: [PATCH] [mlir][sparse] support complex type for sparse_tensor.print

With an integration test example
---
 .../Transforms/SparseTensorRewriting.cpp      |  16 +-
 .../SparseTensor/CPU/sparse_complex_ops.mlir  | 145 +++++++++---------
 2 files changed, 87 insertions(+), 74 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 158845d88a4478..a65bce78d095cf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -692,7 +692,21 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
     rewriter.setInsertionPointToStart(forOp.getBody());
     auto idx = forOp.getInductionVar();
     auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
-    rewriter.create<vector::PrintOp>(loc, val, vector::PrintPunctuation::Comma);
+    if (llvm::isa<ComplexType>(val.getType())) {
+      // Since the vector dialect does not support complex types in any op,
+      // we split those into (real, imag) pairs here.
+      Value real = rewriter.create<complex::ReOp>(loc, val);
+      Value imag = rewriter.create<complex::ImOp>(loc, val);
+      rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
+      rewriter.create<vector::PrintOp>(loc, real,
+                                       vector::PrintPunctuation::Comma);
+      rewriter.create<vector::PrintOp>(loc, imag,
+                                       vector::PrintPunctuation::Close);
+      rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
+    } else {
+      rewriter.create<vector::PrintOp>(loc, val,
+                                       vector::PrintPunctuation::Comma);
+    }
     rewriter.setInsertionPointAfter(forOp);
     // Close bracket and end of line.
     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
index f233a92fa14a7f..c4fc8b08078775 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
@@ -10,7 +10,7 @@
 // DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
 // DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
 // DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
-// DEFINE: %{run_opts} = -e entry -entry-point-result=void
+// DEFINE: %{run_opts} = -e main -entry-point-result=void
 // DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
 // DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
 //
@@ -162,31 +162,8 @@ module {
     return %0 : tensor<?xf64, #SparseVector>
   }
 
-  func.func @dumpc(%arg0: tensor<?xcomplex<f64>, #SparseVector>, %d: index) {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %mem = sparse_tensor.values %arg0 : tensor<?xcomplex<f64>, #SparseVector> to memref<?xcomplex<f64>>
-    scf.for %i = %c0 to %d step %c1 {
-       %v = memref.load %mem[%i] : memref<?xcomplex<f64>>
-       %real = complex.re %v : complex<f64>
-       %imag = complex.im %v : complex<f64>
-       vector.print %real : f64
-       vector.print %imag : f64
-    }
-    return
-  }
-
-  func.func @dumpf(%arg0: tensor<?xf64, #SparseVector>) {
-    %c0 = arith.constant 0 : index
-    %d0 = arith.constant 0.0 : f64
-    %values = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf64>
-    %0 = vector.transfer_read %values[%c0], %d0: memref<?xf64>, vector<3xf64>
-    vector.print %0 : vector<3xf64>
-    return
-  }
-
   // Driver method to call and verify complex kernels.
-  func.func @entry() {
+  func.func @main() {
     // Setup sparse vectors.
     %v1 = arith.constant sparse<
        [ [0], [28], [31] ],
@@ -217,54 +194,76 @@ module {
     //
     // Verify the results.
     //
-    %d3 = arith.constant 3 : index
-    %d4 = arith.constant 4 : index
-    // CHECK: -5.13
-    // CHECK-NEXT: 2
-    // CHECK-NEXT: 1
-    // CHECK-NEXT: 0
-    // CHECK-NEXT: 1
-    // CHECK-NEXT: 4
-    // CHECK-NEXT: 8
-    // CHECK-NEXT: 6
-    call @dumpc(%0, %d4) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
-    // CHECK-NEXT: 3.43887
-    // CHECK-NEXT: 1.47097
-    // CHECK-NEXT: 3.85374
-    // CHECK-NEXT: -27.0168
-    // CHECK-NEXT: -193.43
-    // CHECK-NEXT: 57.2184
-    call @dumpc(%1, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
-    // CHECK-NEXT: 0.433635
-    // CHECK-NEXT: 2.30609
-    // CHECK-NEXT: 2
-    // CHECK-NEXT: 1
-    // CHECK-NEXT: 2.53083
-    // CHECK-NEXT: 1.18538
-    call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
-    // CHECK-NEXT: 0.761594
-    // CHECK-NEXT: 0
-    // CHECK-NEXT: -0.964028
-    // CHECK-NEXT: 0
-    // CHECK-NEXT: 0.995055
-    // CHECK-NEXT: 0
-    call @dumpc(%3, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
-    // CHECK-NEXT: -5.13
-    // CHECK-NEXT: 2
-    // CHECK-NEXT: 3
-    // CHECK-NEXT: 4
-    // CHECK-NEXT: 5
-    // CHECK-NEXT: 6
-    call @dumpc(%4, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
-    // CHECK-NEXT: -2.565
-    // CHECK-NEXT: 1
-    // CHECK-NEXT: 1.5
-    // CHECK-NEXT: 2
-    // CHECK-NEXT: 2.5
-    // CHECK-NEXT: 3
-    call @dumpc(%5, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
-    // CHECK-NEXT: ( 5.50608, 5, 7.81025 )
-    call @dumpf(%6) : (tensor<?xf64, #SparseVector>) -> ()
+    // CHECK:      ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 4
+    // CHECK-NEXT: dim = ( 32 )
+    // CHECK-NEXT: lvl = ( 32 )
+    // CHECK-NEXT: pos[0] : ( 0, 4,
+    // CHECK-NEXT: crd[0] : ( 0, 1, 28, 31,
+    // CHECK-NEXT: values : ( ( -5.13, 2 ), ( 1, 0 ), ( 1, 4 ), ( 8, 6 ),
+    // CHECK-NEXT: ----
+    //
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 3
+    // CHECK-NEXT: dim = ( 32 )
+    // CHECK-NEXT: lvl = ( 32 )
+    // CHECK-NEXT: pos[0] : ( 0, 3,
+    // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+    // CHECK-NEXT: values : ( ( 3.43887, 1.47097 ), ( 3.85374, -27.0168 ), ( -193.43, 57.2184 ),
+    // CHECK-NEXT: ----
+    //
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 3
+    // CHECK-NEXT: dim = ( 32 )
+    // CHECK-NEXT: lvl = ( 32 )
+    // CHECK-NEXT: pos[0] : ( 0, 3,
+    // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+    // CHECK-NEXT: values : ( ( 0.433635, 2.30609 ), ( 2, 1 ), ( 2.53083, 1.18538 ),
+    // CHECK-NEXT: ----
+    //
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 3
+    // CHECK-NEXT: dim = ( 32 )
+    // CHECK-NEXT: lvl = ( 32 )
+    // CHECK-NEXT: pos[0] : ( 0, 3,
+    // CHECK-NEXT: crd[0] : ( 1, 28, 31,
+    // CHECK-NEXT: values : ( ( 0.761594, 0 ), ( -0.964028, 0 ), ( 0.995055, 0 ),
+    // CHECK-NEXT: ----
+    //
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 3
+    // CHECK-NEXT: dim = ( 32 )
+    // CHECK-NEXT: lvl = ( 32 )
+    // CHECK-NEXT: pos[0] : ( 0, 3,
+    // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+    // CHECK-NEXT: values : ( ( -5.13, 2 ), ( 3, 4 ), ( 5, 6 ),
+    // CHECK-NEXT: ----
+    //
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 3
+    // CHECK-NEXT: dim = ( 32 )
+    // CHECK-NEXT: lvl = ( 32 )
+    // CHECK-NEXT: pos[0] : ( 0, 3,
+    // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+    // CHECK-NEXT: values : ( ( -2.565, 1 ), ( 1.5, 2 ), ( 2.5, 3 ),
+    // CHECK-NEXT: ----
+    //
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 3
+    // CHECK-NEXT: dim = ( 32 )
+    // CHECK-NEXT: lvl = ( 32 )
+    // CHECK-NEXT: pos[0] : ( 0, 3,
+    // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+    // CHECK-NEXT: values : ( 5.50608, 5, 7.81025,
+    // CHECK-NEXT: ----
+    //
+    sparse_tensor.print %0 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.print %1 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.print %2 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.print %3 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.print %4 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.print %5 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.print %6 : tensor<?xf64, #SparseVector>
 
     // Release the resources.
     bufferization.dealloc_tensor %sv1 : tensor<?xcomplex<f64>, #SparseVector>



More information about the Mlir-commits mailing list