[Mlir-commits] [mlir] bc04a47 - [mlir][sparse] adding OverheadType::kIndex

wren romano llvmlistbot at llvm.org
Tue Jan 4 16:16:00 PST 2022


Author: wren romano
Date: 2022-01-04T16:15:54-08:00
New Revision: bc04a4703824c005490e7ae79f64e873e1bd6c92

URL: https://github.com/llvm/llvm-project/commit/bc04a4703824c005490e7ae79f64e873e1bd6c92
DIFF: https://github.com/llvm/llvm-project/commit/bc04a4703824c005490e7ae79f64e873e1bd6c92.diff

LOG: [mlir][sparse] adding OverheadType::kIndex

Depends On D115008

This change opens the way for D115012, and removes some corner cases in `CodegenUtils.cpp`. The `SparseTensorAttrDefs.td` already specifies that we allow `0` bitwidth for the two overhead types and that it is interpreted to mean the architecture's native width.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D115010

Added: 
    

Modified: 
    mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
    mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
index 4361fc7d43e75..a1f1dd6ae32d1 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
@@ -18,8 +18,22 @@
 
 extern "C" {
 
-/// Encoding of the elemental type, for "overloading" @newSparseTensor.
-enum class OverheadType : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
+/// This type is used in the public API at all places where MLIR expects
+/// values with the built-in type "index". For now, we simply assume that
+/// type is 64-bit, but targets with 
diff erent "index" bit widths should link
+/// with an alternatively built runtime support library.
+// TODO: support such targets?
+using index_t = uint64_t;
+
+/// Encoding of overhead types (both pointer overhead and indices
+/// overhead), for "overloading" @newSparseTensor.
+enum class OverheadType : uint32_t {
+  kIndex = 0,
+  kU64 = 1,
+  kU32 = 2,
+  kU16 = 3,
+  kU8 = 4
+};
 
 /// Encoding of the elemental type, for "overloading" @newSparseTensor.
 enum class PrimaryType : uint32_t {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 602e1f7484432..0d45ff15e8998 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -20,7 +20,7 @@ using namespace mlir::sparse_tensor;
 
 OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
   switch (width) {
-  default:
+  case 64:
     return OverheadType::kU64;
   case 32:
     return OverheadType::kU32;
@@ -28,11 +28,16 @@ OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
     return OverheadType::kU16;
   case 8:
     return OverheadType::kU8;
+  case 0:
+    return OverheadType::kIndex;
   }
+  llvm_unreachable("Unsupported overhead bitwidth");
 }
 
 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
   switch (ot) {
+  case OverheadType::kIndex:
+    return builder.getIndexType();
   case OverheadType::kU64:
     return builder.getIntegerType(64);
   case OverheadType::kU32:
@@ -47,20 +52,13 @@ Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
 
 Type mlir::sparse_tensor::getPointerOverheadType(
     Builder &builder, const SparseTensorEncodingAttr &enc) {
-  // NOTE(wrengr): This workaround will be fixed in D115010.
-  unsigned width = enc.getPointerBitWidth();
-  if (width == 0)
-    return builder.getIndexType();
-  return getOverheadType(builder, overheadTypeEncoding(width));
+  return getOverheadType(builder,
+                         overheadTypeEncoding(enc.getPointerBitWidth()));
 }
 
 Type mlir::sparse_tensor::getIndexOverheadType(
     Builder &builder, const SparseTensorEncodingAttr &enc) {
-  // NOTE(wrengr): This workaround will be fixed in D115010.
-  unsigned width = enc.getIndexBitWidth();
-  if (width == 0)
-    return builder.getIndexType();
-  return getOverheadType(builder, overheadTypeEncoding(width));
+  return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth()));
 }
 
 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {

diff  --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index 927284ec13f42..3681ca17674b4 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -686,13 +686,6 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
 
 extern "C" {
 
-/// This type is used in the public API at all places where MLIR expects
-/// values with the built-in type "index". For now, we simply assume that
-/// type is 64-bit, but targets with 
diff erent "index" bit widths should link
-/// with an alternatively built runtime support library.
-// TODO: support such targets?
-using index_t = uint64_t;
-
 //===----------------------------------------------------------------------===//
 //
 // Public API with methods that operate on MLIR buffers (memrefs) to interact
@@ -821,6 +814,12 @@ using index_t = uint64_t;
         cursor, values, filled, added, count);                                 \
   }
 
