[clang-tools-extra] [mlir][sparse] Print new syntax (PR #68130)

Yinying Li via cfe-commits cfe-commits at lists.llvm.org
Wed Oct 4 08:36:59 PDT 2023


https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/68130

>From 47b34bb327e1078678d3ba0c96ebce3fc89cf2ae Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 3 Oct 2023 16:43:50 +0000
Subject: [PATCH 1/5] [mlir][sparse] Print new syntax

Printing changes from #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }> to map = (d0) -> (d0 : compressed). Level properties, ELL and slice are also supported.
---
 .../mlir/Dialect/SparseTensor/IR/Enums.h      |  20 +--
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  64 ++++---
 mlir/test/Dialect/SparseTensor/codegen.mlir   |   8 +-
 .../SparseTensor/roundtrip_encoding.mlir      |  32 ++--
 .../Dialect/SparseTensor/sparse_reshape.mlir  |   8 +-
 .../SparseTensor/sparse_tensor_reshape.mlir   |   2 +-
 .../python/dialects/sparse_tensor/dialect.py  | 160 +++++++++---------
 7 files changed, 159 insertions(+), 135 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index bc351ec52c0946b..2920ef79f461c6a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -215,29 +215,29 @@ constexpr const char *toMLIRString(DimLevelType dlt) {
   case DimLevelType::Compressed:
     return "compressed";
   case DimLevelType::CompressedNu:
-    return "compressed_nu";
+    return "compressed(nonunique)";
   case DimLevelType::CompressedNo:
-    return "compressed_no";
+    return "compressed(nonordered)";
   case DimLevelType::CompressedNuNo:
-    return "compressed_nu_no";
+    return "compressed(nonunique, nonordered)";
   case DimLevelType::Singleton:
     return "singleton";
   case DimLevelType::SingletonNu:
-    return "singleton_nu";
+    return "singleton(nonunique)";
   case DimLevelType::SingletonNo:
-    return "singleton_no";
+    return "singleton(nonordered)";
   case DimLevelType::SingletonNuNo:
-    return "singleton_nu_no";
+    return "singleton(nonunique, nonordered)";
   case DimLevelType::LooseCompressed:
     return "loose_compressed";
   case DimLevelType::LooseCompressedNu:
-    return "loose_compressed_nu";
+    return "loose_compressed(nonunique)";
   case DimLevelType::LooseCompressedNo:
-    return "loose_compressed_no";
+    return "loose_compressed(nonordered)";
   case DimLevelType::LooseCompressedNuNo:
-    return "loose_compressed_nu_no";
+    return "loose_compressed(nonunique, nonordered)";
   case DimLevelType::TwoOutOfFour:
-    return "compressed24";
+    return "block2_4";
   }
   return "";
 }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 3897e1b9ea3597c..4c8dccdda6c0c7c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -586,30 +586,56 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
 }
 
 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
-  // Print the struct-like storage in dictionary fashion.
-  printer << "<{ lvlTypes = [ ";
-  llvm::interleaveComma(getLvlTypes(), printer, [&](DimLevelType dlt) {
-    printer << "\"" << toMLIRString(dlt) << "\"";
-  });
-  printer << " ]";
+  auto map = static_cast<AffineMap>(getDimToLvl());
+  auto lvlTypes = getLvlTypes();
+  // Empty affine map indicates identity map
+  if (!map) {
+    map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
+  }
+  // Modified version of AsmPrinter::Impl::printAffineMap.
+  printer << "<{ map = ";
+  // Symbolic identifiers.
+  if (map.getNumSymbols() != 0) {
+    printer << '[';
+    for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
+      printer << 's' << i << ", ";
+    if (map.getNumSymbols() >= 1)
+      printer << 's' << map.getNumSymbols() - 1;
+    printer << ']';
+  }
+  // Dimension identifiers.
+  printer << '(';
+  auto dimSlices = getDimSlices();
+  if (!dimSlices.empty()) {
+    for (unsigned i = 0; i < map.getNumDims() - 1; ++i)
+      printer << 'd' << i << " : " << dimSlices[i] << ", ";
+    if (map.getNumDims() >= 1)
+      printer << 'd' << map.getNumDims() - 1 << " : "
+              << dimSlices[map.getNumDims() - 1];
+  } else {
+    for (unsigned i = 0; i < map.getNumDims() - 1; ++i)
+      printer << 'd' << i << ", ";
+    if (map.getNumDims() >= 1)
+      printer << 'd' << map.getNumDims() - 1;
+  }
+  printer << ')';
+  // Level format and properties.
+  printer << " -> (";
+  for (unsigned i = 0; i < map.getNumResults() - 1; ++i) {
+    map.getResult(i).print(printer.getStream());
+    printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
+  }
+  if (map.getNumResults() >= 1) {
+    auto lastIndex = map.getNumResults() - 1;
+    map.getResult(lastIndex).print(printer.getStream());
+    printer << " : " << toMLIRString(lvlTypes[lastIndex]);
+  }
+  printer << ')';
   // Print remaining members only for non-default values.
-  if (!isIdentity())
-    printer << ", dimToLvl = affine_map<" << getDimToLvl() << ">";
   if (getPosWidth())
     printer << ", posWidth = " << getPosWidth();
   if (getCrdWidth())
     printer << ", crdWidth = " << getCrdWidth();
