[mlir] [llvm] [mlir][sparse] Change LevelType enum to 64 bit (PR #80501)

via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 2 14:40:56 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Yinying Li (yinying-lisa-li)

<details>
<summary>Changes</summary>

1. C++ enum is set through enum class LevelType : uint_64.
2. C enum is set through typedef uint_64 level_type. It is due to the limitations in Windows build: setting enum width to ui64 is not supported in C.

---
Full diff: https://github.com/llvm/llvm-project/pull/80501.diff


9 Files Affected:

- (modified) mlir/include/mlir-c/Dialect/SparseTensor.h (+5-4) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h (+19-19) 
- (modified) mlir/lib/Bindings/Python/DialectSparseTensor.cpp (+2-2) 
- (modified) mlir/lib/CAPI/Dialect/SparseTensor.cpp (+8-8) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h (+1-1) 
- (modified) mlir/test/CAPI/sparse_tensor.c (+2-3) 
- (modified) mlir/test/Dialect/SparseTensor/conversion.mlir (+8-8) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (+6-6) 
- (modified) mlir/test/python/dialects/sparse_tensor/dialect.py (+2-2) 


``````````diff
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 41d024db04964..25656f1b7654c 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -25,6 +25,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 /// These correspond to SparseTensorEncodingAttr::LevelType in the C++ API.
 /// If updating, keep them in sync and update the static_assert in the impl
 /// file.
+typedef uint64_t level_type;
+
 enum MlirSparseTensorLevelType {
   MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4,                   // 0b00001_00
   MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8,              // 0b00010_00
@@ -52,16 +54,15 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
 
 /// Creates a `sparse_tensor.encoding` attribute with the given parameters.
 MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
-    MlirContext ctx, intptr_t lvlRank,
-    enum MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
-    MlirAffineMap lvlTodim, int posWidth, int crdWidth);
+    MlirContext ctx, intptr_t lvlRank, level_type const *lvlTypes,
+    MlirAffineMap dimToLvl, MlirAffineMap lvlTodim, int posWidth, int crdWidth);
 
 /// Returns the level-rank of the `sparse_tensor.encoding` attribute.
 MLIR_CAPI_EXPORTED intptr_t
 mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr);
 
 /// Returns a specified level-type of the `sparse_tensor.encoding` attribute.
-MLIR_CAPI_EXPORTED enum MlirSparseTensorLevelType
+MLIR_CAPI_EXPORTED level_type
 mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl);
 
 /// Returns the dimension-to-level mapping of the `sparse_tensor.encoding`
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index ac91bfa5ae622..1f662e2042304 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -165,7 +165,7 @@ enum class Action : uint32_t {
 /// where we need to store an undefined or indeterminate `LevelType`.
 /// It should not be used externally, since it does not indicate an
 /// actual/representable format.
-enum class LevelType : uint8_t {
+enum class LevelType : uint64_t {
   Undef = 0,                // 0b00000_00
   Dense = 4,                // 0b00001_00
   Compressed = 8,           // 0b00010_00
@@ -184,7 +184,7 @@ enum class LevelType : uint8_t {
 };
 
 /// This enum defines all supported storage format without the level properties.
-enum class LevelFormat : uint8_t {
+enum class LevelFormat : uint64_t {
   Dense = 4,            // 0b00001_00
   Compressed = 8,       // 0b00010_00
   Singleton = 16,       // 0b00100_00
@@ -193,7 +193,7 @@ enum class LevelFormat : uint8_t {
 };
 
 /// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropertyNondefault : uint8_t {
+enum class LevelPropertyNondefault : uint64_t {
   Nonunique = 1,  // 0b00000_01
   Nonordered = 2, // 0b00000_10
 };
@@ -237,8 +237,8 @@ constexpr const char *toMLIRString(LevelType lt) {
 
 /// Check that the `LevelType` contains a valid (possibly undefined) value.
 constexpr bool isValidLT(LevelType lt) {
-  const uint8_t formatBits = static_cast<uint8_t>(lt) >> 2;
-  const uint8_t propertyBits = static_cast<uint8_t>(lt) & 3;
+  const uint64_t formatBits = static_cast<uint64_t>(lt) >> 2;
+  const uint64_t propertyBits = static_cast<uint64_t>(lt) & 3;
   // If undefined or dense, then must be unique and ordered.
   // Otherwise, the format must be one of the known ones.
   return (formatBits <= 1 || formatBits == 16)
@@ -251,32 +251,32 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
 
 /// Check if the `LevelType` is dense (regardless of properties).
 constexpr bool isDenseLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Dense);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::Dense);
 }
 
 /// Check if the `LevelType` is compressed (regardless of properties).
 constexpr bool isCompressedLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Compressed);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::Compressed);
 }
 
 /// Check if the `LevelType` is singleton (regardless of properties).
 constexpr bool isSingletonLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Singleton);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::Singleton);
 }
 
 /// Check if the `LevelType` is loose compressed (regardless of properties).
 constexpr bool isLooseCompressedLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::LooseCompressed);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::LooseCompressed);
 }
 
 /// Check if the `LevelType` is 2OutOf4 (regardless of properties).
 constexpr bool is2OutOf4LT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::TwoOutOfFour);
