[Mlir-commits] [mlir] bfadd13 - [mlir][sparse] Moved _mlir_ciface_newSparseTensor closer to its macros

wren romano llvmlistbot at llvm.org
Mon May 16 17:53:31 PDT 2022


Author: wren romano
Date: 2022-05-16T17:53:25-07:00
New Revision: bfadd13df474aac157d759cea946f1e5c1297000

URL: https://github.com/llvm/llvm-project/commit/bfadd13df474aac157d759cea946f1e5c1297000
DIFF: https://github.com/llvm/llvm-project/commit/bfadd13df474aac157d759cea946f1e5c1297000.diff

LOG: [mlir][sparse] Moved _mlir_ciface_newSparseTensor closer to its macros

This is a followup to D125431, to keep from confusing the machinery that generates diffs (since combining these two changes into one would obfuscate the changes actually made in the previous differential).

Depends On D125431

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D125432

Added: 
    

Modified: 
    mlir/lib/ExecutionEngine/SparseTensorUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index b8731ff1c745..24d77a6c3ec3 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -1410,8 +1410,138 @@ extern "C" {
   }
 
 #define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
-// TODO(D125432): move `_mlir_ciface_newSparseTensor` closer to these
-// macro definitions, but as a separate change so as not to muddy the 
diff .
+
+// Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
+// can safely rewrite kIndex to kU64.  We make this assertion to guarantee
+// that this file cannot get out of sync with its header.
+static_assert(std::is_same<index_type, uint64_t>::value,
+              "Expected index_type == uint64_t");
+
+/// Constructs a new sparse tensor. This is the "swiss army knife"
+/// method for materializing sparse tensors into the computation.
+///
+/// Action:
+/// kEmpty = returns empty storage to fill later
+/// kFromFile = returns storage, where ptr contains filename to read
+/// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
+/// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
+/// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
+/// kToIterator = returns iterator from storage in ptr (call getNext() to use)
+void *
+_mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
+                             StridedMemRefType<index_type, 1> *sref,
+                             StridedMemRefType<index_type, 1> *pref,
+                             OverheadType ptrTp, OverheadType indTp,
+                             PrimaryType valTp, Action action, void *ptr) {
+  assert(aref && sref && pref);
+  assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
+         pref->strides[0] == 1);
+  assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
+  const DimLevelType *sparsity = aref->data + aref->offset;
+  const index_type *shape = sref->data + sref->offset;
+  const index_type *perm = pref->data + pref->offset;
+  uint64_t rank = aref->sizes[0];
+
+  // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
+  // This is safe because of the static_assert above.
+  if (ptrTp == OverheadType::kIndex)
+    ptrTp = OverheadType::kU64;
+  if (indTp == OverheadType::kIndex)
+    indTp = OverheadType::kU64;
+
+  // Double matrices with all combinations of overhead storage.
+  CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
+       uint64_t, double);
+  CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
+       uint32_t, double);
+  CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
+       uint16_t, double);
+  CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
+       uint8_t, double);
+  CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
+       uint64_t, double);
+  CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
+       uint32_t, double);
+  CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
+       uint16_t, double);
+  CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
+       uint8_t, double);
+  CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
+       uint64_t, double);
+  CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
+       uint32_t, double);
+  CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
+       uint16_t, double);
+  CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
+       uint8_t, double);
+  CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
+       uint64_t, double);
+  CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
+       uint32_t, double);
+  CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
+       uint16_t, double);
+  CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
+       uint8_t, double);
+
+  // Float matrices with all combinations of overhead storage.
+  CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
+       uint64_t, float);
+  CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
+       uint32_t, float);
+  CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
+       uint16_t, float);
+  CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
+       uint8_t, float);
+  CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
+       uint64_t, float);
+  CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
+       uint32_t, float);
+  CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
+       uint16_t, float);
+  CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
+       uint8_t, float);
+  CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
+       uint64_t, float);
+  CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
+       uint32_t, float);
+  CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
+       uint16_t, float);
+  CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
+       uint8_t, float);
+  CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
+       uint64_t, float);
+  CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
+       uint32_t, float);
+  CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
+       uint16_t, float);
+  CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
+       uint8_t, float);
+
+  // Integral matrices with both overheads of the same type.
+  CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
+  CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
+  CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
+  CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
+  CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
+  CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
+  CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
+  CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
+  CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
+  CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
+  CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
+  CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
+  CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
+
+  // Complex matrices with wide overhead.
+  CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
+  CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
+
+  // Unsupported case (add above if needed).
+  fputs("unsupported combination of types\n", stderr);
+  exit(1);
+}
+#undef CASE
+#undef CASE_SECSAME
 
 /// Methods that provide direct access to values.
 #define IMPL_SPARSEVALUES(VNAME, V)                                            \