-  if (!getDimSlices().empty()) {
-    printer << ", dimSlices = [ ";
-    llvm::interleaveComma(getDimSlices(), printer,
-                          [&](SparseTensorDimSliceAttr attr) {
-                            // Calls SparseTensorDimSliceAttr::print directly to
-                            // skip mnemonic.
-                            attr.print(printer);
-                          });
-    printer << " ]";
-  }
-
   printer << " }>";
 }
 
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 69a9c274a861ce1..c3b16807a7c18a6 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -507,7 +507,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
   return %1 : tensor<8x8xf64, #CSR>
 }
 
-// CHECK-LABEL: func.func private @_insert_dense_compressed_no_8_8_f64_0_0(
+// CHECK-LABEL: func.func private @"_insert_dense_compressed(nonordered)_8_8_f64_0_0"(
 //  CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
 //  CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
 //  CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
@@ -533,7 +533,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
 //       CHECK:     %[[A13:.*]]:4 = scf.for %[[A14:.*]] = %[[A11]] to %[[A7]] step %[[A12]] iter_args(%[[A15:.*]] = %[[A0]], %[[A16:.*]] = %[[A1]], %[[A17:.*]] = %[[A2]], %[[A18:.*]] = %[[A3]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 //       CHECK:       %[[A19:.*]] = memref.load %[[A6]]{{\[}}%[[A14]]] : memref<?xindex>
 //       CHECK:       %[[A20:.*]] = memref.load %[[A4]]{{\[}}%[[A19]]] : memref<?xf64>
-//       CHECK:       %[[A21:.*]]:4 = func.call @_insert_dense_compressed_no_8_8_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+//       CHECK:       %[[A21:.*]]:4 = func.call @"_insert_dense_compressed(nonordered)_8_8_f64_0_0"(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 //       CHECK:       memref.store %[[A10]], %[[A4]]{{\[}}%[[A19]]] : memref<?xf64>
 //       CHECK:       memref.store %[[A9]], %[[A5]]{{\[}}%[[A19]]] : memref<?xi1>
 //       CHECK:       scf.yield %[[A21]]#0, %[[A21]]#1, %[[A21]]#2, %[[A21]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
@@ -611,7 +611,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
   return %1 : tensor<128xf64, #SparseVector>
 }
 
-// CHECK-LABEL: func.func private @_insert_compressed_nu_singleton_5_6_f64_0_0(
+// CHECK-LABEL: func.func private @"_insert_compressed(nonunique)_singleton_5_6_f64_0_0"(
 //  CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
 //  CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
 //  CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
@@ -627,7 +627,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
 //  CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier
 //  CHECK-SAME: %[[A4:.*4]]: index,
 //  CHECK-SAME: %[[A5:.*5]]: f64)
