[Mlir-commits] [mlir] a14057d - [mlir][sparse] Add more complex operations.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 25 16:38:32 PDT 2022


Author: bixia1
Date: 2022-05-25T16:38:09-07:00
New Revision: a14057d4bddbb78aed7beaf282e75137b2dc60b9

URL: https://github.com/llvm/llvm-project/commit/a14057d4bddbb78aed7beaf282e75137b2dc60b9
DIFF: https://github.com/llvm/llvm-project/commit/a14057d4bddbb78aed7beaf282e75137b2dc60b9.diff

LOG: [mlir][sparse] Add more complex operations.

Support complex operations sqrt, expm1, and tanh.

Add tests.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D126393

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 039cc5ddefac5..db83ce054c0ff 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -35,12 +35,15 @@ enum Kind {
   kCeilF,
   kFloorF,
   kSqrtF,
+  kSqrtC,
   kExpm1F,
+  kExpm1C,
   kLog1pF,
   kLog1pC,
   kSinF,
   kSinC,
   kTanhF,
+  kTanhC,
   kNegF,
   kNegC,
   kNegI,

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 3cf7dc60d33c3..80d2dbba187b8 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -41,12 +41,15 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
   case kCeilF:
   case kFloorF:
   case kSqrtF:
+  case kSqrtC:
   case kExpm1F:
+  case kExpm1C:
   case kLog1pF:
   case kLog1pC:
   case kSinF:
   case kSinC:
   case kTanhF:
+  case kTanhC:
   case kNegF:
   case kNegC:
   case kNegI:
@@ -284,12 +287,15 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
   case kCeilF:
   case kFloorF:
   case kSqrtF:
+  case kSqrtC:
   case kExpm1F:
+  case kExpm1C:
   case kLog1pF:
   case kLog1pC:
   case kSinF:
   case kSinC:
   case kTanhF:
+  case kTanhC:
   case kNegF:
   case kNegC:
   case kNegI:
@@ -360,8 +366,10 @@ static const char *kindToOpSymbol(Kind kind) {
   case kFloorF:
     return "floor";
   case kSqrtF:
+  case kSqrtC:
     return "sqrt";
   case kExpm1F:
+  case kExpm1C:
     return "expm1";
   case kLog1pF:
   case kLog1pC:
@@ -370,6 +378,7 @@ static const char *kindToOpSymbol(Kind kind) {
   case kSinC:
     return "sin";
   case kTanhF:
+  case kTanhC:
     return "tanh";
   case kNegF:
   case kNegC:
@@ -449,10 +458,13 @@ void Merger::dumpExp(unsigned e) const {
   case kCeilF:
   case kFloorF:
   case kSqrtF:
+  case kSqrtC:
   case kExpm1F:
+  case kExpm1C:
   case kLog1pF:
   case kSinF:
   case kTanhF:
+  case kTanhC:
   case kNegF:
   case kNegI:
   case kTruncF:
@@ -555,12 +567,15 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
   case kCRe:
   case kFloorF:
   case kSqrtF:
+  case kSqrtC:
   case kExpm1F:
+  case kExpm1C:
   case kLog1pF:
   case kLog1pC:
   case kSinF:
   case kSinC:
   case kTanhF:
+  case kTanhC:
   case kNegF:
   case kNegC:
   case kNegI:
@@ -785,8 +800,12 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
         return addExp(kFloorF, e);
       if (isa<math::SqrtOp>(def))
         return addExp(kSqrtF, e);
+      if (isa<complex::SqrtOp>(def))
+        return addExp(kSqrtC, e);
       if (isa<math::ExpM1Op>(def))
         return addExp(kExpm1F, e);
+      if (isa<complex::Expm1Op>(def))
+        return addExp(kExpm1C, e);
       if (isa<math::Log1pOp>(def))
         return addExp(kLog1pF, e);
       if (isa<complex::Log1pOp>(def))
@@ -797,6 +816,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
         return addExp(kSinC, e);
       if (isa<math::TanhOp>(def))
         return addExp(kTanhF, e);
+      if (isa<complex::TanhOp>(def))
+        return addExp(kTanhC, e);
       if (isa<arith::NegFOp>(def))
         return addExp(kNegF, e); // no negi in std
       if (isa<complex::NegOp>(def))
@@ -952,8 +973,12 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
     return rewriter.create<math::FloorOp>(loc, v0);
   case kSqrtF:
     return rewriter.create<math::SqrtOp>(loc, v0);
+  case kSqrtC:
+    return rewriter.create<complex::SqrtOp>(loc, v0);
   case kExpm1F:
     return rewriter.create<math::ExpM1Op>(loc, v0);
+  case kExpm1C:
+    return rewriter.create<complex::Expm1Op>(loc, v0);
   case kLog1pF:
     return rewriter.create<math::Log1pOp>(loc, v0);
   case kLog1pC:
@@ -964,6 +989,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
     return rewriter.create<complex::SinOp>(loc, v0);
   case kTanhF:
     return rewriter.create<math::TanhOp>(loc, v0);
+  case kTanhC:
+    return rewriter.create<complex::TanhOp>(loc, v0);
   case kNegF:
     return rewriter.create<arith::NegFOp>(loc, v0);
   case kNegC:

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
index 74c30e6f5ff83..0fbb2b7800f76 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
@@ -59,6 +59,54 @@ module {
     return %0 : tensor<?xcomplex<f64>, #SparseVector>
   }
 
+  func.func @complex_sqrt(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+                 -> tensor<?xcomplex<f64>, #SparseVector> {
+    %c0 = arith.constant 0 : index
+    %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
+    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+    %0 = linalg.generic #trait_op1
+       ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+        outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
+        ^bb(%a: complex<f64>, %x: complex<f64>):
+          %1 = complex.sqrt %a : complex<f64>
+          linalg.yield %1 : complex<f64>
+    } -> tensor<?xcomplex<f64>, #SparseVector>
+    return %0 : tensor<?xcomplex<f64>, #SparseVector>
+  }
+
+  func.func @complex_tanh(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+                 -> tensor<?xcomplex<f64>, #SparseVector> {
+    %c0 = arith.constant 0 : index
+    %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
+    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+    %0 = linalg.generic #trait_op1
+       ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+        outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
+       ^bb(%a: complex<f64>, %x: complex<f64>):
+          %1 = complex.tanh %a : complex<f64>
+          linalg.yield %1 : complex<f64>
+   } -> tensor<?xcomplex<f64>, #SparseVector>
+    return %0 : tensor<?xcomplex<f64>, #SparseVector>
+  }
+
+  func.func @clog1p_expm1(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+                 -> tensor<?xcomplex<f64>, #SparseVector> {
+    %c0 = arith.constant 0 : index
+    %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
+    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+    %0 = linalg.generic #trait_op1
+       ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+        outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
+        ^bb(%a: complex<f64>, %x: complex<f64>):
+          %1 = complex.log1p %a : complex<f64>
+          // TODO(bixia): Enable this line after adding complex.expm1 to
+          // complex to standard lowering.
+          // %2 = complex.expm1 %1 : complex<f64>
+          linalg.yield %1 : complex<f64>
+    } -> tensor<?xcomplex<f64>, #SparseVector>
+    return %0 : tensor<?xcomplex<f64>, #SparseVector>
+  }
+
   func.func @cdiv(%arga: tensor<?xcomplex<f64>, #SparseVector>)
                  -> tensor<?xcomplex<f64>, #SparseVector> {
     %c0 = arith.constant 0 : index
@@ -131,9 +179,15 @@ module {
           tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
     %1 = call @csin(%sv1)
        : (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
-    %2 = call @cdiv(%sv1)
+    %2 = call @complex_sqrt(%sv1)
+       : (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
+    %3 = call @complex_tanh(%sv2)
+       : (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
+    %4 = call @clog1p_expm1(%sv1)
        : (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
-    %3 = call @cabs(%sv1)
+    %5 = call @cdiv(%sv1)
+       : (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
+    %6 = call @cabs(%sv1)
        : (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xf64, #SparseVector>
 
     //
@@ -157,15 +211,36 @@ module {
     // CHECK-NEXT: -193.43
     // CHECK-NEXT: 57.2184
     call @dumpc(%1, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
+    // CHECK-NEXT: 0.433635
+    // CHECK-NEXT: 2.30609
+    // CHECK-NEXT: 2
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 2.53083
+    // CHECK-NEXT: 1.18538
+    call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
+    // CHECK-NEXT: 0.761594
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: -0.964028
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 0.995055
+    // CHECK-NEXT: 0
+    call @dumpc(%3, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
+    // CHECK-NEXT: 1.52361
+    // CHECK-NEXT: 2.69061
+    // CHECK-NEXT: 1.73287
+    // CHECK-NEXT: 0.785398
+    // CHECK-NEXT: 2.13833
+    // CHECK-NEXT: 0.785398
+    call @dumpc(%4, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
     // CHECK-NEXT: -2.565
     // CHECK-NEXT: 1
     // CHECK-NEXT: 1.5
     // CHECK-NEXT: 2
     // CHECK-NEXT: 2.5
     // CHECK-NEXT: 3
-    call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
+    call @dumpc(%5, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
     // CHECK-NEXT: ( 5.50608, 5, 7.81025 )
-    call @dumpf(%3) : (tensor<?xf64, #SparseVector>) -> ()
+    call @dumpf(%6) : (tensor<?xf64, #SparseVector>) -> ()
 
     // Release the resources.
     sparse_tensor.release %sv1 : tensor<?xcomplex<f64>, #SparseVector>
@@ -173,7 +248,10 @@ module {
     sparse_tensor.release %0 : tensor<?xcomplex<f64>, #SparseVector>
     sparse_tensor.release %1 : tensor<?xcomplex<f64>, #SparseVector>
     sparse_tensor.release %2 : tensor<?xcomplex<f64>, #SparseVector>
-    sparse_tensor.release %3 : tensor<?xf64, #SparseVector>
+    sparse_tensor.release %3 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.release %4 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.release %5 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.release %6 : tensor<?xf64, #SparseVector>
     return
   }
 }


        


More information about the Mlir-commits mailing list