[Mlir-commits] [mlir] 275fe3a - [mlir][sparse] support complex type for sparse_tensor.print (#83934)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 4 17:14:35 PST 2024


Author: Aart Bik
Date: 2024-03-04T17:14:31-08:00
New Revision: 275fe3ae2dced8275a1dd85a4f892fee99d322e2

URL: https://github.com/llvm/llvm-project/commit/275fe3ae2dced8275a1dd85a4f892fee99d322e2
DIFF: https://github.com/llvm/llvm-project/commit/275fe3ae2dced8275a1dd85a4f892fee99d322e2.diff

LOG: [mlir][sparse] support complex type for sparse_tensor.print (#83934)

With an integration test example

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 158845d88a4478..a65bce78d095cf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -692,7 +692,21 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
     rewriter.setInsertionPointToStart(forOp.getBody());
     auto idx = forOp.getInductionVar();
     auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
-    rewriter.create<vector::PrintOp>(loc, val, vector::PrintPunctuation::Comma);
+    if (llvm::isa<ComplexType>(val.getType())) {
+      // Since the vector dialect does not support complex types in any op,
+      // we split those into (real, imag) pairs here.
+      Value real = rewriter.create<complex::ReOp>(loc, val);
+      Value imag = rewriter.create<complex::ImOp>(loc, val);
+      rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
+      rewriter.create<vector::PrintOp>(loc, real,
+                                       vector::PrintPunctuation::Comma);
+      rewriter.create<vector::PrintOp>(loc, imag,
+                                       vector::PrintPunctuation::Close);
+      rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
+    } else {
+      rewriter.create<vector::PrintOp>(loc, val,
+                                       vector::PrintPunctuation::Comma);
+    }
     rewriter.setInsertionPointAfter(forOp);
     // Close bracket and end of line.
     rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
index f233a92fa14a7f..c4fc8b08078775 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
@@ -10,7 +10,7 @@
 // DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
 // DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
 // DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
-// DEFINE: %{run_opts} = -e entry -entry-point-result=void
+// DEFINE: %{run_opts} = -e main -entry-point-result=void
 // DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
 // DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
 //
@@ -162,31 +162,8 @@ module {
     return %0 : tensor<?xf64, #SparseVector>
   }
 
-  func.func @dumpc(%arg0: tensor<?xcomplex<f64>, #SparseVector>, %d: index) {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %mem = sparse_tensor.values %arg0 : tensor<?xcomplex<f64>, #SparseVector> to memref<?xcomplex<f64>>
-    scf.for %i = %c0 to %d step %c1 {
-       %v = memref.load %mem[%i] : memref<?xcomplex<f64>>
-       %real = complex.re %v : complex<f64>
-       %imag = complex.im %v : complex<f64>
-       vector.print %real : f64
-       vector.print %imag : f64
-    }
-    return
-  }
-
-  func.func @dumpf(%arg0: tensor<?xf64, #SparseVector>) {
-    %c0 = arith.constant 0 : index
-    %d0 = arith.constant 0.0 : f64
-    %values = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf64>
-    %0 = vector.transfer_read %values[%c0], %d0: memref<?xf64>, vector<3xf64>
-    vector.print %0 : vector<3xf64>
-    return
-  }
-
   // Driver method to call and verify complex kernels.
-  func.func @entry() {
+  func.func @main() {
     // Setup sparse vectors.
     %v1 = arith.constant sparse<
        [ [0], [28], [31] ],
@@ -217,54 +194,76 @@ module {
     //
     // Verify the results.
     //
-    %d3 = arith.constant 3 : index
-    %d4 = arith.constant 4 : index
-    // CHECK: -5.13
-    // CHECK-NEXT: 2
-    // CHECK-NEXT: 1
-    // CHECK-NEXT: 0
-    // CHECK-NEXT: 1
-    // CHECK-NEXT: 4
-    // CHECK-NEXT: 8
-    // CHECK-NEXT: 6
-    call @dumpc(%0, %d4) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
-    // CHECK-NEXT: 3.43887
-    // CHECK-NEXT: 1.47097
-    // CHECK-NEXT: 3.85374
-    // CHECK-NEXT: -27.0168
-    // CHECK-NEXT: -193.43
-    // CHECK-NEXT: 57.2184
-    call @dumpc(%1, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
-    // CHECK-NEXT: 0.433635
-    // CHECK-NEXT: 2.30609
-    // CHECK-NEXT: 2
-    // CHECK-NEXT: 1
-    // CHECK-NEXT: 2.53083
-    // CHECK-NEXT: 1.18538
-    call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
-    // CHECK-NEXT: 0.761594
-    // CHECK-NEXT: 0
-    // CHECK-NEXT: -0.964028
-    // CHECK-NEXT: 0
-    // CHECK-NEXT: 0.995055
-    // CHECK-NEXT: 0
-    call @dumpc(%3, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
-    // CHECK-NEXT: -5.13
-    // CHECK-NEXT: 2
-    // CHECK-NEXT: 3
-    // CHECK-NEXT: 4
-    // CHECK-NEXT: 5
-    // CHECK-NEXT: 6
-    call @dumpc(%4, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
-    // CHECK-NEXT: -2.565
-    // CHECK-NEXT: 1
-    // CHECK-NEXT: 1.5
-    // CHECK-NEXT: 2
-    // CHECK-NEXT: 2.5
-    // CHECK-NEXT: 3
-    call @dumpc(%5, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
-    // CHECK-NEXT: ( 5.50608, 5, 7.81025 )
-    call @dumpf(%6) : (tensor<?xf64, #SparseVector>) -> ()
+    // CHECK:      ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 4
+    // CHECK-NEXT: dim = ( 32 )
+    // CHECK-NEXT: lvl = ( 32 )
+    // CHECK-NEXT: pos[0] : ( 0, 4,
+    // CHECK-NEXT: crd[0] : ( 0, 1, 28, 31,
+    // CHECK-NEXT: values : ( ( -5.13, 2 ), ( 1, 0 ), ( 1, 4 ), ( 8, 6 ),
+    // CHECK-NEXT: ----
+    //
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 3
+    // CHECK-NEXT: dim = ( 32 )
+    // CHECK-NEXT: lvl = ( 32 )
+    // CHECK-NEXT: pos[0] : ( 0, 3,
+    // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+    // CHECK-NEXT: values : ( ( 3.43887, 1.47097 ), ( 3.85374, -27.0168 ), ( -193.43, 57.2184 ),
+    // CHECK-NEXT: ----
+    //
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 3
+    // CHECK-NEXT: dim = ( 32 )
+    // CHECK-NEXT: lvl = ( 32 )
+    // CHECK-NEXT: pos[0] : ( 0, 3,
+    // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+    // CHECK-NEXT: values : ( ( 0.433635, 2.30609 ), ( 2, 1 ), ( 2.53083, 1.18538 ),
+    // CHECK-NEXT: ----
+    //
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 3
+    // CHECK-NEXT: dim = ( 32 )
+    // CHECK-NEXT: lvl = ( 32 )
+    // CHECK-NEXT: pos[0] : ( 0, 3,
+    // CHECK-NEXT: crd[0] : ( 1, 28, 31,
+    // CHECK-NEXT: values : ( ( 0.761594, 0 ), ( -0.964028, 0 ), ( 0.995055, 0 ),
+    // CHECK-NEXT: ----
+    //
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 3
+    // CHECK-NEXT: dim = ( 32 )
+    // CHECK-NEXT: lvl = ( 32 )
+    // CHECK-NEXT: pos[0] : ( 0, 3,
+    // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+    // CHECK-NEXT: values : ( ( -5.13, 2 ), ( 3, 4 ), ( 5, 6 ),
+    // CHECK-NEXT: ----
+    //
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 3
+    // CHECK-NEXT: dim = ( 32 )
+    // CHECK-NEXT: lvl = ( 32 )
+    // CHECK-NEXT: pos[0] : ( 0, 3,
+    // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+    // CHECK-NEXT: values : ( ( -2.565, 1 ), ( 1.5, 2 ), ( 2.5, 3 ),
+    // CHECK-NEXT: ----
+    //
+    // CHECK-NEXT: ---- Sparse Tensor ----
+    // CHECK-NEXT: nse = 3
+    // CHECK-NEXT: dim = ( 32 )
+    // CHECK-NEXT: lvl = ( 32 )
+    // CHECK-NEXT: pos[0] : ( 0, 3,
+    // CHECK-NEXT: crd[0] : ( 0, 28, 31,
+    // CHECK-NEXT: values : ( 5.50608, 5, 7.81025,
+    // CHECK-NEXT: ----
+    //
+    sparse_tensor.print %0 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.print %1 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.print %2 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.print %3 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.print %4 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.print %5 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.print %6 : tensor<?xf64, #SparseVector>
 
     // Release the resources.
     bufferization.dealloc_tensor %sv1 : tensor<?xcomplex<f64>, #SparseVector>


        


More information about the Mlir-commits mailing list