@@ -1476,10 +1606,10 @@ FOREVERY_SIMPLEX_V(IMPL_ADDELT)
 IMPL_ADDELT(C64, complex64)
 // TODO: cleaner way to avoid ABI padding problem?
 IMPL_ADDELT(C32ABI, complex32)
-void *_mlir_ciface_addEltC32(void *tensor, float r, float i,
+void *_mlir_ciface_addEltC32(void *coo, float r, float i,
                              StridedMemRefType<index_type, 1> *iref,
                              StridedMemRefType<index_type, 1> *pref) {
-  return _mlir_ciface_addEltC32ABI(tensor, complex32(r, i), iref, pref);
+  return _mlir_ciface_addEltC32ABI(coo, complex32(r, i), iref, pref);
 }
 #undef IMPL_ADDELT
 
@@ -1549,138 +1679,6 @@ void _mlir_ciface_lexInsertC32(void *tensor,
 FOREVERY_V(IMPL_EXPINSERT)
 #undef IMPL_EXPINSERT
 
-// Assume index_type is in fact uint64_t, so that _mlir_ciface_newSparseTensor
-// can safely rewrite kIndex to kU64.  We make this assertion to guarantee
-// that this file cannot get out of sync with its header.
-static_assert(std::is_same<index_type, uint64_t>::value,
-              "Expected index_type == uint64_t");
-
-/// Constructs a new sparse tensor. This is the "swiss army knife"
-/// method for materializing sparse tensors into the computation.
-///
-/// Action:
-/// kEmpty = returns empty storage to fill later
-/// kFromFile = returns storage, where ptr contains filename to read
-/// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
-/// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
-/// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
-/// kToIterator = returns iterator from storage in ptr (call getNext() to use)
-void *
-_mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
-                             StridedMemRefType<index_type, 1> *sref,
-                             StridedMemRefType<index_type, 1> *pref,
-                             OverheadType ptrTp, OverheadType indTp,
-                             PrimaryType valTp, Action action, void *ptr) {
-  assert(aref && sref && pref);
-  assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
-         pref->strides[0] == 1);
-  assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
-  const DimLevelType *sparsity = aref->data + aref->offset;
-  const index_type *shape = sref->data + sref->offset;
-  const index_type *perm = pref->data + pref->offset;
-  uint64_t rank = aref->sizes[0];
-
-  // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
-  // This is safe because of the static_assert above.
-  if (ptrTp == OverheadType::kIndex)
-    ptrTp = OverheadType::kU64;
-  if (indTp == OverheadType::kIndex)
-    indTp = OverheadType::kU64;
-
-  // Double matrices with all combinations of overhead storage.
-  CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
-       uint64_t, double);
-  CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
-       uint32_t, double);
-  CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
-       uint16_t, double);
-  CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
-       uint8_t, double);
-  CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
-       uint64_t, double);
-  CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
-       uint32_t, double);
-  CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
-       uint16_t, double);
-  CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
-       uint8_t, double);
-  CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
-       uint64_t, double);
-  CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
-       uint32_t, double);
-  CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
-       uint16_t, double);
-  CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
-       uint8_t, double);
-  CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
-       uint64_t, double);
-  CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
-       uint32_t, double);
-  CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
-       uint16_t, double);
-  CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
-       uint8_t, double);
-
-  // Float matrices with all combinations of overhead storage.
-  CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
-       uint64_t, float);
-  CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
-       uint32_t, float);
-  CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
-       uint16_t, float);
-  CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
-       uint8_t, float);
-  CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
-       uint64_t, float);
-  CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
-       uint32_t, float);
-  CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
-       uint16_t, float);
-  CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
-       uint8_t, float);
-  CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
-       uint64_t, float);
-  CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
-       uint32_t, float);
-  CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
-       uint16_t, float);
-  CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
-       uint8_t, float);
-  CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
-       uint64_t, float);
-  CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
-       uint32_t, float);
-  CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
-       uint16_t, float);
-  CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
-       uint8_t, float);
-
-  // Integral matrices with both overheads of the same type.
-  CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
-  CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
-  CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
-  CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
-  CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
-  CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
-  CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
-  CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
-  CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
-  CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
-  CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
-  CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
-  CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
-
-  // Complex matrices with wide overhead.
-  CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
-  CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
-
-  // Unsupported case (add above if needed).
-  fputs("unsupported combination of types\n", stderr);
-  exit(1);
-}
-#undef CASE
-#undef CASE_SECSAME
-
 /// Output a sparse tensor, one per value type.
 #define IMPL_OUTSPARSETENSOR(VNAME, V)                                         \
   void outSparseTensor##VNAME(void *coo, void *dest, bool sort) {              \


        


More information about the Mlir-commits mailing list