[Mlir-commits] [mlir] 69edacb - [mlir][sparse] Add support for complex.im and complex.re to the sparse compiler.

Bixia Zheng llvmlistbot at llvm.org
Wed May 18 08:53:11 PDT 2022


Author: Bixia Zheng
Date: 2022-05-18T15:53:07Z
New Revision: 69edacbcf0c232de6213297cb600b0f0313c6397

URL: https://github.com/llvm/llvm-project/commit/69edacbcf0c232de6213297cb600b0f0313c6397
DIFF: https://github.com/llvm/llvm-project/commit/69edacbcf0c232de6213297cb600b0f0313c6397.diff

LOG: [mlir][sparse] Add support for complex.im and complex.re to the sparse compiler.

Add a test.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D125834

Added: 
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index d4aafb74093e..ceee28466449 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -50,6 +50,8 @@ enum Kind {
   kCastU,  // unsigned
   kCastIdx,
   kTruncI,
+  kCIm, // complex.im
+  kCRe, // complex.re
   kBitCast,
   kBinaryBranch, // semiring unary branch created from a binary op
   kUnary,        // semiring unary op

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index c5cd2e4f4cb6..9a589965a528 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -46,6 +46,8 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
   case kTanhF:
   case kNegF:
   case kNegI:
+  case kCIm:
+  case kCRe:
     assert(x != -1u && y == -1u && !v && !o);
     children.e0 = x;
     children.e1 = y;
@@ -291,6 +293,8 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
   case kCastU:
   case kCastIdx:
   case kTruncI:
+  case kCIm:
+  case kCRe:
   case kBitCast:
     return isSingleCondition(t, tensorExps[e].children.e0);
   case kDivF: // note: x / c only
@@ -367,6 +371,10 @@ static const char *kindToOpSymbol(Kind kind) {
   case kCastU:
   case kCastIdx:
   case kTruncI:
+  case kCIm:
+    return "complex.im";
+  case kCRe:
+    return "complex.re";
   case kBitCast:
     return "cast";
   case kBinaryBranch:
@@ -526,6 +534,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
   }
   case kAbsF:
   case kCeilF:
+  case kCIm:
+  case kCRe:
   case kFloorF:
   case kSqrtF:
   case kExpm1F:
@@ -776,6 +786,10 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
         return addExp(kCastIdx, e, v);
       if (isa<arith::TruncIOp>(def))
         return addExp(kTruncI, e, v);
+      if (isa<complex::ImOp>(def))
+        return addExp(kCIm, e);
+      if (isa<complex::ReOp>(def))
+        return addExp(kCRe, e);
       if (isa<arith::BitcastOp>(def))
         return addExp(kBitCast, e, v);
       if (isa<sparse_tensor::UnaryOp>(def))
@@ -930,6 +944,15 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
   case kTruncI:
     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
+  case kCIm:
+  case kCRe: {
+    auto type = v0.getType().template cast<ComplexType>();
+    auto eltType = type.getElementType().template cast<FloatType>();
+    if (tensorExps[e].kind == kCIm)
+      return rewriter.create<complex::ImOp>(loc, eltType, v0);
+
+    return rewriter.create<complex::ReOp>(loc, eltType, v0);
+  }
   case kBitCast:
     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
   // Binary ops.

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir
new file mode 100644
index 000000000000..2656eb555165
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir
@@ -0,0 +1,93 @@
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
+#trait_op = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // a (in)
+    affine_map<(i) -> (i)>   // x (out)
+  ],
+  iterator_types = ["parallel"],
+  doc = "x(i) = OP a(i)"
+}
+
+module {
+  func.func @cre(%arga: tensor<?xcomplex<f32>, #SparseVector>)
+                -> tensor<?xf32, #SparseVector> {
+    %c = arith.constant 0 : index
+    %d = tensor.dim %arga, %c : tensor<?xcomplex<f32>, #SparseVector>
+    %xv = sparse_tensor.init [%d] : tensor<?xf32, #SparseVector>
+    %0 = linalg.generic #trait_op
+       ins(%arga: tensor<?xcomplex<f32>, #SparseVector>)
+        outs(%xv: tensor<?xf32, #SparseVector>) {
+        ^bb(%a: complex<f32>, %x: f32):
+          %1 = complex.re %a : complex<f32>
+          linalg.yield %1 : f32
+    } -> tensor<?xf32, #SparseVector>
+    return %0 : tensor<?xf32, #SparseVector>
+  }
+
+  func.func @cim(%arga: tensor<?xcomplex<f32>, #SparseVector>)
+                -> tensor<?xf32, #SparseVector> {
+    %c = arith.constant 0 : index
+    %d = tensor.dim %arga, %c : tensor<?xcomplex<f32>, #SparseVector>
+    %xv = sparse_tensor.init [%d] : tensor<?xf32, #SparseVector>
+    %0 = linalg.generic #trait_op
+       ins(%arga: tensor<?xcomplex<f32>, #SparseVector>)
+        outs(%xv: tensor<?xf32, #SparseVector>) {
+        ^bb(%a: complex<f32>, %x: f32):
+          %1 = complex.im %a : complex<f32>
+          linalg.yield %1 : f32
+    } -> tensor<?xf32, #SparseVector>
+    return %0 : tensor<?xf32, #SparseVector>
+  }
+
+  func.func @dump(%arg0: tensor<?xf32, #SparseVector>) {
+    %c0 = arith.constant 0 : index
+    %d0 = arith.constant -1.0 : f32
+    %values = sparse_tensor.values %arg0 : tensor<?xf32, #SparseVector> to memref<?xf32>
+    %0 = vector.transfer_read %values[%c0], %d0: memref<?xf32>, vector<4xf32>
+    vector.print %0 : vector<4xf32>
+    %indices = sparse_tensor.indices %arg0, %c0 : tensor<?xf32, #SparseVector> to memref<?xindex>
+    %1 = vector.transfer_read %indices[%c0], %c0: memref<?xindex>, vector<4xindex>
+    vector.print %1 : vector<4xindex>
+    return
+  }
+
+  // Driver method to call and verify functions cim and cre.
+  func.func @entry() {
+    // Setup sparse vectors.
+    %v1 = arith.constant sparse<
+       [ [0], [20], [31] ],
+         [ (5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > : tensor<32xcomplex<f32>>
+    %sv1 = sparse_tensor.convert %v1 : tensor<32xcomplex<f32>> to tensor<?xcomplex<f32>, #SparseVector>
+
+    // Call sparse vector kernels.
+    %0 = call @cre(%sv1)
+       : (tensor<?xcomplex<f32>, #SparseVector>) -> tensor<?xf32, #SparseVector>
+
+    %1 = call @cim(%sv1)
+       : (tensor<?xcomplex<f32>, #SparseVector>) -> tensor<?xf32, #SparseVector>
+
+    //
+    // Verify the results.
+    //
+    // CHECK: ( 5.13, 3, 5, -1 )
+    // CHECK-NEXT: ( 0, 20, 31, 0 )
+    // CHECK-NEXT: ( 2, 4, 6, -1 )
+    // CHECK-NEXT: ( 0, 20, 31, 0 )
+    //
+    call @dump(%0) : (tensor<?xf32, #SparseVector>) -> ()
+    call @dump(%1) : (tensor<?xf32, #SparseVector>) -> ()
+
+    // Release the resources.
+    sparse_tensor.release %sv1 : tensor<?xcomplex<f32>, #SparseVector>
+    sparse_tensor.release %0 : tensor<?xf32, #SparseVector>
+    sparse_tensor.release %1 : tensor<?xf32, #SparseVector>
+    return
+  }
+}


        


More information about the Mlir-commits mailing list