[Mlir-commits] [mlir] 736c1b6 - [mlir][sparse] introduce complex type to sparse tensor support

Aart Bik llvmlistbot at llvm.org
Mon May 16 13:17:52 PDT 2022


Author: Aart Bik
Date: 2022-05-16T13:17:36-07:00
New Revision: 736c1b66ef332014d0e183627d32edb39a3016dd

URL: https://github.com/llvm/llvm-project/commit/736c1b66ef332014d0e183627d32edb39a3016dd
DIFF: https://github.com/llvm/llvm-project/commit/736c1b66ef332014d0e183627d32edb39a3016dd.diff

LOG: [mlir][sparse] introduce complex type to sparse tensor support

This is the first implementation of complex (f64 and f32) support
in the sparse compiler, with complex add/mul as first operations.
Note that various features are still TBD, such as other ops, and
reading in complex values from file. Also, note that the
std::complex<float> had a bit of an ABI issue when passed as
single argument. It is still TBD if better solutions are possible.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D125596

Added: 
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex64.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
    mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 44e322da16fd3..d4aafb74093ea 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -55,11 +55,13 @@ enum Kind {
   kUnary,        // semiring unary op
   // Binary operations.
   kMulF,
+  kMulC,
   kMulI,
   kDivF,
   kDivS, // signed
   kDivU, // unsigned
   kAddF,
+  kAddC,
   kAddI,
   kSubF,
   kSubI,

diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
index e0c10c0662c04..85549121ff03c 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
@@ -42,7 +42,9 @@ enum class PrimaryType : uint32_t {
   kI64 = 3,
   kI32 = 4,
   kI16 = 5,
-  kI8 = 6
+  kI8 = 6,
+  kC64 = 7,
+  kC32 = 8
 };
 
 /// The actions performed by @newSparseTensor.

diff  --git a/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt
index 879c0f5abc31d..4b20a31572d9b 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt
@@ -8,6 +8,8 @@ add_mlir_dialect_library(MLIRSparseTensorPipelines
   MLIRArithmeticTransforms
   MLIRAffineToStandard
   MLIRBufferizationTransforms
+  MLIRComplexToLLVM
+  MLIRComplexToStandard
   MLIRFuncTransforms
   MLIRLinalgTransforms
   MLIRMathToLibm

diff  --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index f226db43d1bc9..1e817d1a68a49 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -48,7 +48,9 @@ void mlir::sparse_tensor::buildSparseCompiler(
   pm.addPass(createLowerAffinePass());
   pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
   pm.addPass(createMemRefToLLVMPass());
+  pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass());
   pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
+  pm.addPass(createConvertComplexToLLVMPass());
   pm.addPass(createConvertMathToLibmPass());
   pm.addPass(createConvertFuncToLLVMPass());
   pm.addPass(createReconcileUnrealizedCastsPass());

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 38f15a1ee7c01..5eb0785baf938 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -111,6 +111,13 @@ PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
     return PrimaryType::kI16;
   if (elemTp.isInteger(8))
     return PrimaryType::kI8;
+  if (auto complexTp = elemTp.dyn_cast<ComplexType>()) {
+    auto complexEltTp = complexTp.getElementType();
+    if (complexEltTp.isF64())
+      return PrimaryType::kC64;
+    if (complexEltTp.isF32())
+      return PrimaryType::kC32;
+  }
   llvm_unreachable("Unknown primary type");
 }
 
@@ -128,6 +135,10 @@ StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) {
     return "I16";
   case PrimaryType::kI8:
     return "I8";
+  case PrimaryType::kC64:
+    return "C64";
+  case PrimaryType::kC32:
+    return "C32";
   }
   llvm_unreachable("Unknown PrimaryType");
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
index fda9499581b73..ad436bf1e0224 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorUtils
 
   LINK_LIBS PUBLIC
   MLIRArithmetic
