[Mlir-commits] [mlir] bc04a47 - [mlir][sparse] adding OverheadType::kIndex
wren romano
llvmlistbot at llvm.org
Tue Jan 4 16:16:00 PST 2022
Author: wren romano
Date: 2022-01-04T16:15:54-08:00
New Revision: bc04a4703824c005490e7ae79f64e873e1bd6c92
URL: https://github.com/llvm/llvm-project/commit/bc04a4703824c005490e7ae79f64e873e1bd6c92
DIFF: https://github.com/llvm/llvm-project/commit/bc04a4703824c005490e7ae79f64e873e1bd6c92.diff
LOG: [mlir][sparse] adding OverheadType::kIndex
Depends On D115008
This change opens the way for D115012, and removes some corner cases in `CodegenUtils.cpp`. The `SparseTensorAttrDefs.td` already specifies that we allow `0` bitwidth for the two overhead types and that it is interpreted to mean the architecture's native width.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D115010
Added:
Modified:
mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
index 4361fc7d43e75..a1f1dd6ae32d1 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
@@ -18,8 +18,22 @@
extern "C" {
-/// Encoding of the elemental type, for "overloading" @newSparseTensor.
-enum class OverheadType : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
+/// This type is used in the public API at all places where MLIR expects
+/// values with the built-in type "index". For now, we simply assume that
+/// type is 64-bit, but targets with
diff erent "index" bit widths should link
+/// with an alternatively built runtime support library.
+// TODO: support such targets?
+using index_t = uint64_t;
+
+/// Encoding of overhead types (both pointer overhead and indices
+/// overhead), for "overloading" @newSparseTensor.
+enum class OverheadType : uint32_t {
+ kIndex = 0,
+ kU64 = 1,
+ kU32 = 2,
+ kU16 = 3,
+ kU8 = 4
+};
/// Encoding of the elemental type, for "overloading" @newSparseTensor.
enum class PrimaryType : uint32_t {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 602e1f7484432..0d45ff15e8998 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -20,7 +20,7 @@ using namespace mlir::sparse_tensor;
OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
switch (width) {
- default:
+ case 64:
return OverheadType::kU64;
case 32:
return OverheadType::kU32;
@@ -28,11 +28,16 @@ OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
return OverheadType::kU16;
case 8:
return OverheadType::kU8;
+ case 0:
+ return OverheadType::kIndex;
}
+ llvm_unreachable("Unsupported overhead bitwidth");
}
Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
switch (ot) {
+ case OverheadType::kIndex:
+ return builder.getIndexType();
case OverheadType::kU64:
return builder.getIntegerType(64);
case OverheadType::kU32:
@@ -47,20 +52,13 @@ Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
Type mlir::sparse_tensor::getPointerOverheadType(
Builder &builder, const SparseTensorEncodingAttr &enc) {
- // NOTE(wrengr): This workaround will be fixed in D115010.
- unsigned width = enc.getPointerBitWidth();
- if (width == 0)
- return builder.getIndexType();
- return getOverheadType(builder, overheadTypeEncoding(width));
+ return getOverheadType(builder,
+ overheadTypeEncoding(enc.getPointerBitWidth()));
}
Type mlir::sparse_tensor::getIndexOverheadType(
Builder &builder, const SparseTensorEncodingAttr &enc) {
- // NOTE(wrengr): This workaround will be fixed in D115010.
- unsigned width = enc.getIndexBitWidth();
- if (width == 0)
- return builder.getIndexType();
- return getOverheadType(builder, overheadTypeEncoding(width));
+ return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth()));
}
PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index 927284ec13f42..3681ca17674b4 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -686,13 +686,6 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
extern "C" {
-/// This type is used in the public API at all places where MLIR expects
-/// values with the built-in type "index". For now, we simply assume that
-/// type is 64-bit, but targets with
diff erent "index" bit widths should link
-/// with an alternatively built runtime support library.
-// TODO: support such targets?
-using index_t = uint64_t;
-
//===----------------------------------------------------------------------===//
//
// Public API with methods that operate on MLIR buffers (memrefs) to interact
@@ -821,6 +814,12 @@ using index_t = uint64_t;
cursor, values, filled, added, count); \
}
+// Assume index_t is in fact uint64_t, so that _mlir_ciface_newSparseTensor
+// can safely rewrite kIndex to kU64. We make this assertion to guarantee
+// that this file cannot get out of sync with its header.
+static_assert(std::is_same<index_t, uint64_t>::value,
+ "Expected index_t == uint64_t");
+
/// Constructs a new sparse tensor. This is the "swiss army knife"
/// method for materializing sparse tensors into the computation.
///
@@ -846,6 +845,13 @@ _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
const index_t *perm = pref->data + pref->offset;
uint64_t rank = aref->sizes[0];
+ // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
+ // This is safe because of the static_assert above.
+ if (ptrTp == OverheadType::kIndex)
+ ptrTp = OverheadType::kU64;
+ if (indTp == OverheadType::kIndex)
+ indTp = OverheadType::kU64;
+
// Double matrices with all combinations of overhead storage.
CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
uint64_t, double);
diff --git a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
index 0b7c20392d348..2917685064afc 100644
--- a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
@@ -27,16 +27,15 @@
// CHECK-DAG: %[[PermS:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<1xindex>
-// CHECK-DAG: %[[SecTp:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[zeroI32:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[ElemTp:.*]] = arith.constant 4 : i32
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32
-// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[SecTp]], %[[SecTp]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
+// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[zeroI32]], %[[zeroI32]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref<i32>
// CHECK-DAG: %[[M:.*]] = memref.alloc() : memref<13xi32>
-// CHECK-DAG: %[[E0:.*]] = arith.constant 0 : i32
-// CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : i32, memref<13xi32>
+// CHECK-DAG: linalg.fill(%[[zeroI32]], %[[M]]) : i32, memref<13xi32>
// CHECK: scf.while : () -> () {
// CHECK: %[[Cond:.*]] = call @getNextI32(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<i32>) -> i1
// CHECK: scf.condition(%[[Cond]])
@@ -67,16 +66,15 @@ func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13xi32>
// CHECK-DAG: %[[PermS:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<1xindex>
-// CHECK-DAG: %[[SecTp:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[zeroI32:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[ElemTp:.*]] = arith.constant 4 : i32
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32
-// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[SecTp]], %[[SecTp]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
+// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[zeroI32]], %[[zeroI32]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref<i32>
// CHECK-DAG: %[[M:.*]] = memref.alloc(%[[SizeI0]]) : memref<?xi32>
-// CHECK-DAG: %[[E0:.*]] = arith.constant 0 : i32
-// CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : i32, memref<?xi32>
+// CHECK-DAG: linalg.fill(%[[zeroI32]], %[[M]]) : i32, memref<?xi32>
// CHECK: scf.while : () -> () {
// CHECK: %[[Cond:.*]] = call @getNextI32(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<i32>) -> i1
// CHECK: scf.condition(%[[Cond]])
More information about the Mlir-commits
mailing list