-//       CHECK: %[[R:.*]]:4 = call @_insert_compressed_nu_singleton_5_6_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
+//       CHECK: %[[R:.*]]:4 = call @"_insert_compressed(nonunique)_singleton_5_6_f64_0_0"(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
 //       CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
 func.func @sparse_insert_coo(%arg0: tensor<5x6xf64, #Coo>, %arg1: index, %arg2: f64) -> tensor<5x6xf64, #Coo> {
   %0 = sparse_tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo>
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 39e3ef102423524..c4ef50bee01ea2c 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -split-input-file | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: func private @sparse_1d_tensor(
-// CHECK-SAME: tensor<32xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>>)
+// CHECK-SAME: tensor<32xf64, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>)
 func.func private @sparse_1d_tensor(tensor<32xf64, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>)
 
 // -----
@@ -13,7 +13,7 @@ func.func private @sparse_1d_tensor(tensor<32xf64, #sparse_tensor.encoding<{ map
 }>
 
 // CHECK-LABEL: func private @sparse_csr(
-// CHECK-SAME: tensor<?x?xf32, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], posWidth = 64, crdWidth = 64 }>>)
+// CHECK-SAME: tensor<?x?xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>>)
 func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
 
 // -----
@@ -23,7 +23,7 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
 }>
 
 // CHECK-LABEL: func private @CSR_explicit(
-// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>>
 func.func private @CSR_explicit(%arg0: tensor<?x?xf64, #CSR_explicit>) {
   return
 }
@@ -37,7 +37,7 @@ func.func private @CSR_explicit(%arg0: tensor<?x?xf64, #CSR_explicit>) {
 }>
 
 // CHECK-LABEL: func private @sparse_csc(
-// CHECK-SAME: tensor<?x?xf32, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimToLvl = affine_map<(d0, d1) -> (d1, d0)> }>>)
+// CHECK-SAME: tensor<?x?xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed) }>>)
 func.func private @sparse_csc(tensor<?x?xf32, #CSC>)
 
 // -----
@@ -49,7 +49,7 @@ func.func private @sparse_csc(tensor<?x?xf32, #CSC>)
 }>
 
 // CHECK-LABEL: func private @sparse_dcsc(
-// CHECK-SAME: tensor<?x?xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ], dimToLvl = affine_map<(d0, d1) -> (d1, d0)>, crdWidth = 64 }>>)
+// CHECK-SAME: tensor<?x?xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : compressed, d0 : compressed), crdWidth = 64 }>>)
 func.func private @sparse_dcsc(tensor<?x?xf32, #DCSC>)
 
 // -----
@@ -59,7 +59,7 @@ func.func private @sparse_dcsc(tensor<?x?xf32, #DCSC>)
 }>
 
 // CHECK-LABEL: func private @sparse_coo(
-// CHECK-SAME: tensor<?x?xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed_nu_no", "singleton_no" ] }>>)
+// CHECK-SAME: tensor<?x?xf32, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered)) }>>)
 func.func private @sparse_coo(tensor<?x?xf32, #COO>)
 
 // -----
@@ -69,7 +69,7 @@ func.func private @sparse_coo(tensor<?x?xf32, #COO>)
 }>
 
 // CHECK-LABEL: func private @sparse_bcoo(
-// CHECK-SAME: tensor<?x?x?xf32, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "loose_compressed_nu", "singleton" ] }>>)
+// CHECK-SAME: tensor<?x?x?xf32, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton) }>>)
 func.func private @sparse_bcoo(tensor<?x?x?xf32, #BCOO>)
 
 // -----
@@ -79,7 +79,7 @@ func.func private @sparse_bcoo(tensor<?x?x?xf32, #BCOO>)
 }>
 
 // CHECK-LABEL: func private @sparse_sorted_coo(
-// CHECK-SAME: tensor<10x10xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed_nu", "singleton" ] }>>)
+// CHECK-SAME: tensor<10x10xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }>>)
 func.func private @sparse_sorted_coo(tensor<10x10xf64, #SortedCOO>)
 
 // -----
@@ -94,7 +94,7 @@ func.func private @sparse_sorted_coo(tensor<10x10xf64, #SortedCOO>)
 }>
 
 // CHECK-LABEL: func private @sparse_bcsr(
-// CHECK-SAME: tensor<10x60xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "dense", "dense" ], dimToLvl = affine_map<(d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)> }>>
+// CHECK-SAME: tensor<10x60xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : compressed, d1 floordiv 3 : compressed, d0 mod 2 : dense, d1 mod 3 : dense) }>>
 func.func private @sparse_bcsr(tensor<10x60xf64, #BCSR>)
 
 
@@ -105,7 +105,7 @@ func.func private @sparse_bcsr(tensor<10x60xf64, #BCSR>)
 }>
 
 // CHECK-LABEL: func private @sparse_ell(
-// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "dense", "compressed" ], dimToLvl = affine_map<(d0, d1)[s0] -> (d0 * (s0 * 4), d0, d1)> }>>
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ map = [s0](d0, d1) -> (d0 * (s0 * 4) : dense, d0 : dense, d1 : compressed) }>>
 func.func private @sparse_ell(tensor<?x?xf64, #ELL>)
 
 // -----
@@ -115,7 +115,7 @@ func.func private @sparse_ell(tensor<?x?xf64, #ELL>)
 }>
 
 // CHECK-LABEL: func private @sparse_slice(
-// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimSlices = [ (1, 4, 1), (1, 4, 2) ] }>>
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ map = (d0 : #sparse_tensor<slice(1, 4, 1)>, d1 : #sparse_tensor<slice(1, 4, 2)>) -> (d0 : dense, d1 : compressed) }>>
 func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
 
 // -----
@@ -125,7 +125,7 @@ func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
 }>
 
 // CHECK-LABEL: func private @sparse_slice(
-// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimSlices = [ (1, ?, 1), (?, 4, 2) ] }>>
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ map = (d0 : #sparse_tensor<slice(1, ?, 1)>, d1 : #sparse_tensor<slice(?, 4, 2)>) -> (d0 : dense, d1 : compressed) }>>
 func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
 
 // -----
@@ -138,7 +138,7 @@ func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
 }>
 
 // CHECK-LABEL: func private @sparse_2_out_of_4(
-// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed24" ] }>>
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : block2_4) }>>
 func.func private @sparse_2_out_of_4(tensor<?x?xf64, #NV_24>)
 
 // -----
@@ -153,7 +153,7 @@ func.func private @sparse_2_out_of_4(tensor<?x?xf64, #NV_24>)
 }>
 
 // CHECK-LABEL: func private @BCSR(
-// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "dense", "dense" ], dimToLvl = affine_map<(d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)> }>>
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : compressed, d1 floordiv 3 : compressed, d0 mod 2 : dense, d1 mod 3 : dense) }>>
 func.func private @BCSR(%arg0: tensor<?x?xf64, #BCSR>) {
   return
 }
@@ -174,7 +174,7 @@ func.func private @BCSR(%arg0: tensor<?x?xf64, #BCSR>) {
 }>
 
 // CHECK-LABEL: func private @BCSR_explicit(
-// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "dense", "dense" ], dimToLvl = affine_map<(d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)> }>>
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : compressed, d1 floordiv 3 : compressed, d0 mod 2 : dense, d1 mod 3 : dense) }>>
 func.func private @BCSR_explicit(%arg0: tensor<?x?xf64, #BCSR_explicit>) {
   return
 }
@@ -190,7 +190,7 @@ func.func private @BCSR_explicit(%arg0: tensor<?x?xf64, #BCSR_explicit>) {
 }>
 
 // CHECK-LABEL: func private @NV_24(
-// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "dense", "compressed24" ], dimToLvl = affine_map<(d0, d1) -> (d0, d1 floordiv 4, d1 mod 4)> }>>
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : block2_4) }>>
 func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
   return
 }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index 7f8edac15302616..3a2376f75654af9 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -16,7 +16,7 @@
 //      CHECK-ROUND:  return %[[E]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
 //
 // CHECK-LABEL:   func.func @sparse_expand(
