[Mlir-commits] [mlir] a14057d - [mlir][sparse] Add more complex operations.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 25 16:38:32 PDT 2022
Author: bixia1
Date: 2022-05-25T16:38:09-07:00
New Revision: a14057d4bddbb78aed7beaf282e75137b2dc60b9
URL: https://github.com/llvm/llvm-project/commit/a14057d4bddbb78aed7beaf282e75137b2dc60b9
DIFF: https://github.com/llvm/llvm-project/commit/a14057d4bddbb78aed7beaf282e75137b2dc60b9.diff
LOG: [mlir][sparse] Add more complex operations.
Support complex operations sqrt, expm1, and tanh.
Add tests.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D126393
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 039cc5ddefac5..db83ce054c0ff 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -35,12 +35,15 @@ enum Kind {
kCeilF,
kFloorF,
kSqrtF,
+ kSqrtC,
kExpm1F,
+ kExpm1C,
kLog1pF,
kLog1pC,
kSinF,
kSinC,
kTanhF,
+ kTanhC,
kNegF,
kNegC,
kNegI,
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 3cf7dc60d33c3..80d2dbba187b8 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -41,12 +41,15 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
case kCeilF:
case kFloorF:
case kSqrtF:
+ case kSqrtC:
case kExpm1F:
+ case kExpm1C:
case kLog1pF:
case kLog1pC:
case kSinF:
case kSinC:
case kTanhF:
+ case kTanhC:
case kNegF:
case kNegC:
case kNegI:
@@ -284,12 +287,15 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kCeilF:
case kFloorF:
case kSqrtF:
+ case kSqrtC:
case kExpm1F:
+ case kExpm1C:
case kLog1pF:
case kLog1pC:
case kSinF:
case kSinC:
case kTanhF:
+ case kTanhC:
case kNegF:
case kNegC:
case kNegI:
@@ -360,8 +366,10 @@ static const char *kindToOpSymbol(Kind kind) {
case kFloorF:
return "floor";
case kSqrtF:
+ case kSqrtC:
return "sqrt";
case kExpm1F:
+ case kExpm1C:
return "expm1";
case kLog1pF:
case kLog1pC:
@@ -370,6 +378,7 @@ static const char *kindToOpSymbol(Kind kind) {
case kSinC:
return "sin";
case kTanhF:
+ case kTanhC:
return "tanh";
case kNegF:
case kNegC:
@@ -449,10 +458,13 @@ void Merger::dumpExp(unsigned e) const {
case kCeilF:
case kFloorF:
case kSqrtF:
+ case kSqrtC:
case kExpm1F:
+ case kExpm1C:
case kLog1pF:
case kSinF:
case kTanhF:
+ case kTanhC:
case kNegF:
case kNegI:
case kTruncF:
@@ -555,12 +567,15 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
case kCRe:
case kFloorF:
case kSqrtF:
+ case kSqrtC:
case kExpm1F:
+ case kExpm1C:
case kLog1pF:
case kLog1pC:
case kSinF:
case kSinC:
case kTanhF:
+ case kTanhC:
case kNegF:
case kNegC:
case kNegI:
@@ -785,8 +800,12 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kFloorF, e);
if (isa<math::SqrtOp>(def))
return addExp(kSqrtF, e);
+ if (isa<complex::SqrtOp>(def))
+ return addExp(kSqrtC, e);
if (isa<math::ExpM1Op>(def))
return addExp(kExpm1F, e);
+ if (isa<complex::Expm1Op>(def))
+ return addExp(kExpm1C, e);
if (isa<math::Log1pOp>(def))
return addExp(kLog1pF, e);
if (isa<complex::Log1pOp>(def))
@@ -797,6 +816,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kSinC, e);
if (isa<math::TanhOp>(def))
return addExp(kTanhF, e);
+ if (isa<complex::TanhOp>(def))
+ return addExp(kTanhC, e);
if (isa<arith::NegFOp>(def))
return addExp(kNegF, e); // no negi in std
if (isa<complex::NegOp>(def))
@@ -952,8 +973,12 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<math::FloorOp>(loc, v0);
case kSqrtF:
return rewriter.create<math::SqrtOp>(loc, v0);
+ case kSqrtC:
+ return rewriter.create<complex::SqrtOp>(loc, v0);
case kExpm1F:
return rewriter.create<math::ExpM1Op>(loc, v0);
+ case kExpm1C:
+ return rewriter.create<complex::Expm1Op>(loc, v0);
case kLog1pF:
return rewriter.create<math::Log1pOp>(loc, v0);
case kLog1pC:
@@ -964,6 +989,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<complex::SinOp>(loc, v0);
case kTanhF:
return rewriter.create<math::TanhOp>(loc, v0);
+ case kTanhC:
+ return rewriter.create<complex::TanhOp>(loc, v0);
case kNegF:
return rewriter.create<arith::NegFOp>(loc, v0);
case kNegC:
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
index 74c30e6f5ff83..0fbb2b7800f76 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
@@ -59,6 +59,54 @@ module {
return %0 : tensor<?xcomplex<f64>, #SparseVector>
}
+ func.func @complex_sqrt(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+ -> tensor<?xcomplex<f64>, #SparseVector> {
+ %c0 = arith.constant 0 : index
+ %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
+ %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+ %0 = linalg.generic #trait_op1
+ ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+ outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
+ ^bb(%a: complex<f64>, %x: complex<f64>):
+ %1 = complex.sqrt %a : complex<f64>
+ linalg.yield %1 : complex<f64>
+ } -> tensor<?xcomplex<f64>, #SparseVector>
+ return %0 : tensor<?xcomplex<f64>, #SparseVector>
+ }
+
+ func.func @complex_tanh(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+ -> tensor<?xcomplex<f64>, #SparseVector> {
+ %c0 = arith.constant 0 : index
+ %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
+ %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+ %0 = linalg.generic #trait_op1
+ ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+ outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
+ ^bb(%a: complex<f64>, %x: complex<f64>):
+ %1 = complex.tanh %a : complex<f64>
+ linalg.yield %1 : complex<f64>
+ } -> tensor<?xcomplex<f64>, #SparseVector>
+ return %0 : tensor<?xcomplex<f64>, #SparseVector>
+ }
+
+ func.func @clog1p_expm1(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+ -> tensor<?xcomplex<f64>, #SparseVector> {
+ %c0 = arith.constant 0 : index
+ %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
+ %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+ %0 = linalg.generic #trait_op1
+ ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+ outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
+ ^bb(%a: complex<f64>, %x: complex<f64>):
+ %1 = complex.log1p %a : complex<f64>
+ // TODO(bixia): Enable this line after adding complex.expm1 to
+ // complex to standard lowering.
+ // %2 = complex.expm1 %1 : complex<f64>
+ linalg.yield %1 : complex<f64>
+ } -> tensor<?xcomplex<f64>, #SparseVector>
+ return %0 : tensor<?xcomplex<f64>, #SparseVector>
+ }
+
func.func @cdiv(%arga: tensor<?xcomplex<f64>, #SparseVector>)
-> tensor<?xcomplex<f64>, #SparseVector> {
%c0 = arith.constant 0 : index
@@ -131,9 +179,15 @@ module {
tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
%1 = call @csin(%sv1)
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
- %2 = call @cdiv(%sv1)
+ %2 = call @complex_sqrt(%sv1)
+ : (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
+ %3 = call @complex_tanh(%sv2)
+ : (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
+ %4 = call @clog1p_expm1(%sv1)
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
- %3 = call @cabs(%sv1)
+ %5 = call @cdiv(%sv1)
+ : (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
+ %6 = call @cabs(%sv1)
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xf64, #SparseVector>
//
@@ -157,15 +211,36 @@ module {
// CHECK-NEXT: -193.43
// CHECK-NEXT: 57.2184
call @dumpc(%1, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
+ // CHECK-NEXT: 0.433635
+ // CHECK-NEXT: 2.30609
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 2.53083
+ // CHECK-NEXT: 1.18538
+ call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
+ // CHECK-NEXT: 0.761594
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: -0.964028
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 0.995055
+ // CHECK-NEXT: 0
+ call @dumpc(%3, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
+ // CHECK-NEXT: 1.52361
+ // CHECK-NEXT: 2.69061
+ // CHECK-NEXT: 1.73287
+ // CHECK-NEXT: 0.785398
+ // CHECK-NEXT: 2.13833
+ // CHECK-NEXT: 0.785398
+ call @dumpc(%4, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
// CHECK-NEXT: -2.565
// CHECK-NEXT: 1
// CHECK-NEXT: 1.5
// CHECK-NEXT: 2
// CHECK-NEXT: 2.5
// CHECK-NEXT: 3
- call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
+ call @dumpc(%5, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
// CHECK-NEXT: ( 5.50608, 5, 7.81025 )
- call @dumpf(%3) : (tensor<?xf64, #SparseVector>) -> ()
+ call @dumpf(%6) : (tensor<?xf64, #SparseVector>) -> ()
// Release the resources.
sparse_tensor.release %sv1 : tensor<?xcomplex<f64>, #SparseVector>
@@ -173,7 +248,10 @@ module {
sparse_tensor.release %0 : tensor<?xcomplex<f64>, #SparseVector>
sparse_tensor.release %1 : tensor<?xcomplex<f64>, #SparseVector>
sparse_tensor.release %2 : tensor<?xcomplex<f64>, #SparseVector>
- sparse_tensor.release %3 : tensor<?xf64, #SparseVector>
+ sparse_tensor.release %3 : tensor<?xcomplex<f64>, #SparseVector>
+ sparse_tensor.release %4 : tensor<?xcomplex<f64>, #SparseVector>
+ sparse_tensor.release %5 : tensor<?xcomplex<f64>, #SparseVector>
+ sparse_tensor.release %6 : tensor<?xf64, #SparseVector>
return
}
}
More information about the Mlir-commits
mailing list