+  MLIRComplex
   MLIRIR
   MLIRLinalg
 )

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 1a119c943dd6b..c5cd2e4f4cb65 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 
@@ -303,6 +304,7 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
     assert(isInvariant(tensorExps[e].children.e1));
     return isSingleCondition(t, tensorExps[e].children.e0);
   case kMulF:
+  case kMulC:
   case kMulI:
   case kAndI:
     if (isSingleCondition(t, tensorExps[e].children.e0))
@@ -312,6 +314,7 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
       return isInvariant(tensorExps[e].children.e0);
     return false;
   case kAddF:
+  case kAddC:
   case kAddI:
     return isSingleCondition(t, tensorExps[e].children.e0) &&
            isSingleCondition(t, tensorExps[e].children.e1);
@@ -371,21 +374,18 @@ static const char *kindToOpSymbol(Kind kind) {
   case kUnary:
     return "unary";
   case kMulF:
-    return "*";
+  case kMulC:
   case kMulI:
     return "*";
   case kDivF:
-    return "/";
   case kDivS:
-    return "/";
   case kDivU:
     return "/";
   case kAddF:
-    return "+";
+  case kAddC:
   case kAddI:
     return "+";
   case kSubF:
-    return "-";
   case kSubI:
     return "-";
   case kAndI:
@@ -581,6 +581,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
       return takeDisj(kind, child0, buildLattices(rhs, i), unop);
     }
   case kMulF:
+  case kMulC:
   case kMulI:
   case kAndI:
     // A multiplicative operation only needs to be performed
@@ -590,6 +591,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
     //  ---+---+---+
     //  !x | 0 | 0 |
     //   x | 0 |x*y|
+    //
+    // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
     return takeConj(kind, // take binary conjunction
                     buildLattices(tensorExps[e].children.e0, i),
                     buildLattices(tensorExps[e].children.e1, i));
@@ -614,6 +617,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
                     buildLattices(tensorExps[e].children.e0, i),
                     buildLattices(tensorExps[e].children.e1, i));
   case kAddF:
+  case kAddC:
   case kAddI:
   case kSubF:
   case kSubI:
@@ -789,6 +793,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
       unsigned e1 = y.getValue();
       if (isa<arith::MulFOp>(def))
         return addExp(kMulF, e0, e1);
+      if (isa<complex::MulOp>(def))
+        return addExp(kMulC, e0, e1);
       if (isa<arith::MulIOp>(def))
         return addExp(kMulI, e0, e1);
       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
@@ -799,6 +805,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
         return addExp(kDivU, e0, e1);
       if (isa<arith::AddFOp>(def))
         return addExp(kAddF, e0, e1);
+      if (isa<complex::AddOp>(def))
+        return addExp(kAddC, e0, e1);
       if (isa<arith::AddIOp>(def))
         return addExp(kAddI, e0, e1);
       if (isa<arith::SubFOp>(def))
@@ -927,6 +935,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
   // Binary ops.
   case kMulF:
     return rewriter.create<arith::MulFOp>(loc, v0, v1);
+  case kMulC:
+    return rewriter.create<complex::MulOp>(loc, v0, v1);
   case kMulI:
     return rewriter.create<arith::MulIOp>(loc, v0, v1);
   case kDivF:
@@ -937,6 +947,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
   case kAddF:
     return rewriter.create<arith::AddFOp>(loc, v0, v1);
+  case kAddC:
+    return rewriter.create<complex::AddOp>(loc, v0, v1);
   case kAddI:
     return rewriter.create<arith::AddIOp>(loc, v0, v1);
   case kSubF:

diff  --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index 336c500ce1d33..ca5f48759b7d9 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -21,6 +21,7 @@
 
 #include <algorithm>
 #include <cassert>
+#include <complex>
 #include <cctype>
 #include <cinttypes>
 #include <cstdio>
@@ -33,6 +34,9 @@
 #include <numeric>
 #include <vector>
 
+using complex64 = std::complex<double>;
+using complex32 = std::complex<float>;
+
 //===----------------------------------------------------------------------===//
 //
 // Internal support for storing and reading sparse tensors.