-// CHECK-SAME:    %[[S:.*]]:
+// CHECK-SAME:    %[[S:[a-zA-Z0-9_]*]]:
 // CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
@@ -53,7 +53,7 @@ func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10x
 //      CHECK-ROUND:  return %[[C]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
 //
 // CHECK-LABEL:   func.func @sparse_collapse(
-// CHECK-SAME:    %[[S:.*]]:
+// CHECK-SAME:    %[[S:[a-zA-Z0-9_]*]]:
 // CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
@@ -99,7 +99,7 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
 //      CHECK-ROUND:  return %[[E]] : tensor<?x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
 //
 // CHECK-LABEL:   func.func @dynamic_sparse_expand(
-// CHECK-SAME:    %[[S:.*]]:
+// CHECK-SAME:    %[[S:[a-zA-Z0-9_]*]]:
 // CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
@@ -142,7 +142,7 @@ func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<
 //      CHECK-ROUND:  return %[[C]] : tensor<?xf64, #sparse_tensor.encoding<{{{.*}}}>>
 //
 // CHECK-LABEL:   func.func @dynamic_sparse_collapse(
-// CHECK-SAME:    %[[S:.*]]:
+// CHECK-SAME:    %[[S:[a-zA-Z0-9_]*]]:
 // CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
diff --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
index 9368cc71c5faa42..e0111c89df65a2d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
@@ -4,7 +4,7 @@
 #SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
 
 // CHECK:         func.func @sparse_reshape(
-// CHECK-SAME:    %[[S:.*]]:
+// CHECK-SAME:    %[[S:[a-zA-Z0-9_]*]]:
 // CHECK-DAG:     %[[C25:.*]] = arith.constant 25 : index
 // CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index e1048edce184a51..6d15363fb17118d 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -13,95 +13,93 @@ def run(f):
 # CHECK-LABEL: TEST: testEncodingAttr1D
 @run
 def testEncodingAttr1D():
-    with Context() as ctx:
-        parsed = Attribute.parse(
-            "#sparse_tensor.encoding<{"
-            "  map = (d0) -> (d0 : compressed),"
-            "  posWidth = 16,"
-            "  crdWidth = 32"
-            "}>"
-        )
-        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 16, crdWidth = 32 }>
-        print(parsed)
-
-        casted = st.EncodingAttr(parsed)
-        # CHECK: equal: True
-        print(f"equal: {casted == parsed}")
-
-        # CHECK: lvl_types: [<DimLevelType.compressed: 8>]
-        print(f"lvl_types: {casted.lvl_types}")
-        # CHECK: dim_to_lvl: None
-        print(f"dim_to_lvl: {casted.dim_to_lvl}")
-        # CHECK: pos_width: 16
-        print(f"pos_width: {casted.pos_width}")
-        # CHECK: crd_width: 32
-        print(f"crd_width: {casted.crd_width}")
-
-        created = st.EncodingAttr.get(casted.lvl_types, None, 0, 0)
-        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>
-        print(created)
-        # CHECK: created_equal: False
-        print(f"created_equal: {created == casted}")
-
-        # Verify that the factory creates an instance of the proper type.
-        # CHECK: is_proper_instance: True
-        print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
-        # CHECK: created_pos_width: 0
-        print(f"created_pos_width: {created.pos_width}")
+  with Context() as ctx:
+    parsed = Attribute.parse(
+        "#sparse_tensor.encoding<{"
+        "  map = (d0) -> (d0 : compressed),"
+        "  posWidth = 16,"
+        "  crdWidth = 32"
+        "}>"
+    )
+    # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
+    print(parsed)
+
+    casted = st.EncodingAttr(parsed)
+    # CHECK: equal: True
+    print(f"equal: {casted == parsed}")
+
+    # CHECK: lvl_types: [<DimLevelType.compressed: 8>]
+    print(f"lvl_types: {casted.lvl_types}")
+    # CHECK: dim_to_lvl: None
+    print(f"dim_to_lvl: {casted.dim_to_lvl}")
+    # CHECK: pos_width: 16
+    print(f"pos_width: {casted.pos_width}")
+    # CHECK: crd_width: 32
+    print(f"crd_width: {casted.crd_width}")
+
+    created = st.EncodingAttr.get(casted.lvl_types, None, 0, 0)
+    # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
+    print(created)
+    # CHECK: created_equal: False
+    print(f"created_equal: {created == casted}")
+
+    # Verify that the factory creates an instance of the proper type.
+    # CHECK: is_proper_instance: True
+    print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
+    # CHECK: created_pos_width: 0
+    print(f"created_pos_width: {created.pos_width}")
 
 
 # CHECK-LABEL: TEST: testEncodingAttr2D
 @run
 def testEncodingAttr2D():
