[Mlir-commits] [mlir] 3acf498 - [mlir][sparse] support integral types i32, i16, i8 for *numerical* values

Aart Bik llvmlistbot at llvm.org
Wed Apr 7 10:01:51 PDT 2021


Author: Aart Bik
Date: 2021-04-07T10:01:37-07:00
New Revision: 3acf49829c0064d5bcea5d8f6ca032559bf8e73a

URL: https://github.com/llvm/llvm-project/commit/3acf49829c0064d5bcea5d8f6ca032559bf8e73a
DIFF: https://github.com/llvm/llvm-project/commit/3acf49829c0064d5bcea5d8f6ca032559bf8e73a.diff

LOG: [mlir][sparse] support integral types i32,i16,i8 for *numerical* values

Some sparse matrices operate on integral values (in contrast with the common
f32 and f64 values). This CL expands the compiler and runtime support to deal
with several common type combinations.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D99999

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp
    mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
    mlir/lib/ExecutionEngine/SparseUtils.cpp
    mlir/test/Integration/Sparse/CPU/sparse_matvec.mlir
    mlir/test/Integration/data/wide.mtx

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp b/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp
index ef8f1310d2ac..b1efd24d48e2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp
@@ -132,6 +132,12 @@ class TensorToValuesConverter
       name = "sparseValuesF64";
     else if (eltType.isF32())
       name = "sparseValuesF32";
+    else if (eltType.isInteger(32))
+      name = "sparseValuesI32";
+    else if (eltType.isInteger(16))
+      name = "sparseValuesI16";
+    else if (eltType.isInteger(8))
+      name = "sparseValuesI8";
     else
       return failure();
     rewriter.replaceOpWithNewOp<CallOp>(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index 9ed3282b0210..aa162bf83e61 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -837,11 +837,19 @@ static void genReductionEnd(Merger &merger, CodeGen &codegen,
   assert(codegen.curVecLength == 1);
   codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
   unsigned lhs = op.getNumShapedOperands() - 1;
-  if (red.getType().isa<VectorType>()) {
+  if (auto vtp = red.getType().dyn_cast<VectorType>()) {
     // TODO: assumes + reductions for now
+    StringAttr kind = rewriter.getStringAttr("add");
     Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
-    red = rewriter.create<vector::ReductionOp>(
-        op.getLoc(), ld.getType(), rewriter.getStringAttr("add"), red, ld);
+    // Integer reductions don't accept an accumulator.
+    if (vtp.getElementType().isa<IntegerType>()) {
+      red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
+                                                 kind, red, ValueRange{});
+      red = rewriter.create<AddIOp>(op.getLoc(), red, ld);
+    } else {
+      red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(),
+                                                 kind, red, ld);
+    }
   }
   genTensorStore(merger, codegen, rewriter, op, lhs, red);
 }