@@ -287,6 +291,8 @@ class SparseTensorStorageBase {
   virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
+  virtual void getValues(std::vector<complex64> **) { fatal("valc64"); }
+  virtual void getValues(std::vector<complex32> **) { fatal("valc32"); }
 
   /// Element-wise insertion in lexicographic index order.
   virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); }
@@ -295,6 +301,8 @@ class SparseTensorStorageBase {
   virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); }
   virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); }
   virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); }
+  virtual void lexInsert(const uint64_t *, complex64) { fatal("insc64"); }
+  virtual void lexInsert(const uint64_t *, complex32) { fatal("insc32"); }
 
   /// Expanded insertion.
   virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
@@ -315,6 +323,14 @@ class SparseTensorStorageBase {
   virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
     fatal("expi8");
   }
+  virtual void expInsert(uint64_t *, complex64 *, bool *, uint64_t *,
+                         uint64_t) {
+    fatal("expc64");
+  }
+  virtual void expInsert(uint64_t *, complex32 *, bool *, uint64_t *,
+                         uint64_t) {
+    fatal("expc32");
+  }
 
   /// Finishes insertion.
   virtual void endInsert() = 0;
@@ -898,7 +914,7 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
            "dimension size mismatch");
   SparseTensorCOO<V> *tensor =
       SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
-  //  Read all nonzero elements.
+  // Read all nonzero elements.
   std::vector<uint64_t> indices(rank);
   for (uint64_t k = 0; k < nnz; k++) {
     if (!fgets(line, kColWidth, file)) {
@@ -1006,6 +1022,7 @@ template <typename V>
 static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
                                  uint64_t **pShape, V **pValues,
                                  uint64_t **pIndices) {
+  assert(tensor);
   auto sparseTensor =
       static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
   uint64_t rank = sparseTensor->getRank();
@@ -1293,6 +1310,10 @@ _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
   CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
 
+  // Complex matrices with wide overhead.
+  CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
+  CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
+
   // Unsupported case (add above if needed).
   fputs("unsupported combination of types\n", stderr);
   exit(1);
@@ -1319,6 +1340,8 @@ IMPL_SPARSEVALUES(sparseValuesI64, int64_t, getValues)
 IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues)
 IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues)
 IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues)
+IMPL_SPARSEVALUES(sparseValuesC64, complex64, getValues)
+IMPL_SPARSEVALUES(sparseValuesC32, complex32, getValues)
 
 /// Helper to add value to coordinate scheme, one per value type.
 IMPL_ADDELT(addEltF64, double)
@@ -1327,6 +1350,17 @@ IMPL_ADDELT(addEltI64, int64_t)
 IMPL_ADDELT(addEltI32, int32_t)
 IMPL_ADDELT(addEltI16, int16_t)
 IMPL_ADDELT(addEltI8, int8_t)
+IMPL_ADDELT(addEltC64, complex64)
+IMPL_ADDELT(addEltC32ABI, complex32)
+// Make prototype explicit to accept the !llvm.struct<(f32, f32)> without
+// any padding (which seem to happen for complex32 when passed as scalar;
+// all other cases, e.g. pointer to array, work as expected).
+// TODO: cleaner way to avoid ABI padding problem?
+void *_mlir_ciface_addEltC32(void *tensor, float r, float i,
+                             StridedMemRefType<index_type, 1> *iref,
+                             StridedMemRefType<index_type, 1> *pref) {
+  return _mlir_ciface_addEltC32ABI(tensor, complex32(r, i), iref, pref);
+}
 
 /// Helper to enumerate elements of coordinate scheme, one per value type.
 IMPL_GETNEXT(getNextF64, double)
@@ -1335,6 +1369,8 @@ IMPL_GETNEXT(getNextI64, int64_t)
 IMPL_GETNEXT(getNextI32, int32_t)
 IMPL_GETNEXT(getNextI16, int16_t)
 IMPL_GETNEXT(getNextI8, int8_t)