-    with Context() as ctx:
-        parsed = Attribute.parse(
-            "#sparse_tensor.encoding<{"
-            "  map = (d0, d1) -> (d1 : dense, d0 : compressed),"
-            "  posWidth = 8,"
-            "  crdWidth = 32"
-            "}>"
-        )
-        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimToLvl = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }>
-        print(parsed)
-
-        casted = st.EncodingAttr(parsed)
-        # CHECK: equal: True
-        print(f"equal: {casted == parsed}")
-
-        # CHECK: lvl_types: [<DimLevelType.dense: 4>, <DimLevelType.compressed: 8>]
-        print(f"lvl_types: {casted.lvl_types}")
-        # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
-        print(f"dim_to_lvl: {casted.dim_to_lvl}")
-        # CHECK: pos_width: 8
-        print(f"pos_width: {casted.pos_width}")
-        # CHECK: crd_width: 32
-        print(f"crd_width: {casted.crd_width}")
-
-        created = st.EncodingAttr.get(
-            casted.lvl_types, casted.dim_to_lvl, 8, 32
-        )
-        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimToLvl = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }>
-        print(created)
-        # CHECK: created_equal: True
-        print(f"created_equal: {created == casted}")
+  with Context() as ctx:
+    parsed = Attribute.parse(
+        "#sparse_tensor.encoding<{"
+        "  map = (d0, d1) -> (d1 : dense, d0 : compressed),"
+        "  posWidth = 8,"
+        "  crdWidth = 32"
+        "}>"
+    )
+    # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+    print(parsed)
+
+    casted = st.EncodingAttr(parsed)
+    # CHECK: equal: True
+    print(f"equal: {casted == parsed}")
+
+    # CHECK: lvl_types: [<DimLevelType.dense: 4>, <DimLevelType.compressed: 8>]
+    print(f"lvl_types: {casted.lvl_types}")
+    # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
+    print(f"dim_to_lvl: {casted.dim_to_lvl}")
+    # CHECK: pos_width: 8
+    print(f"pos_width: {casted.pos_width}")
+    # CHECK: crd_width: 32
+    print(f"crd_width: {casted.crd_width}")
+
+    created = st.EncodingAttr.get(casted.lvl_types, casted.dim_to_lvl, 8, 32)
+    # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+    print(created)
+    # CHECK: created_equal: True
+    print(f"created_equal: {created == casted}")
 
 
 # CHECK-LABEL: TEST: testEncodingAttrOnTensorType
 @run
 def testEncodingAttrOnTensorType():