diff  --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp
index 5b4af4ca85dd..8f0dd538126b 100644
--- a/mlir/lib/ExecutionEngine/SparseUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp
@@ -127,6 +127,9 @@ class SparseTensorStorageBase {
   // Primary storage.
   virtual void getValues(std::vector<double> **) { fatal("valf64"); }
   virtual void getValues(std::vector<float> **) { fatal("valf32"); }
+  virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
+  virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
+  virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
 
   virtual ~SparseTensorStorageBase() {}
 
@@ -453,64 +456,58 @@ char *getTensorFilename(uint64_t id) {
 // implementation of a bufferized SparseTensor in MLIR. This could be replaced
 // by actual codegen in MLIR.
 //
+// Because we cannot use C++ templates with C linkage, some macro magic is used
+// to generate implementations for all required type combinations that can be
+// called from MLIR generated code.
+//
 //===----------------------------------------------------------------------===//
 
-// Cannot use templates with C linkage.
-
-struct MemRef1DU64 {
-  const uint64_t *base;
-  const uint64_t *data;
-  uint64_t off;
-  uint64_t sizes[1];
-  uint64_t strides[1];
-};
-
-struct MemRef1DU32 {
-  const uint32_t *base;
-  const uint32_t *data;
-  uint64_t off;
-  uint64_t sizes[1];
-  uint64_t strides[1];
-};
+#define TEMPLATE(NAME, TYPE)                                                   \
+  struct NAME {                                                                \
+    const TYPE *base;                                                          \
+    const TYPE *data;                                                          \
+    uint64_t off;                                                              \
+    uint64_t sizes[1];                                                         \
+    uint64_t strides[1];                                                       \
+  }
 
-struct MemRef1DU16 {
-  const uint16_t *base;
-  const uint16_t *data;
-  uint64_t off;
-  uint64_t sizes[1];
-  uint64_t strides[1];
-};
+#define CASE(p, i, v, P, I, V)                                                 \
+  if (ptrTp == (p) && indTp == (i) && valTp == (v))                            \
+  return newSparseTensor<P, I, V>(filename, sparsity, asize)
 
-struct MemRef1DU8 {
-  const uint8_t *base;
-  const uint8_t *data;
-  uint64_t off;
-  uint64_t sizes[1];
-  uint64_t strides[1];
-};
+#define IMPL1(RET, NAME, TYPE, LIB)                                            \
+  RET NAME(void *tensor) {                                                     \
+    std::vector<TYPE> *v;                                                      \
+    static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
+    return {v->data(), v->data(), 0, {v->size()}, {1}};                        \
+  }
 
-struct MemRef1DF64 {
-  const double *base;
-  const double *data;
-  uint64_t off;
-  uint64_t sizes[1];
-  uint64_t strides[1];
-};
+#define IMPL2(RET, NAME, TYPE, LIB)                                            \
+  RET NAME(void *tensor, uint64_t d) {                                         \
+    std::vector<TYPE> *v;                                                      \
+    static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
+    return {v->data(), v->data(), 0, {v->size()}, {1}};                        \
+  }
 
-struct MemRef1DF32 {
-  const float *base;
-  const float *data;
-  uint64_t off;
-  uint64_t sizes[1];
-  uint64_t strides[1];
-};
+TEMPLATE(MemRef1DU64, uint64_t);
+TEMPLATE(MemRef1DU32, uint32_t);
+TEMPLATE(MemRef1DU16, uint16_t);
+TEMPLATE(MemRef1DU8, uint8_t);
+TEMPLATE(MemRef1DI32, int32_t);
+TEMPLATE(MemRef1DI16, int16_t);
+TEMPLATE(MemRef1DI8, int8_t);
+TEMPLATE(MemRef1DF64, double);
+TEMPLATE(MemRef1DF32, float);
 
 enum OverheadTypeEnum : uint64_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
-enum PrimaryTypeEnum : uint64_t { kF64 = 1, kF32 = 2 };
 
-#define CASE(p, i, v, P, I, V)                                                 \
-  if (ptrTp == (p) && indTp == (i) && valTp == (v))                            \
-  return newSparseTensor<P, I, V>(filename, sparsity, asize)
+enum PrimaryTypeEnum : uint64_t {
+  kF64 = 1,
+  kF32 = 2,
+  kI32 = 3,
+  kI16 = 4,
+  kI8 = 5
+};
 
 void *newSparseTensor(char *filename, bool *abase, bool *adata, uint64_t aoff,
                       uint64_t asize, uint64_t astride, uint64_t ptrTp,
@@ -534,6 +531,17 @@ void *newSparseTensor(char *filename, bool *abase, bool *adata, uint64_t aoff,
   CASE(kU16, kU16, kF32, uint16_t, uint16_t, float);
   CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
 
+  // Integral matrices with low overhead storage.
+  CASE(kU32, kU32, kI32, uint32_t, uint32_t, int32_t);
+  CASE(kU32, kU32, kI16, uint32_t, uint32_t, int16_t);
+  CASE(kU32, kU32, kI8, uint32_t, uint32_t, int8_t);
+  CASE(kU16, kU16, kI32, uint16_t, uint16_t, int32_t);
+  CASE(kU16, kU16, kI16, uint16_t, uint16_t, int16_t);
+  CASE(kU16, kU16, kI8, uint16_t, uint16_t, int8_t);
+  CASE(kU8, kU8, kI32, uint8_t, uint8_t, int32_t);
+  CASE(kU8, kU8, kI16, uint8_t, uint8_t, int16_t);
+  CASE(kU8, kU8, kI8, uint8_t, uint8_t, int8_t);
+
   // Unsupported case (add above if needed).
   fputs("unsupported combination of types\n", stderr);
   exit(1);
@@ -545,70 +553,29 @@ uint64_t sparseDimSize(void *tensor, uint64_t d) {
   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
 }
 
-MemRef1DU64 sparsePointers64(void *tensor, uint64_t d) {
-  std::vector<uint64_t> *v;
-  static_cast<SparseTensorStorageBase *>(tensor)->getPointers(&v, d);
-  return {v->data(), v->data(), 0, {v->size()}, {1}};
-}
-
-MemRef1DU32 sparsePointers32(void *tensor, uint64_t d) {
-  std::vector<uint32_t> *v;
-  static_cast<SparseTensorStorageBase *>(tensor)->getPointers(&v, d);
-  return {v->data(), v->data(), 0, {v->size()}, {1}};
-}
-
-MemRef1DU16 sparsePointers16(void *tensor, uint64_t d) {
-  std::vector<uint16_t> *v;
-  static_cast<SparseTensorStorageBase *>(tensor)->getPointers(&v, d);
-  return {v->data(), v->data(), 0, {v->size()}, {1}};
-}
-
-MemRef1DU8 sparsePointers8(void *tensor, uint64_t d) {
-  std::vector<uint8_t> *v;
-  static_cast<SparseTensorStorageBase *>(tensor)->getPointers(&v, d);
-  return {v->data(), v->data(), 0, {v->size()}, {1}};
-}
-
-MemRef1DU64 sparseIndices64(void *tensor, uint64_t d) {
-  std::vector<uint64_t> *v;
-  static_cast<SparseTensorStorageBase *>(tensor)->getIndices(&v, d);
-  return {v->data(), v->data(), 0, {v->size()}, {1}};
-}
-
-MemRef1DU32 sparseIndices32(void *tensor, uint64_t d) {
-  std::vector<uint32_t> *v;
-  static_cast<SparseTensorStorageBase *>(tensor)->getIndices(&v, d);
-  return {v->data(), v->data(), 0, {v->size()}, {1}};
-}
-
-MemRef1DU16 sparseIndices16(void *tensor, uint64_t d) {
-  std::vector<uint16_t> *v;
-  static_cast<SparseTensorStorageBase *>(tensor)->getIndices(&v, d);
-  return {v->data(), v->data(), 0, {v->size()}, {1}};
-}
-
-MemRef1DU8 sparseIndices8(void *tensor, uint64_t d) {
-  std::vector<uint8_t> *v;
-  static_cast<SparseTensorStorageBase *>(tensor)->getIndices(&v, d);
-  return {v->data(), v->data(), 0, {v->size()}, {1}};
-}
-
-MemRef1DF64 sparseValuesF64(void *tensor) {
-  std::vector<double> *v;
-  static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);
-  return {v->data(), v->data(), 0, {v->size()}, {1}};
-}
-
-MemRef1DF32 sparseValuesF32(void *tensor) {
-  std::vector<float> *v;
-  static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v);
-  return {v->data(), v->data(), 0, {v->size()}, {1}};
-}
+IMPL2(MemRef1DU64, sparsePointers64, uint64_t, getPointers)
+IMPL2(MemRef1DU32, sparsePointers32, uint32_t, getPointers)
+IMPL2(MemRef1DU16, sparsePointers16, uint16_t, getPointers)
+IMPL2(MemRef1DU8, sparsePointers8, uint8_t, getPointers)
+IMPL2(MemRef1DU64, sparseIndices64, uint64_t, getIndices)
+IMPL2(MemRef1DU32, sparseIndices32, uint32_t, getIndices)
+IMPL2(MemRef1DU16, sparseIndices16, uint16_t, getIndices)
+IMPL2(MemRef1DU8, sparseIndices8, uint8_t, getIndices)
+IMPL1(MemRef1DF64, sparseValuesF64, double, getValues)
+IMPL1(MemRef1DF32, sparseValuesF32, float, getValues)
+IMPL1(MemRef1DI32, sparseValuesI32, int32_t, getValues)
+IMPL1(MemRef1DI16, sparseValuesI16, int16_t, getValues)
+IMPL1(MemRef1DI8, sparseValuesI8, int8_t, getValues)
 
 void delSparseTensor(void *tensor) {
   delete static_cast<SparseTensorStorageBase *>(tensor);
 }
 
+#undef TEMPLATE
+#undef CASE
+#undef IMPL1
+#undef IMPL2
+
 } // extern "C"
 
 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS

diff  --git a/mlir/test/Integration/Sparse/CPU/sparse_matvec.mlir b/mlir/test/Integration/Sparse/CPU/sparse_matvec.mlir
index 41ee9ccc63b8..fde494760982 100644
--- a/mlir/test/Integration/Sparse/CPU/sparse_matvec.mlir
+++ b/mlir/test/Integration/Sparse/CPU/sparse_matvec.mlir
@@ -54,18 +54,18 @@ module {
   // a sparse matrix A with a dense vector b into a dense vector x.
   //
   func @kernel_matvec(%argA: !SparseTensor,
-                      %argb: tensor<?xf32>,
-                      %argx: tensor<?xf32>) -> tensor<?xf32> {
-    %arga = linalg.sparse_tensor %argA : !SparseTensor to tensor<?x?xf32>
+                      %argb: tensor<?xi32>,
+                      %argx: tensor<?xi32>) -> tensor<?xi32> {
+    %arga = linalg.sparse_tensor %argA : !SparseTensor to tensor<?x?xi32>
     %0 = linalg.generic #matvec
-      ins(%arga, %argb: tensor<?x?xf32>, tensor<?xf32>)
-      outs(%argx: tensor<?xf32>) {
-      ^bb(%a: f32, %b: f32, %x: f32):
-        %0 = mulf %a, %b : f32
-        %1 = addf %x, %0 : f32
-        linalg.yield %1 : f32
-    } -> tensor<?xf32>
-    return %0 : tensor<?xf32>
+      ins(%arga, %argb: tensor<?x?xi32>, tensor<?xi32>)
+      outs(%argx: tensor<?xi32>) {
+      ^bb(%a: i32, %b: i32, %x: i32):
+        %0 = muli %a, %b : i32
+        %1 = addi %x, %0 : i32
+        linalg.yield %1 : i32
+    } -> tensor<?xi32>
+    return %0 : tensor<?xi32>
   }
 
   //
@@ -79,7 +79,7 @@ module {
   // Main driver that reads matrix from file and calls the sparse kernel.
   //
   func @entry() {
-    %f0 = constant 0.0 : f32
+    %i0 = constant 0 : i32
     %c0 = constant 0 : index
     %c1 = constant 1 : index
     %c2 = constant 2 : index
@@ -89,51 +89,51 @@ module {
     // Mark inner dimension of the matrix as sparse and encode the
     // storage scheme types (this must match the metadata in the
     // alias above and compiler switches). In this case, we test
-    // that 8-bit indices and pointers work correctly.
+    // that 8-bit indices and pointers work correctly on a matrix
+    // with i32 elements.
     %annotations = memref.alloc(%c2) : memref<?xi1>
     %sparse = constant true
     %dense = constant false
     memref.store %dense, %annotations[%c0] : memref<?xi1>
     memref.store %sparse, %annotations[%c1] : memref<?xi1>
     %u8 = constant 4 : index
-    %f32 = constant 2 : index
+    %i32 = constant 3 : index
 
     // Read the sparse matrix from file, construct sparse storage.
     %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
-    %a = call @newSparseTensor(%fileName, %annotations, %u8, %u8, %f32)
+    %a = call @newSparseTensor(%fileName, %annotations, %u8, %u8, %i32)
       : (!Filename, memref<?xi1>, index, index, index) -> (!SparseTensor)
 
     // Initialize dense vectors.
-    %bdata = memref.alloc(%c256) : memref<?xf32>
-    %xdata = memref.alloc(%c4) : memref<?xf32>
+    %bdata = memref.alloc(%c256) : memref<?xi32>
+    %xdata = memref.alloc(%c4) : memref<?xi32>
     scf.for %i = %c0 to %c256 step %c1 {
       %k = addi %i, %c1 : index
-      %l = index_cast %k : index to i32
-      %f = sitofp %l : i32 to f32
-      memref.store %f, %bdata[%i] : memref<?xf32>
+      %j = index_cast %k : index to i32
+      memref.store %j, %bdata[%i] : memref<?xi32>
     }
     scf.for %i = %c0 to %c4 step %c1 {
-      memref.store %f0, %xdata[%i] : memref<?xf32>
+      memref.store %i0, %xdata[%i] : memref<?xi32>
     }
-    %b = memref.tensor_load %bdata : memref<?xf32>
-    %x = memref.tensor_load %xdata : memref<?xf32>
+    %b = memref.tensor_load %bdata : memref<?xi32>
+    %x = memref.tensor_load %xdata : memref<?xi32>
 
     // Call kernel.
     %0 = call @kernel_matvec(%a, %b, %x)
-      : (!SparseTensor, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+      : (!SparseTensor, tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
 
     // Print the result for verification.
     //
-    // CHECK: ( 1659, 1534, 21, 18315 )
+    // CHECK: ( 889, 1514, -21, -3431 )
     //
-    %m = memref.buffer_cast %0 : memref<?xf32>
-    %v = vector.transfer_read %m[%c0], %f0: memref<?xf32>, vector<4xf32>
-    vector.print %v : vector<4xf32>
+    %m = memref.buffer_cast %0 : memref<?xi32>
+    %v = vector.transfer_read %m[%c0], %i0: memref<?xi32>, vector<4xi32>
+    vector.print %v : vector<4xi32>
 
     // Release the resources.
     call @delSparseTensor(%a) : (!SparseTensor) -> ()
-    memref.dealloc %bdata : memref<?xf32>
-    memref.dealloc %xdata : memref<?xf32>
+    memref.dealloc %bdata : memref<?xi32>
+    memref.dealloc %xdata : memref<?xi32>
 
     return
   }

diff  --git a/mlir/test/Integration/data/wide.mtx b/mlir/test/Integration/data/wide.mtx
index 6b5ee208afe1..9e0d5f2a1132 100644
--- a/mlir/test/Integration/data/wide.mtx
+++ b/mlir/test/Integration/data/wide.mtx
@@ -4,20 +4,20 @@
 % see https://math.nist.gov/MatrixMarket
 %
 4 256 17
-1 1     1.0
-1 127   2.0
-1 128   3.0
-1 255   4.0
-2 2     5.0
-2 254   6.0
-3 3     7.0
-4 1     8.0
-4 2     9.0
-4 4    10.0
-4 99   11.0
-4 127  12.0
-4 128  13.0
-4 129  14.0
-4 250  15.0
-4 254  16.0
-4 256  17.0
+1 1    -1
+1 127   2
+1 128  -3
+1 255   4
+2 2    -5
+2 254   6
+3 3    -7
+4 1     8
+4 2    -9
+4 4    10
+4 99  -11
+4 127  12
+4 128 -13
+4 129  14
+4 250 -15
+4 254  16
+4 256 -17


        


More information about the Mlir-commits mailing list