[Mlir-commits] [mlir] 933fefb - [mlir][sparse] Adjusting DimLevelType numeric values for faster predicates

wren romano llvmlistbot at llvm.org
Wed Oct 5 17:40:50 PDT 2022


Author: wren romano
Date: 2022-10-05T17:40:38-07:00
New Revision: 933fefb6a834836a4a9b044f6351f53daed7a2a0

URL: https://github.com/llvm/llvm-project/commit/933fefb6a834836a4a9b044f6351f53daed7a2a0
DIFF: https://github.com/llvm/llvm-project/commit/933fefb6a834836a4a9b044f6351f53daed7a2a0.diff

LOG: [mlir][sparse] Adjusting DimLevelType numeric values for faster predicates

This differential adjusts the numeric values for DimLevelType values: using the low-order two bits for recording the "No" and "Nu" properties, and the high-order bits for the formats per se.  (The choice of encoding may seem a bit peculiar, since the bits are mapped to negative properties rather than positive properties.  But this was done in order to preserve the collation order of DimLevelType values.  If we don't care about collation order, then we may prefer to flip the semantics of the property bits, so that they're less surprising to readers.)

Using distinguished bits for the properties and formats enables faster implementation for the predicates detecting those properties/formats, which matters because this is in the runtime library itself (rather than on the codegen side of things).  This differential pushes through the changes to the enum values, and optimizes the basic predicates.  However it does not optimize all the places where we check compound predicates (e.g., "is compressed or singleton"), to help reduce rebasing conflict with D134933.  Those optimizations will be done after this differential and D134933 are landed.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135004

Added: 
    

Modified: 
    mlir/include/mlir-c/Dialect/SparseTensor.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h
    mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
    mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
    mlir/test/CAPI/sparse_tensor.c
    mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
    mlir/test/Dialect/SparseTensor/sparse_concat.mlir
    mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 765e9c2936703..8027f319b28cf 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -26,15 +26,15 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 /// If updating, keep them in sync and update the static_assert in the impl
 /// file.
 enum MlirSparseTensorDimLevelType {
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE,
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED,
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU,
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO,
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO,
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON,
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU,
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO,
-  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO,
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4,             // 0b001_00
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8,        // 0b010_00
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9,     // 0b010_01
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10,    // 0b010_10
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b010_11
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16,        // 0b100_00
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17,     // 0b100_01
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18,     // 0b100_10
+  MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19,  // 0b100_11
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 459cc8dd867dc..74c0eacaa5146 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -168,10 +168,16 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     //
     // TODO: separate type and property in encoding
     //
-    enum class DimLevelType {
-      Dense,
-      Compressed, CompressedNu, CompressedNo, CompressedNuNo,
-      Singleton, SingletonNu, SingletonNo, SingletonNuNo,
+    enum class DimLevelType : uint8_t {
+      Dense            = 4,  // 0b001_00
+      Compressed       = 8,  // 0b010_00
+      CompressedNu     = 9,  // 0b010_01
+      CompressedNo     = 10, // 0b010_10
+      CompressedNuNo   = 11, // 0b010_11
+      Singleton        = 16, // 0b100_00
+      SingletonNu      = 17, // 0b100_01
+      SingletonNo      = 18, // 0b100_10
+      SingletonNuNo    = 19, // 0b100_11
     };
   }];
 }

diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h
index 84ac9f965018d..14f52a550444b 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h
@@ -146,15 +146,15 @@ enum class MLIR_SPARSETENSOR_EXPORT Action : uint32_t {
 /// breaking dependency cycles.  `SparseTensorEncodingAttr::DimLevelType`
 /// is the source of truth and this enum should be kept consistent with it.
 enum class MLIR_SPARSETENSOR_EXPORT DimLevelType : uint8_t {
-  kDense = 0,
-  kCompressed = 1,
-  kCompressedNu = 2,
-  kCompressedNo = 3,
-  kCompressedNuNo = 4,
-  kSingleton = 5,
-  kSingletonNu = 6,
-  kSingletonNo = 7,
-  kSingletonNuNo = 8,
+  kDense = 4,           // 0b001_00
+  kCompressed = 8,      // 0b010_00
+  kCompressedNu = 9,    // 0b010_01
+  kCompressedNo = 10,   // 0b010_10
+  kCompressedNuNo = 11, // 0b010_11
+  kSingleton = 16,      // 0b100_00
+  kSingletonNu = 17,    // 0b100_01
+  kSingletonNo = 18,    // 0b100_10
+  kSingletonNuNo = 19,  // 0b100_11
 };
 
 /// Check if the `DimLevelType` is dense.
@@ -164,56 +164,71 @@ constexpr MLIR_SPARSETENSOR_EXPORT bool isDenseDLT(DimLevelType dlt) {
 
 /// Check if the `DimLevelType` is compressed (regardless of properties).
 constexpr MLIR_SPARSETENSOR_EXPORT bool isCompressedDLT(DimLevelType dlt) {
-  switch (dlt) {
-  case DimLevelType::kCompressed:
-  case DimLevelType::kCompressedNu:
-  case DimLevelType::kCompressedNo:
-  case DimLevelType::kCompressedNuNo:
-    return true;
-  default:
-    return false;
-  }
+  return static_cast<uint8_t>(dlt) &
+         static_cast<uint8_t>(DimLevelType::kCompressed);
 }
 
 /// Check if the `DimLevelType` is singleton (regardless of properties).
 constexpr MLIR_SPARSETENSOR_EXPORT bool isSingletonDLT(DimLevelType dlt) {
-  switch (dlt) {
-  case DimLevelType::kSingleton:
-  case DimLevelType::kSingletonNu:
-  case DimLevelType::kSingletonNo:
-  case DimLevelType::kSingletonNuNo:
-    return true;
-  default:
-    return false;
-  }
+  return static_cast<uint8_t>(dlt) &
+         static_cast<uint8_t>(DimLevelType::kSingleton);
 }
 
 /// Check if the `DimLevelType` is ordered (regardless of storage format).
 constexpr MLIR_SPARSETENSOR_EXPORT bool isOrderedDLT(DimLevelType dlt) {
-  switch (dlt) {
-  case DimLevelType::kCompressedNo:
-  case DimLevelType::kCompressedNuNo:
-  case DimLevelType::kSingletonNo:
-  case DimLevelType::kSingletonNuNo:
-    return false;
-  default:
-    return true;
-  }
+  return !(static_cast<uint8_t>(dlt) & 2);
 }
 
 /// Check if the `DimLevelType` is unique (regardless of storage format).
 constexpr MLIR_SPARSETENSOR_EXPORT bool isUniqueDLT(DimLevelType dlt) {
-  switch (dlt) {
-  case DimLevelType::kCompressedNu:
-  case DimLevelType::kCompressedNuNo:
-  case DimLevelType::kSingletonNu:
-  case DimLevelType::kSingletonNuNo:
-    return false;
-  default:
-    return true;
-  }
+  return !(static_cast<uint8_t>(dlt) & 1);
 }
 
+// Ensure the above predicates work as intended.
+static_assert((!isCompressedDLT(DimLevelType::kDense) &&
+               isCompressedDLT(DimLevelType::kCompressed) &&
+               isCompressedDLT(DimLevelType::kCompressedNu) &&
+               isCompressedDLT(DimLevelType::kCompressedNo) &&
+               isCompressedDLT(DimLevelType::kCompressedNuNo) &&
+               !isCompressedDLT(DimLevelType::kSingleton) &&
+               !isCompressedDLT(DimLevelType::kSingletonNu) &&
+               !isCompressedDLT(DimLevelType::kSingletonNo) &&
+               !isCompressedDLT(DimLevelType::kSingletonNuNo)),
+              "isCompressedDLT definition is broken");
+
+static_assert((!isSingletonDLT(DimLevelType::kDense) &&
+               !isSingletonDLT(DimLevelType::kCompressed) &&
+               !isSingletonDLT(DimLevelType::kCompressedNu) &&
+               !isSingletonDLT(DimLevelType::kCompressedNo) &&
+               !isSingletonDLT(DimLevelType::kCompressedNuNo) &&
+               isSingletonDLT(DimLevelType::kSingleton) &&
+               isSingletonDLT(DimLevelType::kSingletonNu) &&
+               isSingletonDLT(DimLevelType::kSingletonNo) &&
+               isSingletonDLT(DimLevelType::kSingletonNuNo)),
+              "isSingletonDLT definition is broken");
+
+static_assert((isOrderedDLT(DimLevelType::kDense) &&
+               isOrderedDLT(DimLevelType::kCompressed) &&
+               isOrderedDLT(DimLevelType::kCompressedNu) &&
+               !isOrderedDLT(DimLevelType::kCompressedNo) &&
+               !isOrderedDLT(DimLevelType::kCompressedNuNo) &&
+               isOrderedDLT(DimLevelType::kSingleton) &&
+               isOrderedDLT(DimLevelType::kSingletonNu) &&
+               !isOrderedDLT(DimLevelType::kSingletonNo) &&
+               !isOrderedDLT(DimLevelType::kSingletonNuNo)),
+              "isOrderedDLT definition is broken");
+
+static_assert((isUniqueDLT(DimLevelType::kDense) &&
+               isUniqueDLT(DimLevelType::kCompressed) &&
+               !isUniqueDLT(DimLevelType::kCompressedNu) &&
+               isUniqueDLT(DimLevelType::kCompressedNo) &&
+               !isUniqueDLT(DimLevelType::kCompressedNuNo) &&
+               isUniqueDLT(DimLevelType::kSingleton) &&
+               !isUniqueDLT(DimLevelType::kSingletonNu) &&
+               isUniqueDLT(DimLevelType::kSingletonNo) &&
+               !isUniqueDLT(DimLevelType::kSingletonNuNo)),
+              "isUniqueDLT definition is broken");
+
 } // namespace sparse_tensor
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index f8633dedf54bd..c85e2c11f1566 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -634,7 +634,10 @@ class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
              "Value position is out of bounds");
       // TODO: <https://github.com/llvm/llvm-project/issues/54179>
       yield(this->cursor, src.values[parentPos]);