-    with Context() as ctx, Location.unknown():
-        encoding = st.EncodingAttr(
-            Attribute.parse(
-                "#sparse_tensor.encoding<{"
-                "  map = (d0) -> (d0 : compressed), "
-                "  posWidth = 64,"
-                "  crdWidth = 32"
-                "}>"
-            )
+  with Context() as ctx, Location.unknown():
+    encoding = st.EncodingAttr(
+        Attribute.parse(
+            "#sparse_tensor.encoding<{"
+            "  map = (d0) -> (d0 : compressed), "
+            "  posWidth = 64,"
+            "  crdWidth = 32"
+            "}>"
         )
-        tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
-        # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>>
-        print(tt)
-        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>
-        print(tt.encoding)
-        assert tt.encoding == encoding
+    )
+    tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
+    # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>>
+    print(tt)
+    # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>
+    print(tt.encoding)
+    assert tt.encoding == encoding

>From 2be69066192995ff171e08a54f7c7fdd3e35ab44 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 3 Oct 2023 18:39:17 +0000
Subject: [PATCH 2/5] format

---
 .../python/dialects/sparse_tensor/dialect.py  | 158 +++++++++---------
 1 file changed, 79 insertions(+), 79 deletions(-)

diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 6d15363fb17118d..d80b878323377a4 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -13,93 +13,93 @@ def run(f):
 # CHECK-LABEL: TEST: testEncodingAttr1D
 @run
 def testEncodingAttr1D():
-  with Context() as ctx:
-    parsed = Attribute.parse(
-        "#sparse_tensor.encoding<{"
-        "  map = (d0) -> (d0 : compressed),"
-        "  posWidth = 16,"
-        "  crdWidth = 32"
-        "}>"
-    )
-    # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
-    print(parsed)
-
-    casted = st.EncodingAttr(parsed)
-    # CHECK: equal: True
-    print(f"equal: {casted == parsed}")
-
-    # CHECK: lvl_types: [<DimLevelType.compressed: 8>]
-    print(f"lvl_types: {casted.lvl_types}")
-    # CHECK: dim_to_lvl: None
-    print(f"dim_to_lvl: {casted.dim_to_lvl}")
-    # CHECK: pos_width: 16
-    print(f"pos_width: {casted.pos_width}")
-    # CHECK: crd_width: 32
-    print(f"crd_width: {casted.crd_width}")
-
-    created = st.EncodingAttr.get(casted.lvl_types, None, 0, 0)
-    # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
-    print(created)
-    # CHECK: created_equal: False
-    print(f"created_equal: {created == casted}")
-
-    # Verify that the factory creates an instance of the proper type.
-    # CHECK: is_proper_instance: True
-    print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
-    # CHECK: created_pos_width: 0
-    print(f"created_pos_width: {created.pos_width}")
+    with Context() as ctx:
+        parsed = Attribute.parse(
+            "#sparse_tensor.encoding<{"
+            "  map = (d0) -> (d0 : compressed),"
+            "  posWidth = 16,"
+            "  crdWidth = 32"
+            "}>"
+        )
+        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
+        print(parsed)
+
+        casted = st.EncodingAttr(parsed)
+        # CHECK: equal: True
+        print(f"equal: {casted == parsed}")
+
+        # CHECK: lvl_types: [<DimLevelType.compressed: 8>]
+        print(f"lvl_types: {casted.lvl_types}")
+        # CHECK: dim_to_lvl: None
+        print(f"dim_to_lvl: {casted.dim_to_lvl}")
+        # CHECK: pos_width: 16
+        print(f"pos_width: {casted.pos_width}")
+        # CHECK: crd_width: 32
+        print(f"crd_width: {casted.crd_width}")
+
+        created = st.EncodingAttr.get(casted.lvl_types, None, 0, 0)
+        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
+        print(created)
+        # CHECK: created_equal: False
+        print(f"created_equal: {created == casted}")
+
+        # Verify that the factory creates an instance of the proper type.
+        # CHECK: is_proper_instance: True
+        print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
+        # CHECK: created_pos_width: 0
+        print(f"created_pos_width: {created.pos_width}")
 
 
 # CHECK-LABEL: TEST: testEncodingAttr2D
 @run
 def testEncodingAttr2D():
-  with Context() as ctx:
-    parsed = Attribute.parse(
-        "#sparse_tensor.encoding<{"
-        "  map = (d0, d1) -> (d1 : dense, d0 : compressed),"
-        "  posWidth = 8,"
-        "  crdWidth = 32"
-        "}>"
-    )
-    # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
-    print(parsed)
-
-    casted = st.EncodingAttr(parsed)
-    # CHECK: equal: True
-    print(f"equal: {casted == parsed}")
-
-    # CHECK: lvl_types: [<DimLevelType.dense: 4>, <DimLevelType.compressed: 8>]
-    print(f"lvl_types: {casted.lvl_types}")
-    # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
-    print(f"dim_to_lvl: {casted.dim_to_lvl}")
-    # CHECK: pos_width: 8
-    print(f"pos_width: {casted.pos_width}")
-    # CHECK: crd_width: 32
-    print(f"crd_width: {casted.crd_width}")
-
-    created = st.EncodingAttr.get(casted.lvl_types, casted.dim_to_lvl, 8, 32)
-    # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
-    print(created)
-    # CHECK: created_equal: True
-    print(f"created_equal: {created == casted}")
+    with Context() as ctx:
+        parsed = Attribute.parse(
+            "#sparse_tensor.encoding<{"
+            "  map = (d0, d1) -> (d1 : dense, d0 : compressed),"
+            "  posWidth = 8,"
+            "  crdWidth = 32"
+            "}>"
+        )
+        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+        print(parsed)
+
+        casted = st.EncodingAttr(parsed)
+        # CHECK: equal: True
+        print(f"equal: {casted == parsed}")
+
+        # CHECK: lvl_types: [<DimLevelType.dense: 4>, <DimLevelType.compressed: 8>]
+        print(f"lvl_types: {casted.lvl_types}")
+        # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
+        print(f"dim_to_lvl: {casted.dim_to_lvl}")
+        # CHECK: pos_width: 8
+        print(f"pos_width: {casted.pos_width}")
+        # CHECK: crd_width: 32
+        print(f"crd_width: {casted.crd_width}")
+
+        created = st.EncodingAttr.get(casted.lvl_types, casted.dim_to_lvl, 8, 32)
+        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+        print(created)
+        # CHECK: created_equal: True
+        print(f"created_equal: {created == casted}")
 
 
 # CHECK-LABEL: TEST: testEncodingAttrOnTensorType
 @run
 def testEncodingAttrOnTensorType():
-  with Context() as ctx, Location.unknown():
-    encoding = st.EncodingAttr(
-        Attribute.parse(
-            "#sparse_tensor.encoding<{"
-            "  map = (d0) -> (d0 : compressed), "
-            "  posWidth = 64,"
-            "  crdWidth = 32"
-            "}>"
+    with Context() as ctx, Location.unknown():
+        encoding = st.EncodingAttr(
+            Attribute.parse(
+                "#sparse_tensor.encoding<{"
+                "  map = (d0) -> (d0 : compressed), "
+                "  posWidth = 64,"
+                "  crdWidth = 32"
+                "}>"
+            )
         )
-    )
-    tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
-    # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>>
-    print(tt)
-    # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>
-    print(tt.encoding)
-    assert tt.encoding == encoding
+        tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
+        # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>>
+        print(tt)
+        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>
+        print(tt.encoding)
+        assert tt.encoding == encoding

>From c7ee65a28b79ffdd45d068638775d5bcf7c20c29 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 3 Oct 2023 22:44:39 +0000
Subject: [PATCH 3/5] update function name

---
 .../Transforms/SparseTensorCodegen.cpp        | 20 +++++++++++++++++--
 mlir/test/Dialect/SparseTensor/codegen.mlir   |  8 ++++----
 2 files changed, 22 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index f02276fba0d526b..a470de8a72bed16 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -472,8 +472,11 @@ class SparseInsertGenerator
     llvm::raw_svector_ostream nameOstream(nameBuffer);
     nameOstream << kInsertFuncNamePrefix;
     const Level lvlRank = stt.getLvlRank();
-    for (Level l = 0; l < lvlRank; l++)
-      nameOstream << toMLIRString(stt.getLvlType(l)) << "_";
+    for (Level l = 0; l < lvlRank; l++) {
+      std::string lvlType = toMLIRString(stt.getLvlType(l));
+      replaceWithUnderscore(lvlType);
+      nameOstream << lvlType << "_";
+    }
     // Static dim sizes are used in the generated code while dynamic sizes are
     // loaded from the dimSizes buffer. This is the reason for adding the shape
     // to the function name.
@@ -489,6 +492,19 @@ class SparseInsertGenerator
 
 private:
   TensorType rtp;
+  void replaceWithUnderscore(std::string &lvlType) {
+    for (auto it = lvlType.begin(); it != lvlType.end();) {
+      if (*it == '(') {
+        *it = '_';
+      } else if (*it == ')' || *it == ' ') {
+        it = lvlType.erase(it);
+        continue;
+      } else if (*it == ',') {
+        *it = '_';
+      }
+      it++;
+    }
+  }
 };
 
 /// Generations insertion finalization code.
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index c3b16807a7c18a6..6ba4769402d15cb 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -507,7 +507,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
   return %1 : tensor<8x8xf64, #CSR>
 }
 