+// Assume index_t is in fact uint64_t, so that _mlir_ciface_newSparseTensor
+// can safely rewrite kIndex to kU64.  We make this assertion to guarantee
+// that this file cannot get out of sync with its header.
+static_assert(std::is_same<index_t, uint64_t>::value,
+              "Expected index_t == uint64_t");
+
 /// Constructs a new sparse tensor. This is the "swiss army knife"
 /// method for materializing sparse tensors into the computation.
 ///
@@ -846,6 +845,13 @@ _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
   const index_t *perm = pref->data + pref->offset;
   uint64_t rank = aref->sizes[0];
 
+  // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
+  // This is safe because of the static_assert above.
+  if (ptrTp == OverheadType::kIndex)
+    ptrTp = OverheadType::kU64;
+  if (indTp == OverheadType::kIndex)
+    indTp = OverheadType::kU64;
+
   // Double matrices with all combinations of overhead storage.
   CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
        uint64_t, double);

diff  --git a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
index 0b7c20392d348..2917685064afc 100644
--- a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
@@ -27,16 +27,15 @@
 //   CHECK-DAG: %[[PermS:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<1xindex> to memref<?xindex>
 //   CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<1xindex>
-//   CHECK-DAG: %[[SecTp:.*]] = arith.constant 1 : i32
+//   CHECK-DAG: %[[zeroI32:.*]] = arith.constant 0 : i32
 //   CHECK-DAG: %[[ElemTp:.*]] = arith.constant 4 : i32
 //   CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32
-//   CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[SecTp]], %[[SecTp]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[zeroI32]], %[[zeroI32]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
 //   CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref<?xindex>
 //   CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref<i32>
 //   CHECK-DAG: %[[M:.*]] = memref.alloc() : memref<13xi32>
-//   CHECK-DAG: %[[E0:.*]] = arith.constant 0 : i32
-//   CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : i32, memref<13xi32>
+//   CHECK-DAG: linalg.fill(%[[zeroI32]], %[[M]]) : i32, memref<13xi32>
 //       CHECK: scf.while : () -> () {
 //       CHECK:   %[[Cond:.*]] = call @getNextI32(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<i32>) -> i1
 //       CHECK:   scf.condition(%[[Cond]])
@@ -67,16 +66,15 @@ func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13xi32>
 //   CHECK-DAG: %[[PermS:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<1xindex> to memref<?xindex>
 //   CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<1xindex>
-//   CHECK-DAG: %[[SecTp:.*]] = arith.constant 1 : i32
+//   CHECK-DAG: %[[zeroI32:.*]] = arith.constant 0 : i32
 //   CHECK-DAG: %[[ElemTp:.*]] = arith.constant 4 : i32
 //   CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32
-//   CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[SecTp]], %[[SecTp]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
+//   CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[zeroI32]], %[[zeroI32]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
 //   CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref<?xindex>
 //   CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref<i32>
 //   CHECK-DAG: %[[M:.*]] = memref.alloc(%[[SizeI0]]) : memref<?xi32>
-//   CHECK-DAG: %[[E0:.*]] = arith.constant 0 : i32
-//   CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : i32, memref<?xi32>
+//   CHECK-DAG: linalg.fill(%[[zeroI32]], %[[M]]) : i32, memref<?xi32>
 //       CHECK: scf.while : () -> () {
 //       CHECK:   %[[Cond:.*]] = call @getNextI32(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<i32>) -> i1
 //       CHECK:   scf.condition(%[[Cond]])


        


More information about the Mlir-commits mailing list