[Mlir-commits] [mlir] d390035 - [mlir][sparse] Support more complex operations.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 20 14:39:35 PDT 2022
Author: Bixia Zheng
Date: 2022-05-20T14:39:26-07:00
New Revision: d390035b46907f1c421d312bdd33f44ad7580415
URL: https://github.com/llvm/llvm-project/commit/d390035b46907f1c421d312bdd33f44ad7580415
DIFF: https://github.com/llvm/llvm-project/commit/d390035b46907f1c421d312bdd33f44ad7580415.diff
LOG: [mlir][sparse] Support more complex operations.
Add complex operations abs, neg, sin, log1p, sub and div.
Add test cases.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D126027
Added:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index ceee28466449..039cc5ddefac 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -31,14 +31,18 @@ enum Kind {
kIndex,
// Unary operations.
kAbsF,
+ kAbsC,
kCeilF,
kFloorF,
kSqrtF,
kExpm1F,
kLog1pF,
+ kLog1pC,
kSinF,
+ kSinC,
kTanhF,
kNegF,
+ kNegC,
kNegI,
kTruncF,
kExtF,
@@ -60,12 +64,14 @@ enum Kind {
kMulC,
kMulI,
kDivF,
+ kDivC, // complex
kDivS, // signed
kDivU, // unsigned
kAddF,
kAddC,
kAddI,
kSubF,
+ kSubC,
kSubI,
kAndI,
kOrI,
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 9a589965a528..3cf7dc60d33c 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -37,14 +37,18 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
index = x;
break;
case kAbsF:
+ case kAbsC:
case kCeilF:
case kFloorF:
case kSqrtF:
case kExpm1F:
case kLog1pF:
+ case kLog1pC:
case kSinF:
+ case kSinC:
case kTanhF:
case kNegF:
+ case kNegC:
case kNegI:
case kCIm:
case kCRe:
@@ -151,6 +155,8 @@ unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
// TODO: move this if-else logic into buildLattices
if (kind == kSubF)
s1 = mapSet(kNegF, s1);
+ else if (kind == kSubC)
+ s1 = mapSet(kNegC, s1);
else if (kind == kSubI)
s1 = mapSet(kNegI, s1);
// Followed by all in s1.
@@ -274,14 +280,18 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kTensor:
return tensorExps[e].tensor == t;
case kAbsF:
+ case kAbsC:
case kCeilF:
case kFloorF:
case kSqrtF:
case kExpm1F:
case kLog1pF:
+ case kLog1pC:
case kSinF:
+ case kSinC:
case kTanhF:
case kNegF:
+ case kNegC:
case kNegI:
case kTruncF:
case kExtF:
@@ -298,6 +308,7 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kBitCast:
return isSingleCondition(t, tensorExps[e].children.e0);
case kDivF: // note: x / c only
+ case kDivC:
case kDivS:
case kDivU:
assert(!maybeZero(tensorExps[e].children.e1));
@@ -342,6 +353,7 @@ static const char *kindToOpSymbol(Kind kind) {
case kIndex:
return "index";
case kAbsF:
+ case kAbsC:
return "abs";
case kCeilF:
return "ceil";
@@ -352,13 +364,15 @@ static const char *kindToOpSymbol(Kind kind) {
case kExpm1F:
return "expm1";
case kLog1pF:
+ case kLog1pC:
return "log1p";
case kSinF:
+ case kSinC:
return "sin";
case kTanhF:
return "tanh";
case kNegF:
- return "-";
+ case kNegC:
case kNegI:
return "-";
case kTruncF:
@@ -386,6 +400,7 @@ static const char *kindToOpSymbol(Kind kind) {
case kMulI:
return "*";
case kDivF:
+ case kDivC:
case kDivS:
case kDivU:
return "/";
@@ -394,6 +409,7 @@ static const char *kindToOpSymbol(Kind kind) {
case kAddI:
return "+";
case kSubF:
+ case kSubC:
case kSubI:
return "-";
case kAndI:
@@ -533,6 +549,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
return s;
}
case kAbsF:
+ case kAbsC:
case kCeilF:
case kCIm:
case kCRe:
@@ -540,9 +557,12 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
case kSqrtF:
case kExpm1F:
case kLog1pF:
+ case kLog1pC:
case kSinF:
+ case kSinC:
case kTanhF:
case kNegF:
+ case kNegC:
case kNegI:
case kTruncF:
case kExtF:
@@ -607,6 +627,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
buildLattices(tensorExps[e].children.e0, i),
buildLattices(tensorExps[e].children.e1, i));
case kDivF:
+ case kDivC:
case kDivS:
case kDivU:
// A division is tricky, since 0/0, 0/c, c/0 all have
@@ -630,6 +651,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
case kAddC:
case kAddI:
case kSubF:
+ case kSubC:
case kSubI:
case kOrI:
case kXorI:
@@ -696,6 +718,11 @@ Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
/// Only returns false if we are certain this is a nonzero.
bool Merger::maybeZero(unsigned e) const {
if (tensorExps[e].kind == kInvariant) {
+ if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
+ ArrayAttr arrayAttr = c.getValue();
+ return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
+ arrayAttr[0].cast<FloatAttr>().getValue().isZero();
+ }
if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
return c.value() == 0;
if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
@@ -750,6 +777,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
unsigned e = x.getValue();
if (isa<math::AbsOp>(def))
return addExp(kAbsF, e);
+ if (isa<complex::AbsOp>(def))
+ return addExp(kAbsC, e);
if (isa<math::CeilOp>(def))
return addExp(kCeilF, e);
if (isa<math::FloorOp>(def))
@@ -760,12 +789,18 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kExpm1F, e);
if (isa<math::Log1pOp>(def))
return addExp(kLog1pF, e);
+ if (isa<complex::Log1pOp>(def))
+ return addExp(kLog1pC, e);
if (isa<math::SinOp>(def))
return addExp(kSinF, e);
+ if (isa<complex::SinOp>(def))
+ return addExp(kSinC, e);
if (isa<math::TanhOp>(def))
return addExp(kTanhF, e);
if (isa<arith::NegFOp>(def))
return addExp(kNegF, e); // no negi in std
+ if (isa<complex::NegOp>(def))
+ return addExp(kNegC, e);
if (isa<arith::TruncFOp>(def))
return addExp(kTruncF, e, v);
if (isa<arith::ExtFOp>(def))
@@ -813,6 +848,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kMulI, e0, e1);
if (isa<arith::DivFOp>(def) && !maybeZero(e1))
return addExp(kDivF, e0, e1);
+ if (isa<complex::DivOp>(def) && !maybeZero(e1))
+ return addExp(kDivC, e0, e1);
if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
return addExp(kDivS, e0, e1);
if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
@@ -825,6 +862,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kAddI, e0, e1);
if (isa<arith::SubFOp>(def))
return addExp(kSubF, e0, e1);
+ if (isa<complex::SubOp>(def))
+ return addExp(kSubC, e0, e1);
if (isa<arith::SubIOp>(def))
return addExp(kSubI, e0, e1);
if (isa<arith::AndIOp>(def))
@@ -902,6 +941,11 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
// Unary ops.
case kAbsF:
return rewriter.create<math::AbsOp>(loc, v0);
+ case kAbsC: {
+ auto type = v0.getType().template cast<ComplexType>();
+ auto eltType = type.getElementType().template cast<FloatType>();
+ return rewriter.create<complex::AbsOp>(loc, eltType, v0);
+ }
case kCeilF:
return rewriter.create<math::CeilOp>(loc, v0);
case kFloorF:
@@ -912,12 +956,18 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<math::ExpM1Op>(loc, v0);
case kLog1pF:
return rewriter.create<math::Log1pOp>(loc, v0);
+ case kLog1pC:
+ return rewriter.create<complex::Log1pOp>(loc, v0);
case kSinF:
return rewriter.create<math::SinOp>(loc, v0);
+ case kSinC:
+ return rewriter.create<complex::SinOp>(loc, v0);
case kTanhF:
return rewriter.create<math::TanhOp>(loc, v0);
case kNegF:
return rewriter.create<arith::NegFOp>(loc, v0);
+ case kNegC:
+ return rewriter.create<complex::NegOp>(loc, v0);
case kNegI: // no negi in std
return rewriter.create<arith::SubIOp>(
loc,
@@ -964,6 +1014,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<arith::MulIOp>(loc, v0, v1);
case kDivF:
return rewriter.create<arith::DivFOp>(loc, v0, v1);
+ case kDivC:
+ return rewriter.create<complex::DivOp>(loc, v0, v1);
case kDivS:
return rewriter.create<arith::DivSIOp>(loc, v0, v1);
case kDivU:
@@ -976,6 +1028,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<arith::AddIOp>(loc, v0, v1);
case kSubF:
return rewriter.create<arith::SubFOp>(loc, v0, v1);
+ case kSubC:
+ return rewriter.create<complex::SubOp>(loc, v0, v1);
case kSubI:
return rewriter.create<arith::SubIOp>(loc, v0, v1);
case kAndI:
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
new file mode 100644
index 000000000000..74c30e6f5ff8
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
@@ -0,0 +1,179 @@
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
+#trait_op1 = {
+ indexing_maps = [
+ affine_map<(i) -> (i)>, // a (in)
+ affine_map<(i) -> (i)> // x (out)
+ ],
+ iterator_types = ["parallel"],
+ doc = "x(i) = OP a(i)"
+}
+
+#trait_op2 = {
+ indexing_maps = [
+ affine_map<(i) -> (i)>, // a (in)
+ affine_map<(i) -> (i)>, // b (in)
+ affine_map<(i) -> (i)> // x (out)
+ ],
+ iterator_types = ["parallel"],
+ doc = "x(i) = a(i) OP b(i)"
+}
+
+module {
+ func.func @cops(%arga: tensor<?xcomplex<f64>, #SparseVector>,
+ %argb: tensor<?xcomplex<f64>, #SparseVector>)
+ -> tensor<?xcomplex<f64>, #SparseVector> {
+ %c0 = arith.constant 0 : index
+ %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
+ %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+ %0 = linalg.generic #trait_op2
+ ins(%arga, %argb: tensor<?xcomplex<f64>, #SparseVector>,
+ tensor<?xcomplex<f64>, #SparseVector>)
+ outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
+ ^bb(%a: complex<f64>, %b: complex<f64>, %x: complex<f64>):
+ %1 = complex.neg %b : complex<f64>
+ %2 = complex.sub %a, %1 : complex<f64>
+ linalg.yield %2 : complex<f64>
+ } -> tensor<?xcomplex<f64>, #SparseVector>
+ return %0 : tensor<?xcomplex<f64>, #SparseVector>
+ }
+
+ func.func @csin(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+ -> tensor<?xcomplex<f64>, #SparseVector> {
+ %c0 = arith.constant 0 : index
+ %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
+ %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+ %0 = linalg.generic #trait_op1
+ ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+ outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
+ ^bb(%a: complex<f64>, %x: complex<f64>):
+ %1 = complex.sin %a : complex<f64>
+ linalg.yield %1 : complex<f64>
+ } -> tensor<?xcomplex<f64>, #SparseVector>
+ return %0 : tensor<?xcomplex<f64>, #SparseVector>
+ }
+
+ func.func @cdiv(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+ -> tensor<?xcomplex<f64>, #SparseVector> {
+ %c0 = arith.constant 0 : index
+ %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
+ %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+ %c = complex.constant [2.0 : f64, 0.0 : f64] : complex<f64>
+ %0 = linalg.generic #trait_op1
+ ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+ outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
+ ^bb(%a: complex<f64>, %x: complex<f64>):
+ %1 = complex.div %a, %c : complex<f64>
+ linalg.yield %1 : complex<f64>
+ } -> tensor<?xcomplex<f64>, #SparseVector>
+ return %0 : tensor<?xcomplex<f64>, #SparseVector>
+ }
+
+ func.func @cabs(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+ -> tensor<?xf64, #SparseVector> {
+ %c0 = arith.constant 0 : index
+ %d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
+ %xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
+ %0 = linalg.generic #trait_op1
+ ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
+ outs(%xv: tensor<?xf64, #SparseVector>) {
+ ^bb(%a: complex<f64>, %x: f64):
+ %1 = complex.abs %a : complex<f64>
+ linalg.yield %1 : f64
+ } -> tensor<?xf64, #SparseVector>
+ return %0 : tensor<?xf64, #SparseVector>
+ }
+
+ func.func @dumpc(%arg0: tensor<?xcomplex<f64>, #SparseVector>, %d: index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %mem = sparse_tensor.values %arg0 : tensor<?xcomplex<f64>, #SparseVector> to memref<?xcomplex<f64>>
+ scf.for %i = %c0 to %d step %c1 {
+ %v = memref.load %mem[%i] : memref<?xcomplex<f64>>
+ %real = complex.re %v : complex<f64>
+ %imag = complex.im %v : complex<f64>
+ vector.print %real : f64
+ vector.print %imag : f64
+ }
+ return
+ }
+
+ func.func @dumpf(%arg0: tensor<?xf64, #SparseVector>) {
+ %c0 = arith.constant 0 : index
+ %d0 = arith.constant 0.0 : f64
+ %values = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf64>
+ %0 = vector.transfer_read %values[%c0], %d0: memref<?xf64>, vector<3xf64>
+ vector.print %0 : vector<3xf64>
+ return
+ }
+
+ // Driver method to call and verify complex kernels.
+ func.func @entry() {
+ // Setup sparse vectors.
+ %v1 = arith.constant sparse<
+ [ [0], [28], [31] ],
+ [ (-5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > : tensor<32xcomplex<f64>>
+ %v2 = arith.constant sparse<
+ [ [1], [28], [31] ],
+ [ (1.0, 0.0), (-2.0, 0.0), (3.0, 0.0) ] > : tensor<32xcomplex<f64>>
+ %sv1 = sparse_tensor.convert %v1 : tensor<32xcomplex<f64>> to tensor<?xcomplex<f64>, #SparseVector>
+ %sv2 = sparse_tensor.convert %v2 : tensor<32xcomplex<f64>> to tensor<?xcomplex<f64>, #SparseVector>
+
+ // Call sparse vector kernels.
+ %0 = call @cops(%sv1, %sv2)
+ : (tensor<?xcomplex<f64>, #SparseVector>,
+ tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
+ %1 = call @csin(%sv1)
+ : (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
+ %2 = call @cdiv(%sv1)
+ : (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
+ %3 = call @cabs(%sv1)
+ : (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xf64, #SparseVector>
+
+ //
+ // Verify the results.
+ //
+ %d3 = arith.constant 3 : index
+ %d4 = arith.constant 4 : index
+ // CHECK: -5.13
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 4
+ // CHECK-NEXT: 8
+ // CHECK-NEXT: 6
+ call @dumpc(%0, %d4) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
+ // CHECK-NEXT: 3.43887
+ // CHECK-NEXT: 1.47097
+ // CHECK-NEXT: 3.85374
+ // CHECK-NEXT: -27.0168
+ // CHECK-NEXT: -193.43
+ // CHECK-NEXT: 57.2184
+ call @dumpc(%1, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
+ // CHECK-NEXT: -2.565
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 1.5
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: 2.5
+ // CHECK-NEXT: 3
+ call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
+ // CHECK-NEXT: ( 5.50608, 5, 7.81025 )
+ call @dumpf(%3) : (tensor<?xf64, #SparseVector>) -> ()
+
+ // Release the resources.
+ sparse_tensor.release %sv1 : tensor<?xcomplex<f64>, #SparseVector>
+ sparse_tensor.release %sv2 : tensor<?xcomplex<f64>, #SparseVector>
+ sparse_tensor.release %0 : tensor<?xcomplex<f64>, #SparseVector>
+ sparse_tensor.release %1 : tensor<?xcomplex<f64>, #SparseVector>
+ sparse_tensor.release %2 : tensor<?xcomplex<f64>, #SparseVector>
+ sparse_tensor.release %3 : tensor<?xf64, #SparseVector>
+ return
+ }
+}
More information about the Mlir-commits
mailing list