+IMPL_GETNEXT(getNextC64, complex64)
+IMPL_GETNEXT(getNextC32, complex32)
 
 /// Insert elements in lexicographical index order, one per value type.
 IMPL_LEXINSERT(lexInsertF64, double)
@@ -1343,6 +1379,17 @@ IMPL_LEXINSERT(lexInsertI64, int64_t)
 IMPL_LEXINSERT(lexInsertI32, int32_t)
 IMPL_LEXINSERT(lexInsertI16, int16_t)
 IMPL_LEXINSERT(lexInsertI8, int8_t)
+IMPL_LEXINSERT(lexInsertC64, complex64)
+IMPL_LEXINSERT(lexInsertC32ABI, complex32)
+// Make prototype explicit to accept the !llvm.struct<(f32, f32)> without
+// any padding (which seem to happen for complex32 when passed as scalar;
+// all other cases, e.g. pointer to array, work as expected).
+// TODO: cleaner way to avoid ABI padding problem?
+void _mlir_ciface_lexInsertC32(void *tensor,
+                               StridedMemRefType<index_type, 1> *cref, float r,
+                               float i) {
+  _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
+}
 
 /// Insert using expansion, one per value type.
 IMPL_EXPINSERT(expInsertF64, double)
@@ -1351,6 +1398,8 @@ IMPL_EXPINSERT(expInsertI64, int64_t)
 IMPL_EXPINSERT(expInsertI32, int32_t)
 IMPL_EXPINSERT(expInsertI16, int16_t)
 IMPL_EXPINSERT(expInsertI8, int8_t)
+IMPL_EXPINSERT(expInsertC64, complex64)
+IMPL_EXPINSERT(expInsertC32, complex32)
 
 #undef CASE
 #undef IMPL_SPARSEVALUES