-    } else if (src.isCompressedDim(d)) {
+      return;
+    }
+    const auto dlt = src.getDimType(d);
+    if (isCompressedDLT(dlt)) {
       // Look up the bounds of the `d`-level segment determined by the
       // `d-1`-level position `parentPos`.
       const std::vector<P> &pointersD = src.pointers[d];
@@ -650,11 +653,11 @@ class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
         cursorReordD = static_cast<uint64_t>(indicesD[pos]);
         forallElements(yield, pos, d + 1);
       }
-    } else if (src.isSingletonDim(d)) {
+    } else if (isSingletonDLT(dlt)) {
       this->cursor[this->reord[d]] = src.getIndex(d, parentPos);
       forallElements(yield, parentPos, d + 1);
-    } else { // Dense dimension.
-      assert(src.isDenseDim(d)); // TODO: reuse the ASSERT_DENSE_DIM message
+    } else {
+      assert(isDenseDLT(dlt)); // TODO: reuse the ASSERT_DENSE_DIM message
       const uint64_t sz = src.getDimSizes()[d];
       const uint64_t pstart = parentPos * sz;
       uint64_t &cursorReordD = this->cursor[this->reord[d]];

diff  --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index 3eca2ba158c58..f02bb9ad35ca7 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -87,11 +87,12 @@ toMLIRSparseTensor(uint64_t rank, uint64_t nse, const uint64_t *shape,
   }
 
   // Verify that the sparsity values are supported.
+  // TODO: update this check to match what we actually support.
   for (uint64_t i = 0; i < rank; ++i)
     if (sparsity[i] != DimLevelType::kDense &&
         sparsity[i] != DimLevelType::kCompressed)
-      MLIR_SPARSETENSOR_FATAL("Unsupported sparsity value %d\n",
-                              static_cast<int>(sparsity[i]));
+      MLIR_SPARSETENSOR_FATAL("unsupported dimension level type: %d\n",
+                              static_cast<uint8_t>(sparsity[i]));
 #endif
 
   // Convert external format to internal COO.

diff  --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index fc3437b039fc1..58af741c19b84 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -43,9 +43,9 @@ static int testRoundtripEncoding(MlirContext ctx) {
       mlirSparseTensorEncodingAttrGetHigherOrdering(originalAttr);
   // CHECK: (d0, d1)[s0] -> (s0, d0, d1)
   mlirAffineMapDump(higherOrdering);
-  // CHECK: level_type: 0
-  // CHECK: level_type: 1
-  // CHECK: level_type: 1
+  // CHECK: level_type: 4
+  // CHECK: level_type: 8
+  // CHECK: level_type: 8
   int numLevelTypes = mlirSparseTensorEncodingGetNumDimLevelTypes(originalAttr);
   enum MlirSparseTensorDimLevelType *levelTypes =
       malloc(sizeof(enum MlirSparseTensorDimLevelType) * numLevelTypes);

diff  --git a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
index 32308d17d3ff5..ee7499a1c1201 100644
--- a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir
@@ -19,8 +19,8 @@
 //   CHECK-DAG: %[[I13:.*]] = arith.constant 13 : index
 //   CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<1xi8>
 //   CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<1xi8> to memref<?xi8>
-//   CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8
-//   CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<1xi8>
+//   CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
+//   CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<1xi8>
 //   CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<1xindex> to memref<?xindex>
 //   CHECK-DAG: memref.store %[[I13]], %[[SizesS]][%[[I0]]] : memref<1xindex>
@@ -56,8 +56,8 @@ func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13x
 //   CHECK-DAG: %[[I0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<1xi8>
 //   CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<1xi8> to memref<?xi8>
-//   CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8
-//   CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<1xi8>
+//   CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
+//   CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<1xi8>
 //   CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<1xindex> to memref<?xindex>
 //   CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
@@ -97,9 +97,9 @@ func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<
 //   CHECK-DAG: %[[I4:.*]] = arith.constant 4 : index
 //   CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<2xi8>
 //   CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<2xi8> to memref<?xi8>
-//   CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8
-//   CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
-//   CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
+//   CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
+//   CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
+//   CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
 //   CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: memref.store %[[I2]], %[[SizesS]][%[[I0]]] : memref<2xindex>
@@ -140,9 +140,9 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x
 //   CHECK-DAG: %[[I4:.*]] = arith.constant 4 : index
 //   CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<2xi8>
 //   CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<2xi8> to memref<?xi8>
-//   CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8
-//   CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
-//   CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
+//   CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
+//   CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
+//   CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
 //   CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
@@ -184,9 +184,9 @@ func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tens
 //   CHECK-DAG: %[[I2:.*]] = arith.constant 2 : index
 //   CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<2xi8>
 //   CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<2xi8> to memref<?xi8>
-//   CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8
-//   CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
-//   CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
+//   CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
+//   CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
+//   CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
 //   CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: %[[SizeI1:.*]] = call @sparseDimSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
@@ -227,9 +227,9 @@ func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tens
 //   CHECK-DAG: %[[I1:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<2xi8>
 //   CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<2xi8> to memref<?xi8>
-//   CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8
-//   CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
-//   CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
+//   CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
+//   CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<2xi8>
+//   CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<2xi8>
 //   CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
@@ -274,10 +274,10 @@ func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tens
 //   CHECK-DAG: %[[I4:.*]] = arith.constant 4 : index
 //   CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<3xi8>
 //   CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<3xi8> to memref<?xi8>
-//   CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8
-//   CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<3xi8>
-//   CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<3xi8>
-//   CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I2]]] : memref<3xi8>
+//   CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
+//   CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<3xi8>
+//   CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<3xi8>
+//   CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I2]]] : memref<3xi8>
 //   CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<3xindex>
 //   CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<3xindex> to memref<?xindex>
 //   CHECK-DAG: memref.store %[[I2]], %[[SizesS]][%[[I0]]] : memref<3xindex>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
index 4bdb5dd8a711f..b51c72e2d6e62 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
@@ -14,7 +14,7 @@
 // CHECK-DAG:     %[[TMP_c6_i32:.*]] = arith.constant 6 : i32
 // CHECK-DAG:     %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
 // CHECK-DAG:     %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
-// CHECK-DAG:     %[[TMP_c1_i8:.*]] = arith.constant 1 : i8
+// CHECK-DAG:     %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
 // CHECK-DAG:     %[[TMP_c3:.*]] = arith.constant 3 : index
 // CHECK-DAG:     %[[TMP_c1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
@@ -33,8 +33,8 @@
 // CHECK:         }
 // CHECK:         %[[TMP_1:.*]] = memref.alloca() : memref<2xi8>
 // CHECK:         %[[TMP_2:.*]] = memref.cast %[[TMP_1]] : memref<2xi8> to memref<?xi8>
-// CHECK:         memref.store %[[TMP_c1_i8]], %[[TMP_1]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK:         memref.store %[[TMP_c1_i8]], %[[TMP_1]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK:         memref.store %[[TMP_c8_i8]], %[[TMP_1]][%[[TMP_c0]]] : memref<2xi8>
+// CHECK:         memref.store %[[TMP_c8_i8]], %[[TMP_1]][%[[TMP_c1]]] : memref<2xi8>
 // CHECK:         %[[TMP_3:.*]] = memref.alloca() : memref<2xindex>
 // CHECK:         %[[TMP_4:.*]] = memref.cast %[[TMP_3]] : memref<2xindex> to memref<?xindex>
 // CHECK:         memref.store %[[TMP_c3]], %[[TMP_3]][%[[TMP_c0]]] : memref<2xindex>
@@ -83,11 +83,11 @@ func.func @concat_mix_dense(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spar
 // CHECK-DAG:     %[[TMP_c0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[TMP_c5:.*]] = arith.constant 5 : index
 // CHECK-DAG:     %[[TMP_c4:.*]] = arith.constant 4 : index
-// CHECK-DAG:     %[[TMP_c1_i8:.*]] = arith.constant 1 : i8
+// CHECK-DAG:     %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
 // CHECK:         %[[TMP_0:.*]] = memref.alloca() : memref<2xi8>
 // CHECK:         %[[TMP_1:.*]] = memref.cast %[[TMP_0]] : memref<2xi8> to memref<?xi8>
-// CHECK:         memref.store %[[TMP_c1_i8]], %[[TMP_0]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK:         memref.store %[[TMP_c1_i8]], %[[TMP_0]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK:         memref.store %[[TMP_c8_i8]], %[[TMP_0]][%[[TMP_c0]]] : memref<2xi8>
+// CHECK:         memref.store %[[TMP_c8_i8]], %[[TMP_0]][%[[TMP_c1]]] : memref<2xi8>
 // CHECK:         %[[TMP_2:.*]] = memref.alloca() : memref<2xindex>
 // CHECK:         %[[TMP_3:.*]] = memref.cast %[[TMP_2]] : memref<2xindex> to memref<?xindex>
 // CHECK:         memref.store %[[TMP_c5]], %[[TMP_2]][%[[TMP_c0]]] : memref<2xindex>
@@ -115,8 +115,8 @@ func.func @concat_mix_dense(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spar
 // CHECK:         }
 // CHECK:         %[[TMP_11:.*]] = memref.alloca() : memref<2xi8>
 // CHECK:         %[[TMP_12:.*]] = memref.cast %[[TMP_11]] : memref<2xi8> to memref<?xi8>
-// CHECK:         memref.store %[[TMP_c1_i8]], %[[TMP_11]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK:         memref.store %[[TMP_c1_i8]], %[[TMP_11]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK:         memref.store %[[TMP_c8_i8]], %[[TMP_11]][%[[TMP_c0]]] : memref<2xi8>
+// CHECK:         memref.store %[[TMP_c8_i8]], %[[TMP_11]][%[[TMP_c1]]] : memref<2xi8>
 // CHECK:         %[[TMP_13:.*]] = memref.alloca() : memref<2xindex>
 // CHECK:         %[[TMP_14:.*]] = memref.cast %[[TMP_13]] : memref<2xindex> to memref<?xindex>
 // CHECK:         memref.store %[[TMP_c3]], %[[TMP_13]][%[[TMP_c0]]] : memref<2xindex>
@@ -167,11 +167,11 @@ func.func @concat_mix_sparse(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spa
 // CHECK-DAG:     %[[TMP_c0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[TMP_c4:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[TMP_c5:.*]] = arith.constant 5 : index
-// CHECK-DAG:     %[[TMP_c1_i8:.*]] = arith.constant 1 : i8
+// CHECK-DAG:     %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
 // CHECK:         %[[TMP_0:.*]] = memref.alloca() : memref<2xi8>
 // CHECK:         %[[TMP_1:.*]] = memref.cast %[[TMP_0]] : memref<2xi8> to memref<?xi8>
-// CHECK:         memref.store %[[TMP_c1_i8]], %[[TMP_0]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK:         memref.store %[[TMP_c1_i8]], %[[TMP_0]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK:         memref.store %[[TMP_c8_i8]], %[[TMP_0]][%[[TMP_c0]]] : memref<2xi8>
+// CHECK:         memref.store %[[TMP_c8_i8]], %[[TMP_0]][%[[TMP_c1]]] : memref<2xi8>
 // CHECK:         %[[TMP_2:.*]] = memref.alloca() : memref<2xindex>
 // CHECK:         %[[TMP_3:.*]] = memref.cast %[[TMP_2]] : memref<2xindex> to memref<?xindex>
 // CHECK:         memref.store %[[TMP_c4]], %[[TMP_2]][%[[TMP_c0]]] : memref<2xindex>
@@ -199,8 +199,8 @@ func.func @concat_mix_sparse(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spa
 // CHECK:         }
 // CHECK:         %[[TMP_11:.*]] = memref.alloca() : memref<2xi8>
 // CHECK:         %[[TMP_12:.*]] = memref.cast %[[TMP_11]] : memref<2xi8> to memref<?xi8>
-// CHECK:         memref.store %[[TMP_c1_i8]], %[[TMP_11]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK:         memref.store %[[TMP_c1_i8]], %[[TMP_11]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK:         memref.store %[[TMP_c8_i8]], %[[TMP_11]][%[[TMP_c0]]] : memref<2xi8>
+// CHECK:         memref.store %[[TMP_c8_i8]], %[[TMP_11]][%[[TMP_c1]]] : memref<2xi8>
 // CHECK:         %[[TMP_13:.*]] = memref.alloca() : memref<2xindex>
 // CHECK:         %[[TMP_14:.*]] = memref.cast %[[TMP_13]] : memref<2xindex> to memref<?xindex>
 // CHECK:         memref.store %[[TMP_c4]], %[[TMP_13]][%[[TMP_c0]]] : memref<2xindex>
@@ -243,7 +243,7 @@ func.func @concat_mix_sparse_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3
 // CHECK-DAG:         %[[TMP_c6_i32:.*]] = arith.constant 6 : i32
 // CHECK-DAG:         %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
 // CHECK-DAG:         %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
-// CHECK-DAG:         %[[TMP_c1_i8:.*]] = arith.constant 1 : i8
+// CHECK-DAG:         %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
 // CHECK-DAG:         %[[TMP_c3:.*]] = arith.constant 3 : index
 // CHECK-DAG:         %[[TMP_c1:.*]] = arith.constant 1 : index
 // CHECK-DAG:         %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
@@ -262,8 +262,8 @@ func.func @concat_mix_sparse_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3
 // CHECK:         }
 // CHECK:         %[[TMP_1:.*]] = memref.alloca() : memref<2xi8>
 // CHECK:         %[[TMP_2:.*]] = memref.cast %[[TMP_1]] : memref<2xi8> to memref<?xi8>
-// CHECK:         memref.store %[[TMP_c1_i8]], %[[TMP_1]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK:         memref.store %[[TMP_c1_i8]], %[[TMP_1]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK:         memref.store %[[TMP_c8_i8]], %[[TMP_1]][%[[TMP_c0]]] : memref<2xi8>
+// CHECK:         memref.store %[[TMP_c8_i8]], %[[TMP_1]][%[[TMP_c1]]] : memref<2xi8>
 // CHECK:         %[[TMP_3:.*]] = memref.alloca() : memref<2xindex>
 // CHECK:         %[[TMP_4:.*]] = memref.cast %[[TMP_3]] : memref<2xindex> to memref<?xindex>
 // CHECK:         memref.store %[[TMP_c4]], %[[TMP_3]][%[[TMP_c0]]] : memref<2xindex>
@@ -304,7 +304,7 @@ func.func @concat_mix_dense_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3x
 // CHECK-DAG:       %[[TMP_c6_i32:.*]] = arith.constant 6 : i32
 // CHECK-DAG:       %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
 // CHECK-DAG:       %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
-// CHECK-DAG:       %[[TMP_c1_i8:.*]] = arith.constant 1 : i8
+// CHECK-DAG:       %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
 // CHECK-DAG:       %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
 // CHECK-DAG:       %[[TMP_c0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[TMP_c3:.*]] = arith.constant 3 : index
@@ -323,8 +323,8 @@ func.func @concat_mix_dense_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3x
 // CHECK:           }
 // CHECK:           %[[TMP_2:.*]] = memref.alloca() : memref<2xi8>
 // CHECK:           %[[TMP_3:.*]] = memref.cast %[[TMP_2]] : memref<2xi8> to memref<?xi8>
-// CHECK:           memref.store %[[TMP_c1_i8]], %[[TMP_2]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK:           memref.store %[[TMP_c1_i8]], %[[TMP_2]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK:           memref.store %[[TMP_c8_i8]], %[[TMP_2]][%[[TMP_c0]]] : memref<2xi8>
+// CHECK:           memref.store %[[TMP_c8_i8]], %[[TMP_2]][%[[TMP_c1]]] : memref<2xi8>
 // CHECK:           %[[TMP_4:.*]] = memref.alloca() : memref<2xindex>
 // CHECK:           %[[TMP_5:.*]] = memref.cast %[[TMP_4]] : memref<2xindex> to memref<?xindex>
 // CHECK:           memref.store %[[TMP_c3]], %[[TMP_4]][%[[TMP_c0]]] : memref<2xindex>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 2b388bafc2cfd..250e06f6e55e8 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -14,7 +14,7 @@
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 100 : index
 // CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 1 : i8
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 8 : i8
 // CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi8>
 // CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref<?xi8>
 // CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8>


        


More information about the Mlir-commits mailing list