+  return (static_cast<uint64_t>(lt) & ~3) ==
+         static_cast<uint64_t>(LevelType::TwoOutOfFour);
 }
 
 /// Check if the `LevelType` needs positions array.
@@ -292,12 +292,12 @@ constexpr bool isWithCrdLT(LevelType lt) {
 
 /// Check if the `LevelType` is ordered (regardless of storage format).
 constexpr bool isOrderedLT(LevelType lt) {
-  return !(static_cast<uint8_t>(lt) & 2);
+  return !(static_cast<uint64_t>(lt) & 2);
 }
 
 /// Check if the `LevelType` is unique (regardless of storage format).
 constexpr bool isUniqueLT(LevelType lt) {
-  return !(static_cast<uint8_t>(lt) & 1);
+  return !(static_cast<uint64_t>(lt) & 1);
 }
 
 /// Convert a LevelType to its corresponding LevelFormat.
@@ -305,7 +305,7 @@ constexpr bool isUniqueLT(LevelType lt) {
 constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
   if (lt == LevelType::Undef)
     return std::nullopt;
-  return static_cast<LevelFormat>(static_cast<uint8_t>(lt) & ~3);
+  return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & ~3);
 }
 
 /// Convert a LevelFormat to its corresponding LevelType with the given
@@ -313,7 +313,7 @@ constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
 /// for the input level format.
 constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
                                                   bool unique) {
-  auto lt = static_cast<LevelType>(static_cast<uint8_t>(lf) |
+  auto lt = static_cast<LevelType>(static_cast<uint64_t>(lf) |
                                    (ordered ? 0 : 2) | (unique ? 0 : 1));
   return isValidLT(lt) ? std::optional(lt) : std::nullopt;
 }
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 8706c523988b1..3fe6a9e495dc5 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -46,7 +46,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
                           mlirAttributeIsASparseTensorEncodingAttr)
       .def_classmethod(
           "get",
-          [](py::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
+          [](py::object cls, std::vector<level_type> lvlTypes,
              std::optional<MlirAffineMap> dimToLvl,
              std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
              MlirContext context) {
@@ -64,7 +64,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
           "lvl_types",
           [](MlirAttribute self) {
             const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
-            std::vector<MlirSparseTensorLevelType> ret;
+            std::vector<level_type> ret;
             ret.reserve(lvlRank);
             for (int l = 0; l < lvlRank; ++l)
               ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index e4534ad132385..3f17b740e813c 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -44,11 +44,11 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
   return isa<SparseTensorEncodingAttr>(unwrap(attr));
 }
 
-MlirAttribute
-mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
-                                MlirSparseTensorLevelType const *lvlTypes,
-                                MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
-                                int posWidth, int crdWidth) {
+MlirAttribute mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
+                                              level_type const *lvlTypes,
+                                              MlirAffineMap dimToLvl,
+                                              MlirAffineMap lvlToDim,
+                                              int posWidth, int crdWidth) {
   SmallVector<LevelType> cppLvlTypes;
   cppLvlTypes.reserve(lvlRank);
   for (intptr_t l = 0; l < lvlRank; ++l)
@@ -70,9 +70,9 @@ intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
   return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank();
 }
 
-MlirSparseTensorLevelType
-mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) {
-  return static_cast<MlirSparseTensorLevelType>(
+level_type mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr,
+                                                  intptr_t lvl) {
+  return static_cast<level_type>(
       cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl));
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index 8d54b5959d871..cc119bc704559 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -423,7 +423,7 @@ inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
 /// Generates a constant of the internal dimension level type encoding.
 inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
                                        LevelType lt) {
-  return constantI8(builder, loc, static_cast<uint8_t>(lt));
+  return constantI64(builder, loc, static_cast<uint64_t>(lt));
 }
 
 inline bool isZeroRankedTensorOrScalar(Type type) {
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index b0bc9bb6e881a..f3edce81c35f5 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -43,11 +43,10 @@ static int testRoundtripEncoding(MlirContext ctx) {
   MlirAffineMap lvlToDim =
       mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
   int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
-  enum MlirSparseTensorLevelType *lvlTypes =
-      malloc(sizeof(enum MlirSparseTensorLevelType) * lvlRank);
+  level_type *lvlTypes = malloc(sizeof(level_type) * lvlRank);
   for (int l = 0; l < lvlRank; ++l) {
     lvlTypes[l] = mlirSparseTensorEncodingAttrGetLvlType(originalAttr, l);
-    fprintf(stderr, "level_type: %d\n", lvlTypes[l]);
+    fprintf(stderr, "level_type: %lu\n", lvlTypes[l]);
   }
   // CHECK: posWidth: 32
   int posWidth = mlirSparseTensorEncodingAttrGetPosWidth(originalAttr);
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index e4e825bf85043..465f210862660 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -78,8 +78,8 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind
 //   CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi64>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimShape]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
@@ -96,8 +96,8 @@ func.func @sparse_new1d(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> {
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
 //       CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
@@ -114,8 +114,8 @@ func.func @sparse_new2d(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CSR> {
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
 //       CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi64>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
 //   CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
@@ -136,10 +136,10 @@ func.func @sparse_new3d(%arg0: !llvm.ptr) -> tensor<?x?x?xf32, #SparseTensor> {
 //   CHECK-DAG: %[[Empty:.*]] = arith.constant 0 : i32
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64>
 //   CHECK-DAG: %[[Sizes0:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Sizes:.*]] = memref.cast %[[Sizes0]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: memref.store %[[I]], %[[Sizes0]][%[[C0]]] : memref<2xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 40367f12f85a4..7c494b2bcfe1d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -14,11 +14,11 @@
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 100 : index
 // CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 8 : i8
-// CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi8>
-// CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref<?xi8>
-// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8>
-// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi8>
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 8 : i64
+// CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
+// CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
+// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
+// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi64>
 // CHECK:           %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
 // CHECK:           memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
@@ -28,7 +28,7 @@
 // CHECK:           memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex>
 // CHECK:           memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex>
 // CHECK:           %[[VAL_18:.*]] = llvm.mlir.zero : !llvm.ptr
-// CHECK:           %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
+// CHECK:           %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
 // CHECK:           %[[VAL_20:.*]] = memref.alloc() : memref<300xf64>
 // CHECK:           %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref<?xf64>
 // CHECK:           %[[VAL_22:.*]] = memref.alloc() : memref<300xi1>
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 88a5595d75aea..946a224dab064 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -28,7 +28,7 @@ def testEncodingAttr1D():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [<LevelType.compressed: 8>]
+        # CHECK: lvl_types: [8]
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0) -> (d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")
@@ -70,7 +70,7 @@ def testEncodingAttr2D():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [<LevelType.dense: 4>, <LevelType.compressed: 8>]
+        # CHECK: lvl_types: [4, 8]
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")

``````````

</details>


https://github.com/llvm/llvm-project/pull/80501


More information about the llvm-commits mailing list