@@ -1379,6 +1428,12 @@ void outSparseTensorI16(void *tensor, void *dest, bool sort) {
 void outSparseTensorI8(void *tensor, void *dest, bool sort) {
   return outSparseTensor<int8_t>(tensor, dest, sort);
 }
+void outSparseTensorC64(void *tensor, void *dest, bool sort) {
+  return outSparseTensor<complex64>(tensor, dest, sort);
+}
+void outSparseTensorC32(void *tensor, void *dest, bool sort) {
+  return outSparseTensor<complex32>(tensor, dest, sort);
+}
 
 //===----------------------------------------------------------------------===//
 //
@@ -1428,6 +1483,8 @@ IMPL_DELCOO(I64, int64_t)
 IMPL_DELCOO(I32, int32_t)
 IMPL_DELCOO(I16, int16_t)
 IMPL_DELCOO(I8, int8_t)
+IMPL_DELCOO(C64, complex64)
+IMPL_DELCOO(C32, complex32)
 #undef IMPL_DELCOO
 
 /// Initializes sparse tensor from a COO-flavored format expressed using C-style
@@ -1489,6 +1546,18 @@ void *convertToMLIRSparseTensorI8(uint64_t rank, uint64_t nse, uint64_t *shape,
   return toMLIRSparseTensor<int8_t>(rank, nse, shape, values, indices, perm,
                                     sparse);
 }
+void *convertToMLIRSparseTensorC64(uint64_t rank, uint64_t nse, uint64_t *shape,
+                                   complex64 *values, uint64_t *indices,
+                                   uint64_t *perm, uint8_t *sparse) {
+  return toMLIRSparseTensor<complex64>(rank, nse, shape, values, indices, perm,
+                                       sparse);
+}
+void *convertToMLIRSparseTensorC32(uint64_t rank, uint64_t nse, uint64_t *shape,
+                                   complex32 *values, uint64_t *indices,
+                                   uint64_t *perm, uint8_t *sparse) {
+  return toMLIRSparseTensor<complex32>(rank, nse, shape, values, indices, perm,
+                                       sparse);
+}
 
 /// Converts a sparse tensor to COO-flavored format expressed using C-style
 /// data structures. The expected output parameters are pointers for these
@@ -1540,6 +1609,18 @@ void convertFromMLIRSparseTensorI8(void *tensor, uint64_t *pRank,
                                    int8_t **pValues, uint64_t **pIndices) {
   fromMLIRSparseTensor<int8_t>(tensor, pRank, pNse, pShape, pValues, pIndices);
 }
+void convertFromMLIRSparseTensorC64(void *tensor, uint64_t *pRank,
+                                    uint64_t *pNse, uint64_t **pShape,
+                                    complex64 **pValues, uint64_t **pIndices) {
+  fromMLIRSparseTensor<complex64>(tensor, pRank, pNse, pShape, pValues,
+                                  pIndices);
+}
+void convertFromMLIRSparseTensorC32(void *tensor, uint64_t *pRank,
+                                    uint64_t *pNse, uint64_t **pShape,
+                                    complex32 **pValues, uint64_t **pIndices) {
+  fromMLIRSparseTensor<complex32>(tensor, pRank, pNse, pShape, pValues,
+                                  pIndices);
+}
 
 } // extern "C"
 

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
new file mode 100644
index 0000000000000..976f1c5f778c5
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
@@ -0,0 +1,116 @@
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
+#trait_op = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // a (in)
+    affine_map<(i) -> (i)>,  // b (in)
+    affine_map<(i) -> (i)>   // x (out)
+  ],
+  iterator_types = ["parallel"],
+  doc = "x(i) = a(i) OP b(i)"
+}
+
+module {
+  func.func @cadd(%arga: tensor<?xcomplex<f32>, #SparseVector>,
+                  %argb: tensor<?xcomplex<f32>, #SparseVector>)
+                      -> tensor<?xcomplex<f32>, #SparseVector> {
+    %c = arith.constant 0 : index
+    %d = tensor.dim %arga, %c : tensor<?xcomplex<f32>, #SparseVector>
+    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f32>, #SparseVector>
+    %0 = linalg.generic #trait_op
+       ins(%arga, %argb: tensor<?xcomplex<f32>, #SparseVector>,
+                         tensor<?xcomplex<f32>, #SparseVector>)
+        outs(%xv: tensor<?xcomplex<f32>, #SparseVector>) {
+        ^bb(%a: complex<f32>, %b: complex<f32>, %x: complex<f32>):
+          %1 = complex.add %a, %b : complex<f32>
+          linalg.yield %1 : complex<f32>
+    } -> tensor<?xcomplex<f32>, #SparseVector>
+    return %0 : tensor<?xcomplex<f32>, #SparseVector>
+  }
+
+  func.func @cmul(%arga: tensor<?xcomplex<f32>, #SparseVector>,
+                  %argb: tensor<?xcomplex<f32>, #SparseVector>)
+                      -> tensor<?xcomplex<f32>, #SparseVector> {
+    %c = arith.constant 0 : index
+    %d = tensor.dim %arga, %c : tensor<?xcomplex<f32>, #SparseVector>
+    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f32>, #SparseVector>
+    %0 = linalg.generic #trait_op
+       ins(%arga, %argb: tensor<?xcomplex<f32>, #SparseVector>,
+                         tensor<?xcomplex<f32>, #SparseVector>)
+        outs(%xv: tensor<?xcomplex<f32>, #SparseVector>) {
+        ^bb(%a: complex<f32>, %b: complex<f32>, %x: complex<f32>):
+          %1 = complex.mul %a, %b : complex<f32>
+          linalg.yield %1 : complex<f32>
+    } -> tensor<?xcomplex<f32>, #SparseVector>
+    return %0 : tensor<?xcomplex<f32>, #SparseVector>
+  }
+
+  func.func @dump(%arg0: tensor<?xcomplex<f32>, #SparseVector>, %d: index) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %mem = sparse_tensor.values %arg0 : tensor<?xcomplex<f32>, #SparseVector> to memref<?xcomplex<f32>>
+    scf.for %i = %c0 to %d step %c1 {
+       %v = memref.load %mem[%i] : memref<?xcomplex<f32>>
+       %real = complex.re %v : complex<f32>
+       %imag = complex.im %v : complex<f32>
+       vector.print %real : f32
+       vector.print %imag : f32
+    }
+    return
+  }
+
+  // Driver method to call and verify complex kernels.
+  func.func @entry() {
+    // Setup sparse vectors.
+    %v1 = arith.constant sparse<
+       [ [0], [28], [31] ],
+         [ (511.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > : tensor<32xcomplex<f32>>
+    %v2 = arith.constant sparse<
+       [ [1], [28], [31] ],
+         [ (1.0, 0.0), (2.0, 0.0), (3.0, 0.0) ] > : tensor<32xcomplex<f32>>
+    %sv1 = sparse_tensor.convert %v1 : tensor<32xcomplex<f32>> to tensor<?xcomplex<f32>, #SparseVector>
+    %sv2 = sparse_tensor.convert %v2 : tensor<32xcomplex<f32>> to tensor<?xcomplex<f32>, #SparseVector>
+
+    // Call sparse vector kernels.
+    %0 = call @cadd(%sv1, %sv2)
+       : (tensor<?xcomplex<f32>, #SparseVector>,
+          tensor<?xcomplex<f32>, #SparseVector>) -> tensor<?xcomplex<f32>, #SparseVector>
+    %1 = call @cmul(%sv1, %sv2)
+       : (tensor<?xcomplex<f32>, #SparseVector>,
+          tensor<?xcomplex<f32>, #SparseVector>) -> tensor<?xcomplex<f32>, #SparseVector>
+
+    //
+    // Verify the results.
+    //
+    // CHECK: 511.13
+    // CHECK-NEXT: 2
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 5
+    // CHECK-NEXT: 4
+    // CHECK-NEXT: 8
+    // CHECK-NEXT: 6
+    // CHECK-NEXT: 6
+    // CHECK-NEXT: 8
+    // CHECK-NEXT: 15
+    // CHECK-NEXT: 18
+    //
+    %d1 = arith.constant 4 : index
+    %d2 = arith.constant 2 : index
+    call @dump(%0, %d1) : (tensor<?xcomplex<f32>, #SparseVector>, index) -> ()
+    call @dump(%1, %d2) : (tensor<?xcomplex<f32>, #SparseVector>, index) -> ()
+
+    // Release the resources.
+    sparse_tensor.release %sv1 : tensor<?xcomplex<f32>, #SparseVector>
+    sparse_tensor.release %sv2 : tensor<?xcomplex<f32>, #SparseVector>
+    sparse_tensor.release %0 : tensor<?xcomplex<f32>, #SparseVector>
+    sparse_tensor.release %1 : tensor<?xcomplex<f32>, #SparseVector>
+    return
+  }
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex64.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex64.mlir
new file mode 100644
index 0000000000000..f544fae1449b9
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex64.mlir
@@ -0,0 +1,116 @@
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
+#trait_op = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // a (in)
+    affine_map<(i) -> (i)>,  // b (in)
+    affine_map<(i) -> (i)>   // x (out)
+  ],
+  iterator_types = ["parallel"],
+  doc = "x(i) = a(i) OP b(i)"
+}
+
+module {
+  func.func @cadd(%arga: tensor<?xcomplex<f64>, #SparseVector>,
+                  %argb: tensor<?xcomplex<f64>, #SparseVector>)
+                      -> tensor<?xcomplex<f64>, #SparseVector> {
+    %c = arith.constant 0 : index
+    %d = tensor.dim %arga, %c : tensor<?xcomplex<f64>, #SparseVector>
+    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+    %0 = linalg.generic #trait_op
+       ins(%arga, %argb: tensor<?xcomplex<f64>, #SparseVector>,
+                         tensor<?xcomplex<f64>, #SparseVector>)
+        outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
+        ^bb(%a: complex<f64>, %b: complex<f64>, %x: complex<f64>):
+          %1 = complex.add %a, %b : complex<f64>
+          linalg.yield %1 : complex<f64>
+    } -> tensor<?xcomplex<f64>, #SparseVector>
+    return %0 : tensor<?xcomplex<f64>, #SparseVector>
+  }
+
+  func.func @cmul(%arga: tensor<?xcomplex<f64>, #SparseVector>,
+                  %argb: tensor<?xcomplex<f64>, #SparseVector>)
+                      -> tensor<?xcomplex<f64>, #SparseVector> {
+    %c = arith.constant 0 : index
+    %d = tensor.dim %arga, %c : tensor<?xcomplex<f64>, #SparseVector>
+    %xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
+    %0 = linalg.generic #trait_op
+       ins(%arga, %argb: tensor<?xcomplex<f64>, #SparseVector>,
+                         tensor<?xcomplex<f64>, #SparseVector>)
+        outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
+        ^bb(%a: complex<f64>, %b: complex<f64>, %x: complex<f64>):
+          %1 = complex.mul %a, %b : complex<f64>
+          linalg.yield %1 : complex<f64>
+    } -> tensor<?xcomplex<f64>, #SparseVector>
+    return %0 : tensor<?xcomplex<f64>, #SparseVector>
+  }
+
+  func.func @dump(%arg0: tensor<?xcomplex<f64>, #SparseVector>, %d: index) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %mem = sparse_tensor.values %arg0 : tensor<?xcomplex<f64>, #SparseVector> to memref<?xcomplex<f64>>
+    scf.for %i = %c0 to %d step %c1 {
+       %v = memref.load %mem[%i] : memref<?xcomplex<f64>>
+       %real = complex.re %v : complex<f64>
+       %imag = complex.im %v : complex<f64>
+       vector.print %real : f64
+       vector.print %imag : f64
+    }
+    return
+  }
+
+  // Driver method to call and verify complex kernels.
+  func.func @entry() {
+    // Setup sparse vectors.
+    %v1 = arith.constant sparse<
+       [ [0], [28], [31] ],
+         [ (511.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > : tensor<32xcomplex<f64>>
+    %v2 = arith.constant sparse<
+       [ [1], [28], [31] ],
+         [ (1.0, 0.0), (2.0, 0.0), (3.0, 0.0) ] > : tensor<32xcomplex<f64>>
+    %sv1 = sparse_tensor.convert %v1 : tensor<32xcomplex<f64>> to tensor<?xcomplex<f64>, #SparseVector>
+    %sv2 = sparse_tensor.convert %v2 : tensor<32xcomplex<f64>> to tensor<?xcomplex<f64>, #SparseVector>
+
+    // Call sparse vector kernels.
+    %0 = call @cadd(%sv1, %sv2)
+       : (tensor<?xcomplex<f64>, #SparseVector>,
+          tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
+    %1 = call @cmul(%sv1, %sv2)
+       : (tensor<?xcomplex<f64>, #SparseVector>,
+          tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
+
+    //
+    // Verify the results.
+    //
+    // CHECK: 511.13
+    // CHECK-NEXT: 2
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 5
+    // CHECK-NEXT: 4
+    // CHECK-NEXT: 8
+    // CHECK-NEXT: 6
+    // CHECK-NEXT: 6
+    // CHECK-NEXT: 8
+    // CHECK-NEXT: 15
+    // CHECK-NEXT: 18
+    //
+    %d1 = arith.constant 4 : index
+    %d2 = arith.constant 2 : index
+    call @dump(%0, %d1) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
+    call @dump(%1, %d2) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
+
+    // Release the resources.
+    sparse_tensor.release %sv1 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.release %sv2 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.release %0 : tensor<?xcomplex<f64>, #SparseVector>
+    sparse_tensor.release %1 : tensor<?xcomplex<f64>, #SparseVector>
+    return
+  }
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 9cf7c5dc776b4..93b21872f254a 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2009,6 +2009,7 @@ cc_library(
     includes = ["include"],
     deps = [
         ":ArithmeticDialect",
+        ":ComplexDialect",
         ":IR",
         ":LinalgOps",
         ":MathDialect",


        


More information about the Mlir-commits mailing list