-// CHECK-LABEL: func.func private @"_insert_dense_compressed(nonordered)_8_8_f64_0_0"(
+// CHECK-LABEL: func.func private @_insert_dense_compressed_nonordered_8_8_f64_0_0(
 //  CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
 //  CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
 //  CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
@@ -533,7 +533,7 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
 //       CHECK:     %[[A13:.*]]:4 = scf.for %[[A14:.*]] = %[[A11]] to %[[A7]] step %[[A12]] iter_args(%[[A15:.*]] = %[[A0]], %[[A16:.*]] = %[[A1]], %[[A17:.*]] = %[[A2]], %[[A18:.*]] = %[[A3]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 //       CHECK:       %[[A19:.*]] = memref.load %[[A6]]{{\[}}%[[A14]]] : memref<?xindex>
 //       CHECK:       %[[A20:.*]] = memref.load %[[A4]]{{\[}}%[[A19]]] : memref<?xf64>
-//       CHECK:       %[[A21:.*]]:4 = func.call @"_insert_dense_compressed(nonordered)_8_8_f64_0_0"(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
+//       CHECK:       %[[A21:.*]]:4 = func.call @_insert_dense_compressed_nonordered_8_8_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 //       CHECK:       memref.store %[[A10]], %[[A4]]{{\[}}%[[A19]]] : memref<?xf64>
 //       CHECK:       memref.store %[[A9]], %[[A5]]{{\[}}%[[A19]]] : memref<?xi1>
 //       CHECK:       scf.yield %[[A21]]#0, %[[A21]]#1, %[[A21]]#2, %[[A21]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
@@ -611,7 +611,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
   return %1 : tensor<128xf64, #SparseVector>
 }
 
-// CHECK-LABEL: func.func private @"_insert_compressed(nonunique)_singleton_5_6_f64_0_0"(
+// CHECK-LABEL: func.func private @_insert_compressed_nonunique_singleton_5_6_f64_0_0(
 //  CHECK-SAME: %[[A1:.*0]]: memref<?xindex>,
 //  CHECK-SAME: %[[A2:.*1]]: memref<?xindex>,
 //  CHECK-SAME: %[[A3:.*2]]: memref<?xf64>,
@@ -627,7 +627,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
 //  CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier
 //  CHECK-SAME: %[[A4:.*4]]: index,
 //  CHECK-SAME: %[[A5:.*5]]: f64)
-//       CHECK: %[[R:.*]]:4 = call @"_insert_compressed(nonunique)_singleton_5_6_f64_0_0"(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
+//       CHECK: %[[R:.*]]:4 = call @_insert_compressed_nonunique_singleton_5_6_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
 //       CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
 func.func @sparse_insert_coo(%arg0: tensor<5x6xf64, #Coo>, %arg1: index, %arg2: f64) -> tensor<5x6xf64, #Coo> {
   %0 = sparse_tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo>

>From 2329e0df37e9ae6d36f57de8113028f43f162ddc Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 3 Oct 2023 23:01:57 +0000
Subject: [PATCH 4/5] make replace function more compact

---
 .../Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp   | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index a470de8a72bed16..0d076f6ef9d10ab 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -494,13 +494,11 @@ class SparseInsertGenerator
   TensorType rtp;
   void replaceWithUnderscore(std::string &lvlType) {
     for (auto it = lvlType.begin(); it != lvlType.end();) {
-      if (*it == '(') {
+      if (*it == '(' || *it == ',') {
         *it = '_';
       } else if (*it == ')' || *it == ' ') {
         it = lvlType.erase(it);
         continue;
-      } else if (*it == ',') {
-        *it = '_';
       }
       it++;
     }

>From 0d628f56229a8e2225b3222c60aa04c549d3a08c Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Wed, 4 Oct 2023 15:35:30 +0000
Subject: [PATCH 5/5] address reivew comments

---
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  8 ++++
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 42 ++++++++++++-------
 .../Transforms/SparseTensorCodegen.cpp        | 19 ++++-----
 3 files changed, 41 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 4e38f314a27391d..a3fe938a4af3d89 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -422,6 +422,14 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     std::optional<uint64_t> getStaticLvlSliceOffset(::mlir::sparse_tensor::Level lvl) const;
     std::optional<uint64_t> getStaticLvlSliceSize(::mlir::sparse_tensor::Level lvl) const;
     std::optional<uint64_t> getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const;
+
+    //
+    // Printing methods.
+    //
+
+    void printSymbol(AffineMap &map, AsmPrinter &printer) const;
+    void printDimension(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
+    void printLevel(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::DimLevelType> lvlTypes) const;
   }];
 
   let genVerifyDecl = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 4c8dccdda6c0c7c..fa4c366d03bf43f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -587,14 +587,27 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
 
 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
   auto map = static_cast<AffineMap>(getDimToLvl());
-  auto lvlTypes = getLvlTypes();
   // Empty affine map indicates identity map
   if (!map) {
     map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
   }
-  // Modified version of AsmPrinter::Impl::printAffineMap.
   printer << "<{ map = ";
-  // Symbolic identifiers.
+  printSymbol(map, printer);
+  printer << '(';
+  printDimension(map, printer, getDimSlices());
+  printer << ") -> (";
+  printLevel(map, printer, getLvlTypes());
+  printer << ')';
+  // Print remaining members only for non-default values.
+  if (getPosWidth())
+    printer << ", posWidth = " << getPosWidth();
+  if (getCrdWidth())
+    printer << ", crdWidth = " << getCrdWidth();
+  printer << " }>";
+}
+
+void SparseTensorEncodingAttr::printSymbol(AffineMap &map,
+                                           AsmPrinter &printer) const {
   if (map.getNumSymbols() != 0) {
     printer << '[';
     for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
@@ -603,9 +616,11 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
       printer << 's' << map.getNumSymbols() - 1;
     printer << ']';
   }
-  // Dimension identifiers.
-  printer << '(';
-  auto dimSlices = getDimSlices();
+}
+
+void SparseTensorEncodingAttr::printDimension(
+    AffineMap &map, AsmPrinter &printer,
+    ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
   if (!dimSlices.empty()) {
     for (unsigned i = 0; i < map.getNumDims() - 1; ++i)
       printer << 'd' << i << " : " << dimSlices[i] << ", ";
@@ -618,9 +633,11 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
     if (map.getNumDims() >= 1)
       printer << 'd' << map.getNumDims() - 1;
   }
-  printer << ')';
-  // Level format and properties.
-  printer << " -> (";
+}
+
+void SparseTensorEncodingAttr::printLevel(
+    AffineMap &map, AsmPrinter &printer,
+    ArrayRef<DimLevelType> lvlTypes) const {
   for (unsigned i = 0; i < map.getNumResults() - 1; ++i) {
     map.getResult(i).print(printer.getStream());
     printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
@@ -630,13 +647,6 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
     map.getResult(lastIndex).print(printer.getStream());
     printer << " : " << toMLIRString(lvlTypes[lastIndex]);
   }
-  printer << ')';
-  // Print remaining members only for non-default values.
-  if (getPosWidth())
-    printer << ", posWidth = " << getPosWidth();
-  if (getCrdWidth())
-    printer << ", crdWidth = " << getCrdWidth();
-  printer << " }>";
 }
 
 LogicalResult
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 0d076f6ef9d10ab..80abc3d602a0cf1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -474,7 +474,13 @@ class SparseInsertGenerator
     const Level lvlRank = stt.getLvlRank();
     for (Level l = 0; l < lvlRank; l++) {
       std::string lvlType = toMLIRString(stt.getLvlType(l));
-      replaceWithUnderscore(lvlType);
+      // Replace/remove punctuations in level properties.
+      std::replace_if(
+          lvlType.begin(), lvlType.end(),
+          [](char c) { return c == '(' || c == ','; }, '_');
+      lvlType.erase(std::remove_if(lvlType.begin(), lvlType.end(),
+                                   [](char c) { return c == ')' || c == ' '; }),
+                    lvlType.end());
       nameOstream << lvlType << "_";
     }
     // Static dim sizes are used in the generated code while dynamic sizes are
@@ -492,17 +498,6 @@ class SparseInsertGenerator
 
 private:
   TensorType rtp;
-  void replaceWithUnderscore(std::string &lvlType) {
-    for (auto it = lvlType.begin(); it != lvlType.end();) {
-      if (*it == '(' || *it == ',') {
-        *it = '_';
-      } else if (*it == ')' || *it == ' ') {
-        it = lvlType.erase(it);
-        continue;
-      }
-      it++;
-    }
-  }
 };
 
 /// Generations insertion finalization code.



More information about the cfe-commits mailing list