[Mlir-commits] [mlir] 96a2391 - [mlir][sparse] complete migration to sparse tensor type

Aart Bik llvmlistbot at llvm.org
Mon May 10 12:55:36 PDT 2021


Author: Aart Bik
Date: 2021-05-10T12:55:22-07:00
New Revision: 96a23911f6d72cc1ef0788b34caa553f1ce99c5d

URL: https://github.com/llvm/llvm-project/commit/96a23911f6d72cc1ef0788b34caa553f1ce99c5d
DIFF: https://github.com/llvm/llvm-project/commit/96a23911f6d72cc1ef0788b34caa553f1ce99c5d.diff

LOG: [mlir][sparse] complete migration to sparse tensor type

A very elaborate, but also very fun revision because all
puzzle pieces are finally "falling in place".

1. replaces lingalg annotations + flags with proper sparse tensor types
2. add rigorous verification on sparse tensor type and sparse primitives
3. removes glue and clutter on opaque pointers in favor of sparse tensor types
4. migrates all tests to use sparse tensor types

NOTE: next CL will remove *all* obsoleted sparse code in Linalg

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D102095

Added: 
    mlir/test/Dialect/SparseTensor/invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/ExecutionEngine/SparseUtils.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir
    mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
    mlir/test/Dialect/SparseTensor/sparse_1d.mlir
    mlir/test/Dialect/SparseTensor/sparse_2d.mlir
    mlir/test/Dialect/SparseTensor/sparse_3d.mlir
    mlir/test/Dialect/SparseTensor/sparse_lower.mlir
    mlir/test/Dialect/SparseTensor/sparse_nd.mlir
    mlir/test/Dialect/SparseTensor/sparse_parallel.mlir
    mlir/test/Dialect/SparseTensor/sparse_storage.mlir
    mlir/test/Dialect/SparseTensor/sparse_vector.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matvec.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir

Removed: 
    mlir/test/Dialect/SparseTensor/sparse_invalid.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 7a8a249a7a959..52539c469c60a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -24,4 +24,12 @@
 
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.h.inc"
 
+namespace mlir {
+namespace sparse_tensor {
+/// Convenience method to get a sparse encoding attribute from a type.
+/// Returns null-attribute for any type without an encoding.
+SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
+} // namespace sparse_tensor
+} // namespace mlir
+
 #endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index db8506416f208..976065aaf5fc9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -13,36 +13,42 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
 include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
+//===----------------------------------------------------------------------===//
 // Base class.
+//===----------------------------------------------------------------------===//
+
 class SparseTensor_Op<string mnemonic, list<OpTrait> traits = []>
   : Op<SparseTensor_Dialect, mnemonic, traits> {
   let printer = [{ return ::print(p, *this); }];
-  let verifier = ?;
+  let verifier = [{ return ::verify(*this); }];
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
-// TODO: remove me
-def SparseTensor_FromPointerOp : SparseTensor_Op<"fromPtr">,
-    Arguments<(ins AnyType:$ptr)>,
-    Results<(outs AnyTensor:$result)> {
-  let summary = "Views an opaque sparse tensor pointer as a tensor";
-  let description = [{
-     Lacking a first class citizen type for sparse tensors, this operation
-     forms the glue between a sparse storage scheme (behind an opaque
-     pointer) and the (dense) tensors used in the kernel definitions.
-     This operation merely provides a way to assign a proper tensor
-     type and shape to the incoming opaque pointer. It disappears
-     completely during lowering.
+//===----------------------------------------------------------------------===//
+// Operations.
+//===----------------------------------------------------------------------===//
 
-     Example:
+def SparseTensor_NewOp : SparseTensor_Op<"new", []>,
+    Arguments<(ins AnyType:$source)>,
+    Results<(outs TensorOf<[AnyType]>:$result)> {
+  string summary = "Constructs a new sparse tensor";
+  string description = [{
+    Constructs a sparse tensor value with contents taken from an opaque
+    pointer provided by `source`. For targets that have access to a file
+    system, for example, this pointer may be a filename (or file) of a sparse
+    tensor in a particular external storage format. The form of the operation
+    is kept deliberately very general to allow for alternative implementations
+    in the future, such as pointers to buffers or runnable initialization
+    code. The operation is provided as an anchor that materializes a fully
+    typed sparse tensor values into a computation.
+
+    Example:
 
     ```mlir
-     !SparseTensor = type !llvm.ptr<i8>
-
-     %0 = sparse_tensor.fromPtr %arg0 : !SparseTensor to tensor<64x64xf64>
+    sparse_tensor.new %source : !Source to tensor<1024x1024xf64, #CSR>
     ```
   }];
-  let assemblyFormat = "$ptr attr-dict `:` type($ptr) `to` type($result)";
+  let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
 }
 
 def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
@@ -51,7 +57,7 @@ def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
   let summary = "Extract pointers array at given dimension from a tensor";
   let description = [{
      Returns the pointers array of the sparse storage scheme at the
-     given dimension for the given tensor. This is similar to the
+     given dimension for the given sparse tensor. This is similar to the
      `buffer_cast` operation in the sense that it provides a bridge
      between a tensor world view and a bufferized world view. Unlike the
      `buffer_cast` operation, however, this sparse operation actually
@@ -61,7 +67,8 @@ def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
      Example:
 
     ```mlir
-    %1 = sparse_tensor.pointers %0, %c1 : tensor<64x64xf64> to memref<?xindex>
+    %1 = sparse_tensor.pointers %0, %c1
+       : tensor<64x64xf64, #CSR> to memref<?xindex>
     ```
   }];
   let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)"
@@ -74,7 +81,7 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
   let summary = "Extract indices array at given dimension from a tensor";
   let description = [{
      Returns the indices array of the sparse storage scheme at the
-     given dimension for the given tensor. This is similar to the
+     given dimension for the given sparse tensor. This is similar to the
      `buffer_cast` operation in the sense that it provides a bridge
      between a tensor world view and a bufferized world view. Unlike the
      `buffer_cast` operation, however, this sparse operation actually
@@ -84,7 +91,8 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
      Example:
 
     ```mlir
-    %1 = sparse_tensor.indices %0, %c1 : tensor<64x64xf64> to memref<?xindex>
+    %1 = sparse_tensor.indices %0, %c1
+       : tensor<64x64xf64, #CSR> to memref<?xindex>
     ```
   }];
   let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)"
@@ -97,8 +105,8 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
   let summary = "Extract numerical values array from a tensor";
   let description = [{
      Returns the values array of the sparse storage scheme for the given
-     tensor, independent of the actual dimension. This is similar to the
-     `buffer_cast` operation in the sense that it provides a bridge
+     sparse tensor, independent of the actual dimension. This is similar to
+     the `buffer_cast` operation in the sense that it provides a bridge
      between a tensor world view and a bufferized world view. Unlike the
      `buffer_cast` operation, however, this sparse operation actually
      lowers into a call into a support library to obtain access to the
@@ -107,7 +115,7 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
      Example:
 
     ```mlir
-    %1 = sparse_tensor.values %0 : tensor<64x64xf64> to memref<?xf64>
+    %1 = sparse_tensor.values %0 : tensor<64x64xf64, #CSR> to memref<?xf64>
     ```
   }];
   let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index bfc1a31a98298..12720a38f1b62 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -18,6 +18,9 @@
 
 namespace mlir {
 
+// Forward.
+class TypeConverter;
+
 /// Defines a parallelization strategy. Any independent loop is a candidate
 /// for parallelization. The loop is made parallel if (1) allowed by the
 /// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse
@@ -42,32 +45,18 @@ enum class SparseVectorizationStrategy {
   kAnyStorageInnerLoop
 };
 
-/// Defines a type for "pointer" and "index" storage in the sparse storage
-/// scheme, with a choice between the native platform-dependent index width
-/// or any of 64-/32-/16-/8-bit integers. A narrow width obviously reduces
-/// the memory footprint of the sparse storage scheme, but the width should
-/// suffice to define the total required range (viz. the maximum number of
-/// stored entries per indirection level for the "pointers" and the maximum
-/// value of each tensor index over all dimensions for the "indices").
-enum class SparseIntType { kNative, kI64, kI32, kI16, kI8 };
-
 /// Sparsification options.
 struct SparsificationOptions {
   SparsificationOptions(SparseParallelizationStrategy p,
-                        SparseVectorizationStrategy v, unsigned vl,
-                        SparseIntType pt, SparseIntType it, bool fo)
+                        SparseVectorizationStrategy v, unsigned vl, bool fo)
       : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl),
-        ptrType(pt), indType(it), fastOutput(fo) {}
+        fastOutput(fo) {}
   SparsificationOptions()
       : SparsificationOptions(SparseParallelizationStrategy::kNone,
-                              SparseVectorizationStrategy::kNone, 1u,
-                              SparseIntType::kNative, SparseIntType::kNative,
-                              false) {}
+                              SparseVectorizationStrategy::kNone, 1u, false) {}
   SparseParallelizationStrategy parallelizationStrategy;
   SparseVectorizationStrategy vectorizationStrategy;
   unsigned vectorLength;
-  SparseIntType ptrType;
-  SparseIntType indType;
   bool fastOutput; // experimental: fast output buffers
 };
 
@@ -77,7 +66,8 @@ void populateSparsificationPatterns(
     const SparsificationOptions &options = SparsificationOptions());
 
 /// Sets up sparse tensor conversion rules.
-void populateSparseTensorConversionPatterns(RewritePatternSet &patterns);
+void populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
+                                            RewritePatternSet &patterns);
 
 std::unique_ptr<Pass> createSparsificationPass();
 std::unique_ptr<Pass> createSparseTensorConversionPass();

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index a41ed2ecd2416..f569e308c7ebc 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
-
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
@@ -17,7 +17,7 @@ using namespace mlir;
 using namespace mlir::sparse_tensor;
 
 //===----------------------------------------------------------------------===//
-// TensorDialect Attribute Methods
+// TensorDialect Attribute Methods.
 //===----------------------------------------------------------------------===//
 
 #define GET_ATTRDEF_CLASSES
@@ -178,8 +178,73 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
   return success();
 }
 
+SparseTensorEncodingAttr
+mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
+  if (auto ttp = type.dyn_cast<RankedTensorType>())
+    return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// TensorDialect Operations.
+//===----------------------------------------------------------------------===//
+
+static LogicalResult isInBounds(Value dim, Value tensor) {
+  if (auto constantOp = dim.getDefiningOp<ConstantOp>()) {
+    unsigned d = constantOp.getValue().cast<IntegerAttr>().getInt();
+    if (d >= tensor.getType().cast<RankedTensorType>().getRank())
+      return failure();
+  }
+  return success(); // in bounds, or symbolic
+}
+
+static LogicalResult isMatchingWidth(Value result, unsigned width) {
+  Type etp = result.getType().cast<MemRefType>().getElementType();
+  if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
+    return success();
+  return failure();
+}
+
+static LogicalResult verify(NewOp op) {
+  if (!getSparseTensorEncoding(op.getResult().getType()))
+    return op.emitError("expected a sparse tensor result");
+  return success();
+}
+
+static LogicalResult verify(ToPointersOp op) {
+  if (failed(isInBounds(op.dim(), op.tensor())))
+    return op.emitError("requested pointers dimension out of bounds");
+  if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
+    if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth())))
+      return op.emitError("unexpected type for pointers");
+    return success();
+  }
+  return op.emitError("expected a sparse tensor to get pointers");
+}
+
+static LogicalResult verify(ToIndicesOp op) {
+  if (failed(isInBounds(op.dim(), op.tensor())))
+    return op.emitError("requested indices dimension out of bounds");
+  if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
+    if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth())))
+      return op.emitError("unexpected type for indices");
+    return success();
+  }
+  return op.emitError("expected a sparse tensor to get indices");
+}
+
+static LogicalResult verify(ToValuesOp op) {
+  if (!getSparseTensorEncoding(op.tensor().getType()))
+    return op.emitError("expected a sparse tensor to get values");
+  RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>();
+  MemRefType mtp = op.result().getType().cast<MemRefType>();
+  if (ttp.getElementType() != mtp.getElementType())
+    return op.emitError("unexpected mismatch in element types");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
-// TensorDialect Methods
+// TensorDialect Methods.
 //===----------------------------------------------------------------------===//
 
 void SparseTensorDialect::initialize() {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index faf1133b1996f..71515fecb0606 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -1,4 +1,4 @@
-//===- SparseTensorLowering.cpp - Sparse tensor primitives lowering -------===//
+//===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,9 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// Lower sparse tensor primitives to calls into a runtime support library.
-// Note that this is a current implementation choice to keep the lowering
-// simple. In principle, these primitives could also be lowered to actual
+// Convert sparse tensor primitives to calls into a runtime support library.
+// Note that this is a current implementation choice to keep the conversion
+// simple. In principle, these primitives could also be converted to actual
 // elaborate IR code that implements the primitives on the selected sparse
 // tensor storage schemes.
 //
@@ -22,9 +22,24 @@
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
+using namespace mlir::sparse_tensor;
 
 namespace {
 
+/// Returns internal type encoding for overhead storage.
+static unsigned getOverheadTypeEncoding(unsigned width) {
+  switch (width) {
+  default:
+    return 1;
+  case 32:
+    return 2;
+  case 16:
+    return 3;
+  case 8:
+    return 4;
+  }
+}
+
 /// Returns function reference (first hit also inserts into module).
 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
                                  ValueRange operands) {
@@ -41,14 +56,14 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
   return SymbolRefAttr::get(context, name);
 }
 
-/// Sparse conversion rule to remove opaque pointer cast.
-class SparseTensorFromPointerConverter
-    : public OpConversionPattern<sparse_tensor::FromPointerOp> {
+/// Sparse conversion rule for returns.
+class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
+public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(sparse_tensor::FromPointerOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOp(op, operands[0]);
+    rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
     return success();
   }
 };
@@ -71,18 +86,77 @@ class SparseTensorToDimSizeConverter
   }
 };
 
+/// Sparse conversion rule for the new operator.
+class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(NewOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type resType = op.getType();
+    Type eltType = resType.cast<ShapedType>().getElementType();
+    MLIRContext *context = op->getContext();
+    SmallVector<Value, 5> params;
+    // Sparse encoding.
+    auto enc = getSparseTensorEncoding(resType);
+    if (!enc)
+      return failure();
+    // User pointer.
+    params.push_back(operands[0]);
+    // Sparsity annotations.
+    SmallVector<bool, 4> attrs;
+    unsigned sz = enc.getDimLevelType().size();
+    for (unsigned i = 0; i < sz; i++)
+      attrs.push_back(enc.getDimLevelType()[i] ==
+                      SparseTensorEncodingAttr::DimLevelType::Compressed);
+    auto elts = DenseElementsAttr::get(
+        RankedTensorType::get({sz}, rewriter.getIntegerType(1)), attrs);
+    params.push_back(rewriter.create<ConstantOp>(loc, elts));
+    // Seconary and primary types encoding.
+    unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
+    unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
+    unsigned primary;
+    if (eltType.isF64())
+      primary = 1;
+    else if (eltType.isF32())
+      primary = 2;
+    else if (eltType.isInteger(32))
+      primary = 3;
+    else if (eltType.isInteger(16))
+      primary = 4;
+    else if (eltType.isInteger(8))
+      primary = 5;
+    else
+      return failure();
+    params.push_back(
+        rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
+    params.push_back(
+        rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
+    params.push_back(
+        rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
+    // Generate the call to create new tensor.
+    Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
+    StringRef name = "newSparseTensor";
+    rewriter.replaceOpWithNewOp<CallOp>(
+        op, ptrType, getFunc(op, name, ptrType, params), params);
+    return success();
+  }
+};
+
 /// Sparse conversion rule for pointer accesses.
 class SparseTensorToPointersConverter
-    : public OpConversionPattern<sparse_tensor::ToPointersOp> {
+    : public OpConversionPattern<ToPointersOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(sparse_tensor::ToPointersOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     Type resType = op.getType();
     Type eltType = resType.cast<ShapedType>().getElementType();
     StringRef name;
-    if (eltType.isIndex() || eltType.isInteger(64))
+    if (eltType.isIndex())
+      name = "sparsePointers";
+    else if (eltType.isInteger(64))
       name = "sparsePointers64";
     else if (eltType.isInteger(32))
       name = "sparsePointers32";
@@ -99,17 +173,18 @@ class SparseTensorToPointersConverter
 };
 
 /// Sparse conversion rule for index accesses.
-class SparseTensorToIndicesConverter
-    : public OpConversionPattern<sparse_tensor::ToIndicesOp> {
+class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(sparse_tensor::ToIndicesOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     Type resType = op.getType();
     Type eltType = resType.cast<ShapedType>().getElementType();
     StringRef name;
-    if (eltType.isIndex() || eltType.isInteger(64))
+    if (eltType.isIndex())
+      name = "sparseIndices";
+    else if (eltType.isInteger(64))
       name = "sparseIndices64";
     else if (eltType.isInteger(32))
       name = "sparseIndices32";
@@ -126,12 +201,11 @@ class SparseTensorToIndicesConverter
 };
 
 /// Sparse conversion rule for value accesses.
-class SparseTensorToValuesConverter
-    : public OpConversionPattern<sparse_tensor::ToValuesOp> {
+class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(sparse_tensor::ToValuesOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     Type resType = op.getType();
     Type eltType = resType.cast<ShapedType>().getElementType();
@@ -158,8 +232,10 @@ class SparseTensorToValuesConverter
 
 /// Populates the given patterns list with conversion rules required for
 /// the sparsification of linear algebra operations.
-void mlir::populateSparseTensorConversionPatterns(RewritePatternSet &patterns) {
-  patterns.add<SparseTensorFromPointerConverter, SparseTensorToDimSizeConverter,
-               SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
-               SparseTensorToValuesConverter>(patterns.getContext());
+void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
+                                                  RewritePatternSet &patterns) {
+  patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
+               SparseTensorNewConverter, SparseTensorToPointersConverter,
+               SparseTensorToIndicesConverter, SparseTensorToValuesConverter>(
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index d54b2eff25afc..641ba4af4363b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -10,9 +10,11 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
+using namespace mlir::sparse_tensor;
 
 namespace {
 
@@ -43,14 +45,6 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
   Option<int32_t> vectorLength{
       *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)};
 
-  Option<int32_t> ptrType{*this, "ptr-type",
-                          llvm::cl::desc("Set the pointer type"),
-                          llvm::cl::init(0)};
-
-  Option<int32_t> indType{*this, "ind-type",
-                          llvm::cl::desc("Set the index type"),
-                          llvm::cl::init(0)};
-
   Option<bool> fastOutput{*this, "fast-output",
                           llvm::cl::desc("Allows fast output buffers"),
                           llvm::cl::init(false)};
@@ -83,29 +77,12 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
     }
   }
 
-  /// Returns the requested integer type.
-  SparseIntType typeOption(int32_t option) {
-    switch (option) {
-    default:
-      return SparseIntType::kNative;
-    case 1:
-      return SparseIntType::kI64;
-    case 2:
-      return SparseIntType::kI32;
-    case 3:
-      return SparseIntType::kI16;
-    case 4:
-      return SparseIntType::kI8;
-    }
-  }
-
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
     // Translate strategy flags to strategy options.
     SparsificationOptions options(parallelOption(), vectorOption(),
-                                  vectorLength, typeOption(ptrType),
-                                  typeOption(indType), fastOutput);
+                                  vectorLength, fastOutput);
     // Apply rewriting.
     populateSparsificationPatterns(patterns, options);
     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
@@ -113,19 +90,41 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
   }
 };
 
+class SparseTensorTypeConverter : public TypeConverter {
+public:
+  SparseTensorTypeConverter() {
+    addConversion([](Type type) { return type; });
+    addConversion(convertSparseTensorTypes);
+  }
+  // Maps each sparse tensor type to an opaque pointer.
+  static Optional<Type> convertSparseTensorTypes(Type type) {
+    if (getSparseTensorEncoding(type) != nullptr)
+      return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
+    return llvm::None;
+  }
+};
+
 struct SparseTensorConversionPass
     : public SparseTensorConversionBase<SparseTensorConversionPass> {
   void runOnOperation() override {
     auto *ctx = &getContext();
-    RewritePatternSet conversionPatterns(ctx);
+    RewritePatternSet patterns(ctx);
+    SparseTensorTypeConverter converter;
     ConversionTarget target(*ctx);
-    target
-        .addIllegalOp<sparse_tensor::FromPointerOp, sparse_tensor::ToPointersOp,
-                      sparse_tensor::ToIndicesOp, sparse_tensor::ToValuesOp>();
-    target.addLegalOp<CallOp>();
-    populateSparseTensorConversionPatterns(conversionPatterns);
+    target.addIllegalOp<NewOp, ToPointersOp, ToIndicesOp, ToValuesOp>();
+    target.addDynamicallyLegalOp<FuncOp>(
+        [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
+    target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
+      return converter.isSignatureLegal(op.getCalleeType());
+    });
+    target.addDynamicallyLegalOp<ReturnOp>(
+        [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
+    target.addLegalOp<ConstantOp>();
+    populateFuncOpTypeConversionPattern(patterns, converter);
+    populateCallOpTypeConversionPattern(patterns, converter);
+    populateSparseTensorConversionPatterns(converter, patterns);
     if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(conversionPatterns))))
+                                      std::move(patterns))))
       signalPassFailure();
   }
 };

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index d2c85841773ae..2cc6825af79f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1,4 +1,4 @@
-//===- Sparsification.cpp - Implementation of linalg sparsification -------===//
+//===- Sparsification.cpp - Implementation of sparsification --------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements lowering annotated linalg dialect to sparse code.
+// This file implements lowering sparse tensor types to actual sparse code.
 //
 // The concept of letting a compiler generate sparse code automatically was
 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and
@@ -49,14 +49,16 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/TensorEncoding.h"
 #include "llvm/ADT/SmallBitVector.h"
 
 using namespace mlir;
+using namespace mlir::sparse_tensor;
 
 namespace {
 
 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
-enum class Dim { kSparse, kDense, kUndef };
+enum class Dim { kSparse, kDense, kSingle, kUndef };
 
 /// Tensor expression. Represents a MLIR expression in tensor index notation.
 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is
@@ -270,11 +272,6 @@ class Merger {
     return false;
   }
 
-  /// Returns true if tensor has any sparse dimension.
-  bool isSparseTensor(unsigned t) const {
-    return llvm::any_of(dims[t], [](Dim d) { return d == Dim::kSparse; });
-  }
-
   /// Setter
   void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
 
@@ -296,7 +293,7 @@ class Merger {
 
 // Code generation.
 struct CodeGen {
-  CodeGen(mlir::SparsificationOptions o, unsigned numTensors, unsigned numLoops)
+  CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
         pointers(numTensors, std::vector<Value>(numLoops)),
         indices(numTensors, std::vector<Value>(numLoops)),
@@ -305,7 +302,7 @@ struct CodeGen {
         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
         curVecLength(1), curVecMask() {}
   /// Sparsification options.
-  mlir::SparsificationOptions options;
+  SparsificationOptions options;
   /// Universal dense indices and upper bounds (by index). The loops array
   /// is updated with the value of the universal dense index in the current
   /// loop. The sizes array is set once with the inferred dimension sizes.
@@ -336,37 +333,33 @@ struct CodeGen {
 
 } // namespace
 
-/// Helper method to inspect sparse annotations in the linalg operation.
+// Helper method to translate dim level type to internal representation.
+static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) {
+  if (enc) {
+    SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
+    if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
+      return Dim::kSparse;
+    if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton)
+      return Dim::kSingle;
+  }
+  return Dim::kDense;
+}
+
+/// Helper method to inspect sparse encodings in the tensor types.
 /// Fills the per-dimension sparsity information for all tensors.
 static void findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
   unsigned numTensors = op.getNumShapedOperands();
-  ArrayAttr sparseAttr = op.sparseAttr();
   for (unsigned t = 0; t < numTensors; t++) {
     auto map = op.getIndexingMap(t);
-    auto dimAttr = sparseAttr[t].cast<ArrayAttr>();
-    // For each tensor, we accept a per-dimension Sparse or Dense annotation.
-    // This is translated to the loop index that indexes that dimension.
     unsigned rank = op.getShapedType(t).getRank();
+    auto enc = getSparseTensorEncoding(op.getShapedType(t));
     for (unsigned d = 0; d < rank; d++) {
       unsigned idx = map.getDimPosition(d);
-      if (isSparseDim(dimAttr[d])) {
-        merger.setDim(t, idx, Dim::kSparse);
-      } else {
-        assert(isDenseDim(dimAttr[d]));
-        merger.setDim(t, idx, Dim::kDense);
-      }
+      merger.setDim(t, idx, toDim(enc, d));
     }
   }
 }
 
-/// Returns true if tensor was set up with sparse storage scheme.
-static bool linkedSparse(linalg::GenericOp op, unsigned tensor) {
-  if (tensor < op.getNumInputs())
-    return isa_and_nonnull<sparse_tensor::FromPointerOp>(
-        op.getInput(tensor).getDefiningOp());
-  return false;
-}
-
 /// A DFS helper to compute a topological sort. Note that recursion is
 /// bounded by the number of implicit loops, which is always small.
 /// Returns false when a cycle is detected.
@@ -404,7 +397,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
     auto map = op.getIndexingMap(t);
     assert(map.getNumDims() == n);
     // Skip dense tensor constraints when sparse only is requested.
-    if (sparseOnly && !merger.isSparseTensor(t) && !linkedSparse(op, t))
+    if (sparseOnly && !getSparseTensorEncoding(op.getShapedType(t)))
       continue;
     // At the moment, we take the index variables in the tensor access
     // expression in the order in which they appear (conceptually a
@@ -507,20 +500,10 @@ static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
 }
 
 /// Maps sparse integer option to actual integral storage type.
-static Type genIntType(PatternRewriter &rewriter, SparseIntType tp) {
-  switch (tp) {
-  case SparseIntType::kNative:
+static Type genIntType(PatternRewriter &rewriter, unsigned width) {
+  if (width == 0)
     return rewriter.getIndexType();
-  case SparseIntType::kI64:
-    return rewriter.getIntegerType(64);
-  case SparseIntType::kI32:
-    return rewriter.getIntegerType(32);
-  case SparseIntType::kI16:
-    return rewriter.getIntegerType(16);
-  case SparseIntType::kI8:
-    return rewriter.getIntegerType(8);
-  }
-  llvm_unreachable("unexpected SparseIntType");
+  return rewriter.getIntegerType(width);
 }
 
 /// Generates buffer for the output tensor.
@@ -563,25 +546,24 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
     auto tensorType = op.getShapedType(t);
     auto shape = tensorType.getShape();
     auto map = op.getIndexingMap(t);
+    auto enc = getSparseTensorEncoding(tensorType);
     // Scan all dimensions of current tensor.
-    bool dense = !linkedSparse(op, t);
     args.clear();
     for (unsigned d = 0, rank = shape.size(); d < rank; d++) {
       unsigned i = map.getDimPosition(d);
       // Handle sparse storage schemes.
       if (merger.isDim(t, i, Dim::kSparse)) {
-        dense = false;
         auto dynShape = {ShapedType::kDynamicSize};
         auto ptrTp = MemRefType::get(
-            dynShape, genIntType(rewriter, codegen.options.ptrType));
+            dynShape, genIntType(rewriter, enc.getPointerBitWidth()));
         auto indTp = MemRefType::get(
-            dynShape, genIntType(rewriter, codegen.options.indType));
+            dynShape, genIntType(rewriter, enc.getIndexBitWidth()));
         Value dim = rewriter.create<ConstantIndexOp>(loc, d);
         // Generate sparse primitives to obtains pointer and indices.
-        codegen.pointers[t][i] = rewriter.create<sparse_tensor::ToPointersOp>(
-            loc, ptrTp, tensor, dim);
-        codegen.indices[t][i] = rewriter.create<sparse_tensor::ToIndicesOp>(
-            loc, indTp, tensor, dim);
+        codegen.pointers[t][i] =
+            rewriter.create<ToPointersOp>(loc, ptrTp, tensor, dim);
+        codegen.indices[t][i] =
+            rewriter.create<ToIndicesOp>(loc, indTp, tensor, dim);
       }
       // Find lower and upper bound in current dimension.
       Value up;
@@ -596,7 +578,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
     // Perform the required bufferization. All dense inputs materialize
     // from the input tensor. The dense output tensor needs special
     // handling. Sparse inputs use a sparse primitive to obtain the values.
-    if (dense) {
+    if (!enc) {
       auto denseTp = MemRefType::get(shape, tensorType.getElementType());
       if (t < numInputs)
         codegen.buffers[t] =
@@ -607,8 +589,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
     } else {
       auto dynShape = {ShapedType::kDynamicSize};
       auto sparseTp = MemRefType::get(dynShape, tensorType.getElementType());
-      codegen.buffers[t] =
-          rewriter.create<sparse_tensor::ToValuesOp>(loc, sparseTp, tensor);
+      codegen.buffers[t] = rewriter.create<ToValuesOp>(loc, sparseTp, tensor);
     }
   }
 }
@@ -704,12 +685,11 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
   SmallVector<Value, 4> args;
   unsigned tensor = merger.exp(exp).e0;
   auto map = op.getIndexingMap(tensor);
-  bool sparse = linkedSparse(op, tensor);
+  auto enc = getSparseTensorEncoding(op.getShapedType(tensor));
   for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
     unsigned idx = map.getDimPosition(i);
     args.push_back(codegen.loops[idx]); // universal dense index
-    if (sparse || merger.isDim(tensor, idx, Dim::kSparse)) {
-      sparse = true;
+    if (enc) {
       args.clear();
       args.push_back(codegen.pidxs[tensor][idx]); // position index
     }
@@ -1000,7 +980,7 @@ static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
                              unsigned idx) {
   unsigned numTensors = op.getNumShapedOperands();
   for (unsigned t = 0; t < numTensors; t++) {
-    if (!merger.isSparseTensor(t) && !linkedSparse(op, t)) {
+    if (!getSparseTensorEncoding(op.getShapedType(t))) {
       auto map = op.getIndexingMap(t);
       unsigned r = map.getNumResults();
       for (unsigned i = 0; i < r; i++) {
@@ -1363,8 +1343,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
                                 PatternRewriter &rewriter) const override {
     // Detects sparse annotations and translate the per-dimension sparsity
     // information for all tensors to loop indices in the kernel.
-    if (!op.hasSparseSemantics())
-      return failure();
     assert(op.getNumOutputs() == 1);
     unsigned numTensors = op.getNumShapedOperands();
     unsigned numLoops = op.iterator_types().getValue().size();

diff  --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp
index 8f0dd538126bc..8bcebba15ae98 100644
--- a/mlir/lib/ExecutionEngine/SparseUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp
@@ -513,7 +513,7 @@ void *newSparseTensor(char *filename, bool *abase, bool *adata, uint64_t aoff,
                       uint64_t asize, uint64_t astride, uint64_t ptrTp,
                       uint64_t indTp, uint64_t valTp) {
   assert(astride == 1);
-  bool *sparsity = abase + aoff;
+  bool *sparsity = adata + aoff;
 
   // The most common cases: 64-bit or 32-bit overhead, double/float values.
   CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);
@@ -553,10 +553,12 @@ uint64_t sparseDimSize(void *tensor, uint64_t d) {
   return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
 }
 
+IMPL2(MemRef1DU64, sparsePointers, uint64_t, getPointers)
 IMPL2(MemRef1DU64, sparsePointers64, uint64_t, getPointers)
 IMPL2(MemRef1DU32, sparsePointers32, uint32_t, getPointers)
 IMPL2(MemRef1DU16, sparsePointers16, uint16_t, getPointers)
 IMPL2(MemRef1DU8, sparsePointers8, uint8_t, getPointers)
+IMPL2(MemRef1DU64, sparseIndices, uint64_t, getIndices)
 IMPL2(MemRef1DU64, sparseIndices64, uint64_t, getIndices)
 IMPL2(MemRef1DU32, sparseIndices32, uint32_t, getIndices)
 IMPL2(MemRef1DU16, sparseIndices16, uint16_t, getIndices)

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 508b29a2d157e..54bfa745dff4b 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -1,52 +1,104 @@
-// RUN: mlir-opt --sparse-tensor-conversion %s | FileCheck %s
+// RUN: mlir-opt %s --sparse-tensor-conversion | FileCheck %s
 
-!SparseTensor = type !llvm.ptr<i8>
+#SparseVector = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed"]
+}>
+
+#SparseVector64 = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed"],
+  pointerBitWidth = 64,
+  indexBitWidth = 64
+}>
+
+#SparseVector32 = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed"],
+  pointerBitWidth = 32,
+  indexBitWidth = 32
+}>
+
+// CHECK-LABEL: func @sparse_dim(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
+//       CHECK: %[[C:.*]] = constant 0 : index
+//       CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
+//       CHECK: return %[[D]] : index
+func @sparse_dim(%arg0: tensor<?xf64, #SparseVector>) -> index {
+  %c = constant 0 : index
+  %0 = memref.dim %arg0, %c : tensor<?xf64, #SparseVector>
+  return %0 : index
+}
+
+// CHECK-LABEL: func @sparse_new(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]]
+//       CHECK: return %[[T]] : !llvm.ptr<i8>
+func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
+  %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<128xf64, #SparseVector>
+  return %0 : tensor<128xf64, #SparseVector>
+}
 
 // CHECK-LABEL: func @sparse_pointers(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
-//       CHECK: %[[C:.*]] = constant 1 : index
-//       CHECK: %[[T:.*]] = call @sparsePointers64(%[[A]], %[[C]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+//       CHECK: %[[C:.*]] = constant 0 : index
+//       CHECK: %[[T:.*]] = call @sparsePointers(%[[A]], %[[C]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 //       CHECK: return %[[T]] : memref<?xindex>
-func @sparse_pointers(%arg0: !SparseTensor) -> memref<?xindex> {
-  %a = sparse_tensor.fromPtr %arg0 : !SparseTensor to tensor<128xf64>
-  %c = constant 1 : index
-  %0 = sparse_tensor.pointers %a, %c : tensor<128xf64> to memref<?xindex>
+func @sparse_pointers(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex> {
+  %c = constant 0 : index
+  %0 = sparse_tensor.pointers %arg0, %c : tensor<128xf64, #SparseVector> to memref<?xindex>
   return %0 : memref<?xindex>
 }
 
+// CHECK-LABEL: func @sparse_pointers64(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
+//       CHECK: %[[C:.*]] = constant 0 : index
+//       CHECK: %[[T:.*]] = call @sparsePointers64(%[[A]], %[[C]]) : (!llvm.ptr<i8>, index) -> memref<?xi64>
+//       CHECK: return %[[T]] : memref<?xi64>
+func @sparse_pointers64(%arg0: tensor<128xf64, #SparseVector64>) -> memref<?xi64> {
+  %c = constant 0 : index
+  %0 = sparse_tensor.pointers %arg0, %c : tensor<128xf64, #SparseVector64> to memref<?xi64>
+  return %0 : memref<?xi64>
+}
+
 // CHECK-LABEL: func @sparse_pointers32(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
-//       CHECK: %[[C:.*]] = constant 1 : index
+//       CHECK: %[[C:.*]] = constant 0 : index
 //       CHECK: %[[T:.*]] = call @sparsePointers32(%[[A]], %[[C]]) : (!llvm.ptr<i8>, index) -> memref<?xi32>
 //       CHECK: return %[[T]] : memref<?xi32>
-func @sparse_pointers32(%arg0: !SparseTensor) -> memref<?xi32> {
-  %a = sparse_tensor.fromPtr %arg0 : !SparseTensor to tensor<128xf64>
-  %c = constant 1 : index
-  %0 = sparse_tensor.pointers %a, %c : tensor<128xf64> to memref<?xi32>
+func @sparse_pointers32(%arg0: tensor<128xf64, #SparseVector32>) -> memref<?xi32> {
+  %c = constant 0 : index
+  %0 = sparse_tensor.pointers %arg0, %c : tensor<128xf64, #SparseVector32> to memref<?xi32>
   return %0 : memref<?xi32>
 }
 
 // CHECK-LABEL: func @sparse_indices(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
-//       CHECK: %[[C:.*]] = constant 1 : index
-//       CHECK: %[[T:.*]] = call @sparseIndices64(%[[A]], %[[C]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+//       CHECK: %[[C:.*]] = constant 0 : index
+//       CHECK: %[[T:.*]] = call @sparseIndices(%[[A]], %[[C]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 //       CHECK: return %[[T]] : memref<?xindex>
-func @sparse_indices(%arg0: !SparseTensor) -> memref<?xindex> {
-  %a = sparse_tensor.fromPtr %arg0 : !SparseTensor to tensor<128xf64>
-  %c = constant 1 : index
-  %0 = sparse_tensor.indices %a, %c : tensor<128xf64> to memref<?xindex>
+func @sparse_indices(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex> {
+  %c = constant 0 : index
+  %0 = sparse_tensor.indices %arg0, %c : tensor<128xf64, #SparseVector> to memref<?xindex>
   return %0 : memref<?xindex>
 }
 
+// CHECK-LABEL: func @sparse_indices64(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
+//       CHECK: %[[C:.*]] = constant 0 : index
+//       CHECK: %[[T:.*]] = call @sparseIndices64(%[[A]], %[[C]]) : (!llvm.ptr<i8>, index) -> memref<?xi64>
+//       CHECK: return %[[T]] : memref<?xi64>
+func @sparse_indices64(%arg0: tensor<128xf64, #SparseVector64>) -> memref<?xi64> {
+  %c = constant 0 : index
+  %0 = sparse_tensor.indices %arg0, %c : tensor<128xf64, #SparseVector64> to memref<?xi64>
+  return %0 : memref<?xi64>
+}
+
 // CHECK-LABEL: func @sparse_indices32(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
-//       CHECK: %[[C:.*]] = constant 1 : index
+//       CHECK: %[[C:.*]] = constant 0 : index
 //       CHECK: %[[T:.*]] = call @sparseIndices32(%[[A]], %[[C]]) : (!llvm.ptr<i8>, index) -> memref<?xi32>
 //       CHECK: return %[[T]] : memref<?xi32>
-func @sparse_indices32(%arg0: !SparseTensor) -> memref<?xi32> {
-  %a = sparse_tensor.fromPtr %arg0 : !SparseTensor to tensor<128xf64>
-  %c = constant 1 : index
-  %0 = sparse_tensor.indices %a, %c : tensor<128xf64> to memref<?xi32>
+func @sparse_indices32(%arg0: tensor<128xf64, #SparseVector32>) -> memref<?xi32> {
+  %c = constant 0 : index
+  %0 = sparse_tensor.indices %arg0, %c : tensor<128xf64, #SparseVector32> to memref<?xi32>
   return %0 : memref<?xi32>
 }
 
@@ -54,9 +106,8 @@ func @sparse_indices32(%arg0: !SparseTensor) -> memref<?xi32> {
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[T:.*]] = call @sparseValuesF64(%[[A]]) : (!llvm.ptr<i8>) -> memref<?xf64>
 //       CHECK: return %[[T]] : memref<?xf64>
-func @sparse_valuesf64(%arg0: !SparseTensor) -> memref<?xf64> {
-  %a = sparse_tensor.fromPtr %arg0 : !SparseTensor to tensor<128xf64>
-  %0 = sparse_tensor.values %a : tensor<128xf64> to memref<?xf64>
+func @sparse_valuesf64(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xf64> {
+  %0 = sparse_tensor.values %arg0 : tensor<128xf64, #SparseVector> to memref<?xf64>
   return %0 : memref<?xf64>
 }
 
@@ -64,8 +115,34 @@ func @sparse_valuesf64(%arg0: !SparseTensor) -> memref<?xf64> {
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[T:.*]] = call @sparseValuesF32(%[[A]]) : (!llvm.ptr<i8>) -> memref<?xf32>
 //       CHECK: return %[[T]] : memref<?xf32>
-func @sparse_valuesf32(%arg0: !SparseTensor) -> memref<?xf32> {
-  %a = sparse_tensor.fromPtr %arg0 : !SparseTensor to tensor<128xf32>
-  %0 = sparse_tensor.values %a : tensor<128xf32> to memref<?xf32>
+func @sparse_valuesf32(%arg0: tensor<128xf32, #SparseVector>) -> memref<?xf32> {
+  %0 = sparse_tensor.values %arg0: tensor<128xf32, #SparseVector> to memref<?xf32>
   return %0 : memref<?xf32>
 }
+
+// CHECK-LABEL: func @sparse_valuesi32(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
+//       CHECK: %[[T:.*]] = call @sparseValuesI32(%[[A]]) : (!llvm.ptr<i8>) -> memref<?xi32>
+//       CHECK: return %[[T]] : memref<?xi32>
+func @sparse_valuesi32(%arg0: tensor<128xi32, #SparseVector>) -> memref<?xi32> {
+  %0 = sparse_tensor.values %arg0: tensor<128xi32, #SparseVector> to memref<?xi32>
+  return %0 : memref<?xi32>
+}
+
+// CHECK-LABEL: func @sparse_valuesi16(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
+//       CHECK: %[[T:.*]] = call @sparseValuesI16(%[[A]]) : (!llvm.ptr<i8>) -> memref<?xi16>
+//       CHECK: return %[[T]] : memref<?xi16>
+func @sparse_valuesi16(%arg0: tensor<128xi16, #SparseVector>) -> memref<?xi16> {
+  %0 = sparse_tensor.values %arg0: tensor<128xi16, #SparseVector> to memref<?xi16>
+  return %0 : memref<?xi16>
+}
+
+// CHECK-LABEL: func @sparse_valuesi8(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
+//       CHECK: %[[T:.*]] = call @sparseValuesI8(%[[A]]) : (!llvm.ptr<i8>) -> memref<?xi8>
+//       CHECK: return %[[T]] : memref<?xi8>
+func @sparse_valuesi8(%arg0: tensor<128xi8, #SparseVector>) -> memref<?xi8> {
+  %0 = sparse_tensor.values %arg0: tensor<128xi8, #SparseVector> to memref<?xi8>
+  return %0 : memref<?xi8>
+}

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
new file mode 100644
index 0000000000000..d7ad3ca0d57d6
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+func @invalid_new_dense(%arg0: !llvm.ptr<i8>) -> tensor<32xf32> {
+  // expected-error at +1 {{expected a sparse tensor result}}
+  %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<32xf32>
+  return %0 : tensor<32xf32>
+}
+
+// -----
+
+func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
+  %c = constant 0 : index
+  // expected-error at +1 {{expected a sparse tensor to get pointers}}
+  %0 = sparse_tensor.pointers %arg0, %c : tensor<128xf64> to memref<?xindex>
+  return %0 : memref<?xindex>
+}
+
+// -----
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"], pointerBitWidth=32}>
+
+func @mismatch_pointers_types(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex> {
+  %c = constant 0 : index
+  // expected-error at +1 {{unexpected type for pointers}}
+  %0 = sparse_tensor.pointers %arg0, %c : tensor<128xf64, #SparseVector> to memref<?xindex>
+  return %0 : memref<?xindex>
+}
+
+// -----
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
+func @pointers_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex> {
+  %c = constant 1 : index
+  // expected-error at +1 {{requested pointers dimension out of bounds}}
+  %0 = sparse_tensor.pointers %arg0, %c : tensor<128xf64, #SparseVector> to memref<?xindex>
+  return %0 : memref<?xindex>
+}
+
+// -----
+
+func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
+  %c = constant 1 : index
+  // expected-error at +1 {{expected a sparse tensor to get indices}}
+  %0 = sparse_tensor.indices %arg0, %c : tensor<10x10xi32> to memref<?xindex>
+  return %0 : memref<?xindex>
+}
+
+// -----
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
+func @mismatch_indices_types(%arg0: tensor<?xf64, #SparseVector>) -> memref<?xi32> {
+  %c = constant 0 : index
+  // expected-error at +1 {{unexpected type for indices}}
+  %0 = sparse_tensor.indices %arg0, %c : tensor<?xf64, #SparseVector> to memref<?xi32>
+  return %0 : memref<?xi32>
+}
+
+// -----
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
+func @indices_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex> {
+  %c = constant 1 : index
+  // expected-error at +1 {{requested indices dimension out of bounds}}
+  %0 = sparse_tensor.indices %arg0, %c : tensor<128xf64, #SparseVector> to memref<?xindex>
+  return %0 : memref<?xindex>
+}
+
+// -----
+
+func @invalid_values_dense(%arg0: tensor<1024xf32>) -> memref<?xf32> {
+  // expected-error at +1 {{expected a sparse tensor to get values}}
+  %0 = sparse_tensor.values %arg0 : tensor<1024xf32> to memref<?xf32>
+  return %0 : memref<?xf32>
+}
+
+// -----
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
+func @mismatch_values_types(%arg0: tensor<?xf64, #SparseVector>) -> memref<?xf32> {
+  // expected-error at +1 {{unexpected mismatch in element types}}
+  %0 = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf32>
+  return %0 : memref<?xf32>
+}

diff  --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index af2368b86c956..2c60c63705333 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt <%s -split-input-file -verify-diagnostics
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
 // -----
 

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index ae9481914a44f..98e6da651bfcc 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -1,49 +1,55 @@
-// RUN: mlir-opt -split-input-file %s | FileCheck %s
+// RUN: mlir-opt %s -split-input-file | mlir-opt | FileCheck %s
 
-!SparseTensor = type !llvm.ptr<i8>
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
 
-// CHECK-LABEL: func @sparse_tensor(
-//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
-//       CHECK: %[[T:.*]] = sparse_tensor.fromPtr %[[A]] : !llvm.ptr<i8> to tensor<128xf64>
-//       CHECK: return %[[T]] : tensor<128xf64>
-func @sparse_tensor(%arg0: !SparseTensor) -> tensor<128xf64> {
-  %0 = sparse_tensor.fromPtr %arg0 : !SparseTensor to tensor<128xf64>
-  return %0 : tensor<128xf64>
+// CHECK-LABEL: func @sparse_new(
+// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
+//       CHECK: %[[T:.*]] = sparse_tensor.new %[[A]] : !llvm.ptr<i8> to tensor<128xf64, #{{.*}}>
+//       CHECK: return %[[T]] : tensor<128xf64, #{{.*}}>
+func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
+  %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<128xf64, #SparseVector>
+  return %0 : tensor<128xf64, #SparseVector>
 }
 
 // -----
 
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
 // CHECK-LABEL: func @sparse_pointers(
-//  CHECK-SAME: %[[A:.*]]: tensor<128xf64>)
-//       CHECK: %[[C:.*]] = constant 1 : index
-//       CHECK: %[[T:.*]] = sparse_tensor.pointers %[[A]], %[[C]] : tensor<128xf64> to memref<?xindex>
+//  CHECK-SAME: %[[A:.*]]: tensor<128xf64, #{{.*}}>)
+//       CHECK: %[[C:.*]] = constant 0 : index
+//       CHECK: %[[T:.*]] = sparse_tensor.pointers %[[A]], %[[C]] : tensor<128xf64, #{{.*}}> to memref<?xindex>
 //       CHECK: return %[[T]] : memref<?xindex>
-func @sparse_pointers(%arg0: tensor<128xf64>) -> memref<?xindex> {
-  %c = constant 1 : index
-  %0 = sparse_tensor.pointers %arg0, %c : tensor<128xf64> to memref<?xindex>
+func @sparse_pointers(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex> {
+  %c = constant 0 : index
+  %0 = sparse_tensor.pointers %arg0, %c : tensor<128xf64, #SparseVector> to memref<?xindex>
   return %0 : memref<?xindex>
 }
 
 // -----
 
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
 // CHECK-LABEL: func @sparse_indices(
-//  CHECK-SAME: %[[A:.*]]: tensor<128xf64>)
-//       CHECK: %[[C:.*]] = constant 1 : index
-//       CHECK: %[[T:.*]] = sparse_tensor.indices %[[A]], %[[C]] : tensor<128xf64> to memref<?xindex>
+//  CHECK-SAME: %[[A:.*]]: tensor<128xf64, #{{.*}}>)
+//       CHECK: %[[C:.*]] = constant 0 : index
+//       CHECK: %[[T:.*]] = sparse_tensor.indices %[[A]], %[[C]] : tensor<128xf64, #{{.*}}> to memref<?xindex>
 //       CHECK: return %[[T]] : memref<?xindex>
-func @sparse_indices(%arg0: tensor<128xf64>) -> memref<?xindex> {
-  %c = constant 1 : index
-  %0 = sparse_tensor.indices %arg0, %c : tensor<128xf64> to memref<?xindex>
+func @sparse_indices(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex> {
+  %c = constant 0 : index
+  %0 = sparse_tensor.indices %arg0, %c : tensor<128xf64, #SparseVector> to memref<?xindex>
   return %0 : memref<?xindex>
 }
 
 // -----
 
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
 // CHECK-LABEL: func @sparse_values(
-//  CHECK-SAME: %[[A:.*]]: tensor<128xf64>)
-//       CHECK: %[[T:.*]] = sparse_tensor.values %[[A]] : tensor<128xf64> to memref<?xf64>
+//  CHECK-SAME: %[[A:.*]]: tensor<128xf64, #{{.*}}>)
+//       CHECK: %[[T:.*]] = sparse_tensor.values %[[A]] : tensor<128xf64, #{{.*}}> to memref<?xf64>
 //       CHECK: return %[[T]] : memref<?xf64>
-func @sparse_values(%arg0: tensor<128xf64>) -> memref<?xf64> {
-  %0 = sparse_tensor.values %arg0 : tensor<128xf64> to memref<?xf64>
+func @sparse_values(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xf64> {
+  %0 = sparse_tensor.values %arg0 : tensor<128xf64, #SparseVector> to memref<?xf64>
   return %0 : memref<?xf64>
 }

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 70a5ef36b3497..2df47548d226e 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -1,9 +1,11 @@
-// RUN: mlir-opt <%s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s -split-input-file | mlir-opt | FileCheck %s
 
 // CHECK-LABEL: func private @sparse_1d_tensor(
 // CHECK-SAME: tensor<32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>)
 func private @sparse_1d_tensor(tensor<32xf64, #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>>)
 
+// -----
+
 #CSR = #sparse_tensor.encoding<{
   dimLevelType = [ "dense", "compressed" ],
   dimOrdering = affine_map<(i,j) -> (i,j)>,

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
index 9ed062cf757f4..59e0637a140d5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
@@ -1,41 +1,40 @@
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
 // RUN: mlir-opt %s -sparsification | FileCheck %s
 
-#trait_d = {
+#DV = #sparse_tensor.encoding<{ dimLevelType = [ "dense"      ] }>
+#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+
+#trait1 = {
   indexing_maps = [
     affine_map<(i) -> (i)>,  // a
     affine_map<(i) -> (i)>   // x (out)
   ],
-  sparse = [
-    [ "D" ],  // a
-    [ "D" ]   // x
-  ],
   iterator_types = ["parallel"],
   doc = "x(i) = a(i) OP b"
 }
 
 // CHECK-LABEL:   func @add_d(
-// CHECK-SAME:                %[[VAL_0:.*]]: tensor<32xf32>,
+// CHECK-SAME:                %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                %[[VAL_1:.*]]: f32,
 // CHECK-SAME:                %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = memref.buffer_cast %[[VAL_0]] : memref<32xf32>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf32>
 // CHECK:           %[[VAL_8:.*]] = memref.alloc() : memref<32xf32>
 // CHECK:           linalg.copy(%[[VAL_7]], %[[VAL_8]]) : memref<32xf32>, memref<32xf32>
 // CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK:             %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<32xf32>
+// CHECK:             %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<?xf32>
 // CHECK:             %[[VAL_11:.*]] = addf %[[VAL_10]], %[[VAL_1]] : f32
 // CHECK:             memref.store %[[VAL_11]], %[[VAL_8]]{{\[}}%[[VAL_9]]] : memref<32xf32>
 // CHECK:           }
 // CHECK:           %[[VAL_12:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xf32>
 // CHECK:           return %[[VAL_12]] : tensor<32xf32>
 // CHECK:         }
-func @add_d(%arga: tensor<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  %0 = linalg.generic #trait_d
-     ins(%arga: tensor<32xf32>)
+func @add_d(%arga: tensor<32xf32, #DV>, %argb: f32, %argx: tensor<32xf32>) -> tensor<32xf32> {
+  %0 = linalg.generic #trait1
+     ins(%arga: tensor<32xf32, #DV>)
     outs(%argx: tensor<32xf32>) {
       ^bb(%a: f32, %x: f32):
         %0 = addf %a, %argb : f32
@@ -45,27 +44,27 @@ func @add_d(%arga: tensor<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<
 }
 
 // CHECK-LABEL:   func @mul_d(
-// CHECK-SAME:                %[[VAL_0:.*]]: tensor<32xf32>,
+// CHECK-SAME:                %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                %[[VAL_1:.*]]: f32,
 // CHECK-SAME:                %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = memref.buffer_cast %[[VAL_0]] : memref<32xf32>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf32>
 // CHECK:           %[[VAL_8:.*]] = memref.alloc() : memref<32xf32>
 // CHECK:           linalg.copy(%[[VAL_7]], %[[VAL_8]]) : memref<32xf32>, memref<32xf32>
 // CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK:             %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<32xf32>
+// CHECK:             %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<?xf32>
 // CHECK:             %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_1]] : f32
 // CHECK:             memref.store %[[VAL_11]], %[[VAL_8]]{{\[}}%[[VAL_9]]] : memref<32xf32>
 // CHECK:           }
 // CHECK:           %[[VAL_12:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xf32>
 // CHECK:           return %[[VAL_12]] : tensor<32xf32>
 // CHECK:         }
-func @mul_d(%arga: tensor<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  %0 = linalg.generic #trait_d
-     ins(%arga: tensor<32xf32>)
+func @mul_d(%arga: tensor<32xf32, #DV>, %argb: f32, %argx: tensor<32xf32>) -> tensor<32xf32> {
+  %0 = linalg.generic #trait1
+     ins(%arga: tensor<32xf32, #DV>)
     outs(%argx: tensor<32xf32>) {
       ^bb(%a: f32, %x: f32):
         %0 = mulf %a, %argb : f32
@@ -74,30 +73,17 @@ func @mul_d(%arga: tensor<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<
   return %0 : tensor<32xf32>
 }
 
-#trait_s = {
-  indexing_maps = [
-    affine_map<(i) -> (i)>,  // a
-    affine_map<(i) -> (i)>   // x (out)
-  ],
-  sparse = [
-    [ "S" ],  // a
-    [ "D" ]   // x
-  ],
-  iterator_types = ["parallel"],
-  doc = "x(i) = a(i) OP b"
-}
-
 // CHECK-LABEL:   func @add_s(
-// CHECK-SAME:                %[[VAL_0:.*]]: tensor<32xf32>,
+// CHECK-SAME:                %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                %[[VAL_1:.*]]: f32,
 // CHECK-SAME:                %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant true
 // CHECK:           %[[VAL_6:.*]] = constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32> to memref<?xf32>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.alloc() : memref<32xf32>
 // CHECK:           linalg.copy(%[[VAL_10]], %[[VAL_11]]) : memref<32xf32>, memref<32xf32>
@@ -132,9 +118,9 @@ func @mul_d(%arga: tensor<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<
 // CHECK:           %[[VAL_30:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf32>
 // CHECK:           return %[[VAL_30]] : tensor<32xf32>
 // CHECK:         }
-func @add_s(%arga: tensor<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  %0 = linalg.generic #trait_s
-     ins(%arga: tensor<32xf32>)
+func @add_s(%arga: tensor<32xf32, #SV>, %argb: f32, %argx: tensor<32xf32>) -> tensor<32xf32> {
+  %0 = linalg.generic #trait1
+     ins(%arga: tensor<32xf32, #SV>)
     outs(%argx: tensor<32xf32>) {
       ^bb(%a: f32, %x: f32):
         %0 = addf %a, %argb : f32
@@ -144,13 +130,13 @@ func @add_s(%arga: tensor<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<
 }
 
 // CHECK-LABEL:   func @repeated_add_s(
-// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<32xf32>,
+// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                         %[[VAL_1:.*]]: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK:           %[[VAL_2:.*]] = constant 0 : index
 // CHECK:           %[[VAL_3:.*]] = constant 1 : index
-// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32> to memref<?xf32>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf32>
 // CHECK:           %[[VAL_8:.*]] = memref.alloc() : memref<32xf32>
 // CHECK:           linalg.copy(%[[VAL_7]], %[[VAL_8]]) : memref<32xf32>, memref<32xf32>
@@ -170,9 +156,9 @@ func @add_s(%arga: tensor<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<
 // CHECK:           %[[VAL_20:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xf32>
 // CHECK:           return %[[VAL_20]] : tensor<32xf32>
 // CHECK:         }
-func @repeated_add_s(%arga: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  %0 = linalg.generic #trait_s
-     ins(%arga: tensor<32xf32>)
+func @repeated_add_s(%arga: tensor<32xf32, #SV>, %argx: tensor<32xf32>) -> tensor<32xf32> {
+  %0 = linalg.generic #trait1
+     ins(%arga: tensor<32xf32, #SV>)
     outs(%argx: tensor<32xf32>) {
       ^bb(%a: f32, %x: f32):
         %0 = addf %a, %a : f32  // same tensor
@@ -184,14 +170,14 @@ func @repeated_add_s(%arga: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32x
 }
 
 // CHECK-LABEL:   func @mul_s(
-// CHECK-SAME:                %[[VAL_0:.*]]: tensor<32xf32>,
+// CHECK-SAME:                %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                %[[VAL_1:.*]]: f32,
 // CHECK-SAME:                %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32> to memref<?xf32>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.alloc() : memref<32xf32>
 // CHECK:           linalg.copy(%[[VAL_8]], %[[VAL_9]]) : memref<32xf32>, memref<32xf32>
@@ -206,9 +192,9 @@ func @repeated_add_s(%arga: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32x
 // CHECK:           %[[VAL_16:.*]] = memref.tensor_load %[[VAL_9]] : memref<32xf32>
 // CHECK:           return %[[VAL_16]] : tensor<32xf32>
 // CHECK:         }
-func @mul_s(%arga: tensor<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  %0 = linalg.generic #trait_s
-     ins(%arga: tensor<32xf32>)
+func @mul_s(%arga: tensor<32xf32, #SV>, %argb: f32, %argx: tensor<32xf32>) -> tensor<32xf32> {
+  %0 = linalg.generic #trait1
+     ins(%arga: tensor<32xf32, #SV>)
     outs(%argx: tensor<32xf32>) {
       ^bb(%a: f32, %x: f32):
         %0 = mulf %a, %argb : f32
@@ -217,35 +203,30 @@ func @mul_s(%arga: tensor<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<
   return %0 : tensor<32xf32>
 }
 
-#trait_dd = {
+#trait2 = {
   indexing_maps = [
     affine_map<(i) -> (i)>,  // a
     affine_map<(i) -> (i)>,  // b
     affine_map<(i) -> (i)>   // x (out)
   ],
-  sparse = [
-    [ "D" ],  // a
-    [ "D" ],  // b
-    [ "D" ]   // x
-  ],
   iterator_types = ["parallel"],
   doc = "x(i) = a(i) OP b(i)"
 }
 
 // CHECK-LABEL:   func @add_dd(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32xf32>) -> tensor<32xf32> {
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32xf32>,
+// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = memref.buffer_cast %[[VAL_0]] : memref<32xf32>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf32>
 // CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.alloc() : memref<32xf32>
 // CHECK:           linalg.copy(%[[VAL_8]], %[[VAL_9]]) : memref<32xf32>, memref<32xf32>
 // CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<32xf32>
+// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf32>
 // CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_10]]] : memref<32xf32>
 // CHECK:             %[[VAL_13:.*]] = addf %[[VAL_11]], %[[VAL_12]] : f32
 // CHECK:             memref.store %[[VAL_13]], %[[VAL_9]]{{\[}}%[[VAL_10]]] : memref<32xf32>
@@ -253,9 +234,9 @@ func @mul_s(%arga: tensor<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<
 // CHECK:           %[[VAL_14:.*]] = memref.tensor_load %[[VAL_9]] : memref<32xf32>
 // CHECK:           return %[[VAL_14]] : tensor<32xf32>
 // CHECK:         }
-func @add_dd(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  %0 = linalg.generic #trait_dd
-     ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+func @add_dd(%arga: tensor<32xf32, #DV>, %argb: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xf32, #DV>, tensor<32xf32>)
     outs(%argx: tensor<32xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -265,19 +246,19 @@ func @add_dd(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
 }
 
 // CHECK-LABEL:   func @mul_dd(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32xf32>) -> tensor<32xf32> {
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32xf32>,
+// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = memref.buffer_cast %[[VAL_0]] : memref<32xf32>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf32>
 // CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.alloc() : memref<32xf32>
 // CHECK:           linalg.copy(%[[VAL_8]], %[[VAL_9]]) : memref<32xf32>, memref<32xf32>
 // CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<32xf32>
+// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf32>
 // CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_10]]] : memref<32xf32>
 // CHECK:             %[[VAL_13:.*]] = mulf %[[VAL_11]], %[[VAL_12]] : f32
 // CHECK:             memref.store %[[VAL_13]], %[[VAL_9]]{{\[}}%[[VAL_10]]] : memref<32xf32>
@@ -285,9 +266,9 @@ func @add_dd(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
 // CHECK:           %[[VAL_14:.*]] = memref.tensor_load %[[VAL_9]] : memref<32xf32>
 // CHECK:           return %[[VAL_14]] : tensor<32xf32>
 // CHECK:         }
-func @mul_dd(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  %0 = linalg.generic #trait_dd
-     ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+func @mul_dd(%arga: tensor<32xf32, #DV>, %argb: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xf32, #DV>, tensor<32xf32>)
     outs(%argx: tensor<32xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -296,33 +277,18 @@ func @mul_dd(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
   return %0 : tensor<32xf32>
 }
 
-#trait_ds = {
-  indexing_maps = [
-    affine_map<(i) -> (i)>,  // a
-    affine_map<(i) -> (i)>,  // b
-    affine_map<(i) -> (i)>   // x (out)
-  ],
-  sparse = [
-    [ "D" ],  // a
-    [ "S" ],  // b
-    [ "D" ]   // x
-  ],
-  iterator_types = ["parallel"],
-  doc = "x(i) = a(i) OP b(i)"
-}
-
 // CHECK-LABEL:   func @add_ds(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32xf32>) -> tensor<32xf32> {
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32xf32>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant true
 // CHECK:           %[[VAL_6:.*]] = constant 1 : index
 // CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_0]] : memref<32xf32>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32xf32> to memref<?xf32>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.alloc() : memref<32xf32>
 // CHECK:           linalg.copy(%[[VAL_11]], %[[VAL_12]]) : memref<32xf32>, memref<32xf32>
@@ -360,9 +326,9 @@ func @mul_dd(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
 // CHECK:           %[[VAL_34:.*]] = memref.tensor_load %[[VAL_12]] : memref<32xf32>
 // CHECK:           return %[[VAL_34]] : tensor<32xf32>
 // CHECK:         }
-func @add_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  %0 = linalg.generic #trait_ds
-     ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+func @add_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32, #SV>, %argx: tensor<32xf32>) -> tensor<32xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xf32>, tensor<32xf32, #SV>)
     outs(%argx: tensor<32xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -372,15 +338,15 @@ func @add_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
 }
 
 // CHECK-LABEL:   func @mul_ds(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32xf32>) -> tensor<32xf32> {
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32xf32>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = constant 1 : index
 // CHECK:           %[[VAL_5:.*]] = memref.buffer_cast %[[VAL_0]] : memref<32xf32>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32xf32> to memref<?xf32>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.alloc() : memref<32xf32>
 // CHECK:           linalg.copy(%[[VAL_9]], %[[VAL_10]]) : memref<32xf32>, memref<32xf32>
@@ -396,9 +362,9 @@ func @add_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
 // CHECK:           %[[VAL_18:.*]] = memref.tensor_load %[[VAL_10]] : memref<32xf32>
 // CHECK:           return %[[VAL_18]] : tensor<32xf32>
 // CHECK:         }
-func @mul_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  %0 = linalg.generic #trait_ds
-     ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+func @mul_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32, #SV>, %argx: tensor<32xf32>) -> tensor<32xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xf32>, tensor<32xf32, #SV>)
     outs(%argx: tensor<32xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -407,32 +373,17 @@ func @mul_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
   return %0 : tensor<32xf32>
 }
 
-#trait_sd = {
-  indexing_maps = [
-    affine_map<(i) -> (i)>,  // a
-    affine_map<(i) -> (i)>,  // b
-    affine_map<(i) -> (i)>   // x (out)
-  ],
-  sparse = [
-    [ "S" ],  // a
-    [ "D" ],  // b
-    [ "D" ]   // x
-  ],
-  iterator_types = ["parallel"],
-  doc = "x(i) = a(i) OP b(i)"
-}
-
 // CHECK-LABEL:   func @add_sd(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32xf32>) -> tensor<32xf32> {
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32xf32>,
+// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant true
 // CHECK:           %[[VAL_6:.*]] = constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32> to memref<?xf32>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.alloc() : memref<32xf32>
@@ -471,9 +422,9 @@ func @mul_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
 // CHECK:           %[[VAL_34:.*]] = memref.tensor_load %[[VAL_12]] : memref<32xf32>
 // CHECK:           return %[[VAL_34]] : tensor<32xf32>
 // CHECK:         }
-func @add_sd(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  %0 = linalg.generic #trait_sd
-     ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+func @add_sd(%arga: tensor<32xf32, #SV>, %argb: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xf32, #SV>, tensor<32xf32>)
     outs(%argx: tensor<32xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -483,14 +434,14 @@ func @add_sd(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
 }
 
 // CHECK-LABEL:   func @mul_sd(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32xf32>) -> tensor<32xf32> {
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32xf32>,
+// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32> to memref<?xf32>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.alloc() : memref<32xf32>
@@ -507,9 +458,9 @@ func @add_sd(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
 // CHECK:           %[[VAL_18:.*]] = memref.tensor_load %[[VAL_10]] : memref<32xf32>
 // CHECK:           return %[[VAL_18]] : tensor<32xf32>
 // CHECK:         }
-func @mul_sd(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  %0 = linalg.generic #trait_sd
-     ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+func @mul_sd(%arga: tensor<32xf32, #SV>, %argb: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xf32, #SV>, tensor<32xf32>)
     outs(%argx: tensor<32xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -518,33 +469,18 @@ func @mul_sd(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
   return %0 : tensor<32xf32>
 }
 
-#trait_ss = {
-  indexing_maps = [
-    affine_map<(i) -> (i)>,  // a
-    affine_map<(i) -> (i)>,  // b
-    affine_map<(i) -> (i)>   // x (out)
-  ],
-  sparse = [
-    [ "S" ],  // a
-    [ "S" ],  // b
-    [ "D" ]   // x
-  ],
-  iterator_types = ["parallel"],
-  doc = "x(i) = a(i) OP b(i)"
-}
-
 // CHECK-LABEL:   func @add_ss(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32xf32>,
+// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32> to memref<?xf32>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32xf32> to memref<?xf32>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.alloc() : memref<32xf32>
 // CHECK:           linalg.copy(%[[VAL_11]], %[[VAL_12]]) : memref<32xf32>, memref<32xf32>
@@ -606,9 +542,9 @@ func @mul_sd(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
 // CHECK:           %[[VAL_53:.*]] = memref.tensor_load %[[VAL_12]] : memref<32xf32>
 // CHECK:           return %[[VAL_53]] : tensor<32xf32>
 // CHECK:         }
-func @add_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  %0 = linalg.generic #trait_ss
-     ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+func @add_ss(%arga: tensor<32xf32, #SV>, %argb: tensor<32xf32, #SV>, %argx: tensor<32xf32>) -> tensor<32xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xf32, #SV>, tensor<32xf32, #SV>)
     outs(%argx: tensor<32xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -618,17 +554,17 @@ func @add_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
 }
 
 // CHECK-LABEL:   func @mul_ss(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32xf32>,
+// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32> to memref<?xf32>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32xf32> to memref<?xf32>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.alloc() : memref<32xf32>
 // CHECK:           linalg.copy(%[[VAL_11]], %[[VAL_12]]) : memref<32xf32>, memref<32xf32>
@@ -668,9 +604,9 @@ func @add_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
 // CHECK:           %[[VAL_41:.*]] = memref.tensor_load %[[VAL_12]] : memref<32xf32>
 // CHECK:           return %[[VAL_41]] : tensor<32xf32>
 // CHECK:         }
-func @mul_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  %0 = linalg.generic #trait_ss
-     ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>)
+func @mul_ss(%arga: tensor<32xf32, #SV>, %argb: tensor<32xf32, #SV>, %argx: tensor<32xf32>) -> tensor<32xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xf32, #SV>, tensor<32xf32, #SV>)
     outs(%argx: tensor<32xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -679,34 +615,19 @@ func @mul_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
   return %0 : tensor<32xf32>
 }
 
-#trait_two_way_inv = {
-  indexing_maps = [
-    affine_map<(i) -> (i)>, // a
-    affine_map<(i) -> (i)>, // b
-    affine_map<(i) -> (i)>  // x (out)
-  ],
-  sparse = [
-    [ "S" ], // a
-    [ "S" ], // b
-    [ "D" ]  // x
-  ],
-  iterator_types = ["parallel"],
-  doc = "x(i) = a(i) * c + b(i) * c"
-}
-
 // CHECK-LABEL:   func @two_way_inv(
-// CHECK-SAME:                      %[[VAL_0:.*0]]: tensor<16xf32>,
-// CHECK-SAME:                      %[[VAL_1:.*1]]: tensor<16xf32>,
+// CHECK-SAME:                      %[[VAL_0:.*0]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                      %[[VAL_1:.*1]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                      %[[VAL_2:.*2]]: f32,
 // CHECK-SAME:                      %[[VAL_3:.*3]]: tensor<16xf32>) -> tensor<16xf32> {
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<16xf32> to memref<?xf32>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<16xf32> to memref<?xf32>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_3]] : memref<16xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.alloc() : memref<16xf32>
 // CHECK:           linalg.copy(%[[VAL_12]], %[[VAL_13]]) : memref<16xf32>, memref<16xf32>
@@ -774,9 +695,10 @@ func @mul_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>, %argx: tensor<32xf32>
 // CHECK:           %[[VAL_60:.*]] = memref.tensor_load %[[VAL_13]] : memref<16xf32>
 // CHECK:           return %[[VAL_60]] : tensor<16xf32>
 // CHECK:         }
-func @two_way_inv(%arga: tensor<16xf32>, %argb: tensor<16xf32>, %argc: f32, %argx: tensor<16xf32>) -> tensor<16xf32> {
-  %0 = linalg.generic #trait_two_way_inv
-    ins(%arga, %argb: tensor<16xf32>, tensor<16xf32>)
+func @two_way_inv(%arga: tensor<16xf32, #SV>, %argb: tensor<16xf32, #SV>, %argc: f32, %argx: tensor<16xf32>) -> tensor<16xf32> {
+  // Kernel "x(i) = a(i) * c + b(i) * c".
+  %0 = linalg.generic #trait2
+    ins(%arga, %argb: tensor<16xf32, #SV>, tensor<16xf32, #SV>)
     outs(%argx: tensor<16xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %argc : f32
@@ -788,18 +710,18 @@ func @two_way_inv(%arga: tensor<16xf32>, %argb: tensor<16xf32>, %argc: f32, %arg
 }
 
 // CHECK-LABEL:   func @two_way_inv_alt(
-// CHECK-SAME:                          %[[VAL_0:.*0]]: tensor<16xf32>,
-// CHECK-SAME:                          %[[VAL_1:.*1]]: tensor<16xf32>,
+// CHECK-SAME:                          %[[VAL_0:.*0]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                          %[[VAL_1:.*1]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                          %[[VAL_2:.*2]]: f32,
 // CHECK-SAME:                          %[[VAL_3:.*3]]: tensor<16xf32>) -> tensor<16xf32> {
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<16xf32> to memref<?xf32>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<16xf32> to memref<?xf32>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_3]] : memref<16xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.alloc() : memref<16xf32>
 // CHECK:           linalg.copy(%[[VAL_12]], %[[VAL_13]]) : memref<16xf32>, memref<16xf32>
@@ -866,11 +788,11 @@ func @two_way_inv(%arga: tensor<16xf32>, %argb: tensor<16xf32>, %argc: f32, %arg
 // CHECK:           %[[VAL_59:.*]] = memref.tensor_load %[[VAL_13]] : memref<16xf32>
 // CHECK:           return %[[VAL_59]] : tensor<16xf32>
 // CHECK:         }
-func @two_way_inv_alt(%arga: tensor<16xf32>,
-                      %argb: tensor<16xf32>, %argc: f32, %argx: tensor<16xf32>) -> tensor<16xf32> {
+func @two_way_inv_alt(%arga: tensor<16xf32, #SV>,
+                      %argb: tensor<16xf32, #SV>, %argc: f32, %argx: tensor<16xf32>) -> tensor<16xf32> {
   // Same kernel, but now expressed as "x(i) = (a(i) + b(i)) * c".
-  %0 = linalg.generic #trait_two_way_inv
-    ins(%arga, %argb: tensor<16xf32>, tensor<16xf32>)
+  %0 = linalg.generic #trait2
+    ins(%arga, %argb: tensor<16xf32, #SV>, tensor<16xf32, #SV>)
     outs(%argx: tensor<16xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -885,21 +807,17 @@ func @two_way_inv_alt(%arga: tensor<16xf32>,
     affine_map<(i) -> (i)>,  // a
     affine_map<(i) -> ()>    // x (scalar out)
   ],
-  sparse = [
-    [ "S" ],  // a
-    [  ]      // x
-  ],
   iterator_types = ["reduction"],
   doc = "x += SUM_i a(i)"
 }
 
 // CHECK-LABEL:   func @sum_reduction(
-// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<?xf32>,
+// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                        %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
 // CHECK:           %[[VAL_2:.*]] = constant 0 : index
 // CHECK:           %[[VAL_3:.*]] = constant 1 : index
-// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?xf32> to memref<?xf32>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_6:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32>
 // CHECK:           %[[VAL_7:.*]] = memref.alloc() : memref<f32>
 // CHECK:           linalg.copy(%[[VAL_6]], %[[VAL_7]]) : memref<f32>, memref<f32>
@@ -915,9 +833,9 @@ func @two_way_inv_alt(%arga: tensor<16xf32>,
 // CHECK:           %[[VAL_17:.*]] = memref.tensor_load %[[VAL_7]] : memref<f32>
 // CHECK:           return %[[VAL_17]] : tensor<f32>
 // CHECK:         }
-func @sum_reduction(%arga: tensor<?xf32>, %argx: tensor<f32>) -> tensor<f32> {
+func @sum_reduction(%arga: tensor<?xf32, #SV>, %argx: tensor<f32>) -> tensor<f32> {
   %0 = linalg.generic #trait_sum_reduction
-    ins(%arga: tensor<?xf32>)
+    ins(%arga: tensor<?xf32, #SV>)
     outs(%argx: tensor<f32>) {
       ^bb(%a: f32, %x: f32):
         %0 = addf %x, %a : f32
@@ -926,33 +844,28 @@ func @sum_reduction(%arga: tensor<?xf32>, %argx: tensor<f32>) -> tensor<f32> {
   return %0 : tensor<f32>
 }
 
-#trait_sum_reduction_ss = {
+#trait_sum_reduction2 = {
   indexing_maps = [
     affine_map<(i) -> (i)>, // a
     affine_map<(i) -> (i)>, // b
     affine_map<(i)-> ()>    // x (scalar out)
   ],
-  sparse = [
-    [ "S" ],  // a
-    [ "S" ],  // b
-    [     ]   // x
-  ],
   iterator_types = ["reduction"],
   doc = "x += SUM_i a(i) + b(i)"
 }
 
 // CHECK-LABEL:   func @sum_reduction_ss(
-// CHECK-SAME:                           %[[VAL_0:.*0]]: tensor<16xf32>,
-// CHECK-SAME:                           %[[VAL_1:.*1]]: tensor<16xf32>,
+// CHECK-SAME:                           %[[VAL_0:.*0]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                           %[[VAL_1:.*1]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                           %[[VAL_2:.*2]]: tensor<f32>) -> tensor<f32> {
 // CHECK:           %[[VAL_3:.*]] = constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<16xf32> to memref<?xf32>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<16xf32> to memref<?xf32>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<f32>
 // CHECK:           %[[VAL_12:.*]] = memref.alloc() : memref<f32>
 // CHECK:           linalg.copy(%[[VAL_11]], %[[VAL_12]]) : memref<f32>, memref<f32>
@@ -1022,13 +935,13 @@ func @sum_reduction(%arga: tensor<?xf32>, %argx: tensor<f32>) -> tensor<f32> {
 // CHECK:           %[[VAL_66:.*]] = memref.tensor_load %[[VAL_12]] : memref<f32>
 // CHECK:           return %[[VAL_66]] : tensor<f32>
 // CHECK:         }
-func @sum_reduction_ss(%arga: tensor<16xf32>,
-                       %argb: tensor<16xf32>,
+func @sum_reduction_ss(%arga: tensor<16xf32, #SV>,
+                       %argb: tensor<16xf32, #SV>,
                        %argx: tensor<f32>) -> tensor<f32> {
   // Just for testing. This case would be better expressed
   // as two separate reductions kernels.
-  %0 = linalg.generic #trait_sum_reduction_ss
-    ins(%arga, %argb: tensor<16xf32>, tensor<16xf32>)
+  %0 = linalg.generic #trait_sum_reduction2
+    ins(%arga, %argb: tensor<16xf32, #SV>, tensor<16xf32, #SV>)
     outs(%argx: tensor<f32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -1038,37 +951,31 @@ func @sum_reduction_ss(%arga: tensor<16xf32>,
   return %0 : tensor<f32>
 }
 
-#trait_sum_reduction_inv_ss = {
+#trait_sum_reduction_inv = {
   indexing_maps = [
     affine_map<(i) -> (i)>, // a
     affine_map<(i) -> ()>,  // b
     affine_map<(i) -> (i)>, // c
     affine_map<(i) -> ()>   // x (out)
   ],
-  sparse = [
-    [ "S" ], // a
-    [     ], // b
-    [ "S" ], // c
-    [     ]  // x
-  ],
   iterator_types = ["reduction"],
   doc = "x += SUM_i a(i) * b + c(i)"
 }
 
 // CHECK-LABEL:   func @sum_reduction_inv(
-// CHECK-SAME:                            %[[VAL_0:.*0]]: tensor<16xf32>,
+// CHECK-SAME:                            %[[VAL_0:.*0]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                            %[[VAL_1:.*1]]: tensor<f32>,
-// CHECK-SAME:                            %[[VAL_2:.*2]]: tensor<16xf32>,
+// CHECK-SAME:                            %[[VAL_2:.*2]]: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                            %[[VAL_3:.*3]]: tensor<f32>) -> tensor<f32> {
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<16xf32> to memref<?xf32>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_2]], %[[VAL_4]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_2]], %[[VAL_4]] : tensor<16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<16xf32> to memref<?xf32>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_2]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_2]], %[[VAL_4]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.buffer_cast %[[VAL_3]] : memref<f32>
 // CHECK:           %[[VAL_14:.*]] = memref.alloc() : memref<f32>
 // CHECK:           linalg.copy(%[[VAL_13]], %[[VAL_14]]) : memref<f32>, memref<f32>
@@ -1142,14 +1049,14 @@ func @sum_reduction_ss(%arga: tensor<16xf32>,
 // CHECK:           %[[VAL_72:.*]] = memref.tensor_load %[[VAL_14]] : memref<f32>
 // CHECK:           return %[[VAL_72]] : tensor<f32>
 // CHECK:         }
-func @sum_reduction_inv(%arga: tensor<16xf32>,
+func @sum_reduction_inv(%arga: tensor<16xf32, #SV>,
                         %argb: tensor<f32>,
-                        %argc: tensor<16xf32>,
+                        %argc: tensor<16xf32, #SV>,
                         %argx: tensor<f32>) -> tensor<f32> {
   // Just for testing. This case would be better expressed
   // as two separate reductions kernels.
-  %0 = linalg.generic #trait_sum_reduction_inv_ss
-    ins(%arga, %argb, %argc : tensor<16xf32>, tensor<f32>, tensor<16xf32>)
+  %0 = linalg.generic #trait_sum_reduction_inv
+    ins(%arga, %argb, %argc : tensor<16xf32, #SV>, tensor<f32>, tensor<16xf32, #SV>)
     outs(%argx: tensor<f32>) {
       ^bb(%a: f32, %b: f32, %c: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -1168,34 +1075,27 @@ func @sum_reduction_inv(%arga: tensor<16xf32>,
     affine_map<(i) -> (i)>,  // D
     affine_map<(i) -> (i)>   // X (out)
   ],
-  sparse = [
-    ["D"], // A
-    ["S"], // B
-    ["D"], // C
-    ["S"], // D
-    ["D"]  // X
-  ],
   iterator_types = ["parallel"],
   doc = "X(i) = A(i) + B(i) + C(i) + D(i)"
 }
 
 // CHECK-LABEL:   func @four_tensors_op(
 // CHECK-SAME:                          %[[VAL_0:.*0]]: tensor<?xf64>,
-// CHECK-SAME:                          %[[VAL_1:.*1]]: tensor<?xf64>,
+// CHECK-SAME:                          %[[VAL_1:.*1]]: tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                          %[[VAL_2:.*2]]: tensor<?xf64>,
-// CHECK-SAME:                          %[[VAL_3:.*3]]: tensor<?xf64>,
-// CHECK-SAME:                          %[[VAL_4:.*4]]: tensor<?xf64>) -> tensor<?xf64> {
+// CHECK-SAME:                          %[[VAL_3:.*3]]: tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                          %[[VAL_4:.*]]: tensor<?xf64>) -> tensor<?xf64> {
 // CHECK:           %[[VAL_5:.*]] = constant 0 : index
 // CHECK:           %[[VAL_6:.*]] = constant true
 // CHECK:           %[[VAL_7:.*]] = constant 1 : index
 // CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_0]] : memref<?xf64>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_5]] : tensor<?xf64> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_5]] : tensor<?xf64> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf64> to memref<?xf64>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_5]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_5]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
 // CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_2]] : memref<?xf64>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_3]], %[[VAL_5]] : tensor<?xf64> to memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_3]], %[[VAL_5]] : tensor<?xf64> to memref<?xindex>
-// CHECK:           %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_3]] : tensor<?xf64> to memref<?xf64>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_3]], %[[VAL_5]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_3]], %[[VAL_5]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_3]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
 // CHECK:           %[[VAL_16:.*]] = memref.dim %[[VAL_4]], %[[VAL_5]] : tensor<?xf64>
 // CHECK:           %[[VAL_17:.*]] = memref.buffer_cast %[[VAL_4]] : memref<?xf64>
 // CHECK:           %[[VAL_18:.*]] = memref.alloc(%[[VAL_16]]) : memref<?xf64>
@@ -1331,12 +1231,12 @@ func @sum_reduction_inv(%arga: tensor<16xf32>,
 // CHECK:           return %[[VAL_115]] : tensor<?xf64>
 // CHECK:         }
 func @four_tensors_op(%arga: tensor<?xf64>,
-                      %argb: tensor<?xf64>,
+                      %argb: tensor<?xf64, #SV>,
                       %argc: tensor<?xf64>,
-                      %argd: tensor<?xf64>,
+                      %argd: tensor<?xf64, #SV>,
                       %argx: tensor<?xf64>) -> tensor<?xf64> {
   %r = linalg.generic #trait_four_tensors
-    ins(%arga, %argb, %argc, %argd: tensor<?xf64>, tensor<?xf64>, tensor<?xf64>, tensor<?xf64>)
+    ins(%arga, %argb, %argc, %argd: tensor<?xf64>, tensor<?xf64, #SV>, tensor<?xf64>, tensor<?xf64, #SV>)
     outs(%argx: tensor<?xf64>) {
       ^bb(%a: f64, %b: f64, %c: f64, %d: f64, %x: f64):
         %0 = addf %a, %b : f64
@@ -1354,32 +1254,26 @@ func @four_tensors_op(%arga: tensor<?xf64>,
     affine_map<(i) -> (i)>,
     affine_map<(i) -> ()>
   ],
-  sparse = [
-    ["S"],
-    ["S"],
-    ["S"],
-    []
-  ],
   iterator_types = ["reduction"],
   doc = "x += a(i) + b(i) + c(i)"
 }
 
 // CHECK-LABEL:   func @red3s(
-// CHECK-SAME:                %[[VAL_0:.*0]]: tensor<?xf64>,
-// CHECK-SAME:                %[[VAL_1:.*1]]: tensor<?xf64>,
-// CHECK-SAME:                %[[VAL_2:.*2]]: tensor<?xf64>,
+// CHECK-SAME:                %[[VAL_0:.*0]]: tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                %[[VAL_1:.*1]]: tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                %[[VAL_2:.*2]]: tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                %[[VAL_3:.*3]]: tensor<f64>) -> tensor<f64> {
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<?xf64> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<?xf64> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?xf64> to memref<?xf64>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<?xf64> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<?xf64> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf64> to memref<?xf64>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_2]], %[[VAL_4]] : tensor<?xf64> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_2]], %[[VAL_4]] : tensor<?xf64> to memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<?xf64> to memref<?xf64>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_2]], %[[VAL_4]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_2]], %[[VAL_4]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
 // CHECK:           %[[VAL_15:.*]] = memref.buffer_cast %[[VAL_3]] : memref<f64>
 // CHECK:           %[[VAL_16:.*]] = memref.alloc() : memref<f64>
 // CHECK:           linalg.copy(%[[VAL_15]], %[[VAL_16]]) : memref<f64>, memref<f64>
@@ -1655,11 +1549,11 @@ func @four_tensors_op(%arga: tensor<?xf64>,
 // CHECK:           %[[VAL_229:.*]] = memref.tensor_load %[[VAL_16]] : memref<f64>
 // CHECK:           return %[[VAL_229]] : tensor<f64>
 // CHECK:         }
-func @red3s(%arga: tensor<?xf64>,
-            %argb: tensor<?xf64>,
-	    %argc: tensor<?xf64>, %argx: tensor<f64>) ->tensor<f64>{
+func @red3s(%arga: tensor<?xf64, #SV>,
+            %argb: tensor<?xf64, #SV>,
+	    %argc: tensor<?xf64, #SV>, %argx: tensor<f64>) ->tensor<f64>{
  %0 = linalg.generic #trait_red3s
-   ins(%arga, %argb, %argc: tensor<?xf64>, tensor<?xf64>, tensor<?xf64>)
+   ins(%arga, %argb, %argc: tensor<?xf64, #SV>, tensor<?xf64, #SV>, tensor<?xf64, #SV>)
    outs(%argx: tensor<f64>) {
      ^bb(%a: f64,%b: f64,%c: f64,%x: f64):
         %0 = addf %x, %a : f64

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
index 80febcad1c0b1..91ed85723bd9f 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
@@ -1,48 +1,50 @@
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
 // RUN: mlir-opt %s -sparsification | FileCheck %s
 
-#trait_dd = {
+#Tdd = #sparse_tensor.encoding<{ dimLevelType = [ "dense",      "dense"      ] }>
+#Tds = #sparse_tensor.encoding<{ dimLevelType = [ "dense",      "compressed" ] }>
+#Tsd = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense"      ] }>
+#Tss = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+
+#trait2 = {
   indexing_maps = [
     affine_map<(i,j) -> (i,j)>,  // A
     affine_map<(i,j) -> (i,j)>,  // B
     affine_map<(i,j) -> (i,j)>   // X (out)
   ],
-  sparse = [
-    [ "D", "D" ],  // A
-    [ "D", "D" ],  // B
-    [ "D", "D" ]   // X
-  ],
   iterator_types = ["parallel", "parallel"],
   doc = "X(i,j) = A(i,j) OP B(i,j)"
 }
 
 // CHECK-LABEL:   func @add_dd(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32x16xf32>,
+// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 16 : index
 // CHECK:           %[[VAL_5:.*]] = constant 0 : index
 // CHECK:           %[[VAL_6:.*]] = constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_0]] : memref<32x16xf32>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.alloc() : memref<32x16xf32>
 // CHECK:           linalg.copy(%[[VAL_9]], %[[VAL_10]]) : memref<32x16xf32>, memref<32x16xf32>
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
 // CHECK:             scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK:               %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32>
-// CHECK:               %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32>
-// CHECK:               %[[VAL_15:.*]] = addf %[[VAL_13]], %[[VAL_14]] : f32
-// CHECK:               store %[[VAL_15]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32>
+// CHECK:               %[[VAL_13:.*]] = muli %[[VAL_11]], %[[VAL_4]] : index
+// CHECK:               %[[VAL_14:.*]] = addi %[[VAL_13]], %[[VAL_12]] : index
+// CHECK:               %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref<?xf32>
+// CHECK:               %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32>
+// CHECK:               %[[VAL_17:.*]] = addf %[[VAL_15]], %[[VAL_16]] : f32
+// CHECK:               memref.store %[[VAL_17]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32>
 // CHECK:             }
 // CHECK:           }
-// CHECK:           %[[VAL_16:.*]] = memref.tensor_load %[[VAL_10]] : memref<32x16xf32>
-// CHECK:           return %[[VAL_16]] : tensor<32x16xf32>
+// CHECK:           %[[VAL_18:.*]] = memref.tensor_load %[[VAL_10]] : memref<32x16xf32>
+// CHECK:           return %[[VAL_18]] : tensor<32x16xf32>
 // CHECK:         }
-func @add_dd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  %0 = linalg.generic #trait_dd
-     ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+func @add_dd(%arga: tensor<32x16xf32, #Tdd>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32x16xf32, #Tdd>, tensor<32x16xf32>)
     outs(%argx: tensor<32x16xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -52,32 +54,34 @@ func @add_dd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 }
 
 // CHECK-LABEL:   func @mul_dd(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32x16xf32>,
+// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 16 : index
 // CHECK:           %[[VAL_5:.*]] = constant 0 : index
 // CHECK:           %[[VAL_6:.*]] = constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_0]] : memref<32x16xf32>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.alloc() : memref<32x16xf32>
 // CHECK:           linalg.copy(%[[VAL_9]], %[[VAL_10]]) : memref<32x16xf32>, memref<32x16xf32>
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
 // CHECK:             scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK:               %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32>
-// CHECK:               %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32>
-// CHECK:               %[[VAL_15:.*]] = mulf %[[VAL_13]], %[[VAL_14]] : f32
-// CHECK:               store %[[VAL_15]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32>
+// CHECK:               %[[VAL_13:.*]] = muli %[[VAL_11]], %[[VAL_4]] : index
+// CHECK:               %[[VAL_14:.*]] = addi %[[VAL_13]], %[[VAL_12]] : index
+// CHECK:               %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref<?xf32>
+// CHECK:               %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32>
+// CHECK:               %[[VAL_17:.*]] = mulf %[[VAL_15]], %[[VAL_16]] : f32
+// CHECK:               memref.store %[[VAL_17]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<32x16xf32>
 // CHECK:             }
 // CHECK:           }
-// CHECK:           %[[VAL_16:.*]] = memref.tensor_load %[[VAL_10]] : memref<32x16xf32>
-// CHECK:           return %[[VAL_16]] : tensor<32x16xf32>
+// CHECK:           %[[VAL_18:.*]] = memref.tensor_load %[[VAL_10]] : memref<32x16xf32>
+// CHECK:           return %[[VAL_18]] : tensor<32x16xf32>
 // CHECK:         }
-func @mul_dd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  %0 = linalg.generic #trait_dd
-     ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+func @mul_dd(%arga: tensor<32x16xf32, #Tdd>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32x16xf32, #Tdd>, tensor<32x16xf32>)
     outs(%argx: tensor<32x16xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -86,33 +90,18 @@ func @mul_dd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
   return %0 : tensor<32x16xf32>
 }
 
-#trait_ds = {
-  indexing_maps = [
-    affine_map<(i,j) -> (i,j)>,  // A
-    affine_map<(i,j) -> (i,j)>,  // B
-    affine_map<(i,j) -> (i,j)>   // X (out)
-  ],
-  sparse = [
-    [ "D", "S" ],  // A
-    [ "D", "D" ],  // B
-    [ "D", "D" ]   // X
-  ],
-  iterator_types = ["parallel", "parallel"],
-  doc = "X(i,j) = A(i,j) OP B(i,j)"
-}
-
 // CHECK-LABEL:   func @add_ds(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32x16xf32>,
+// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 16 : index
 // CHECK:           %[[VAL_5:.*]] = constant 0 : index
 // CHECK:           %[[VAL_6:.*]] = constant true
 // CHECK:           %[[VAL_7:.*]] = constant 1 : index
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32> to memref<?xf32>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.alloc() : memref<32x16xf32>
@@ -132,11 +121,11 @@ func @mul_dd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_22]]] : memref<?xf32>
 // CHECK:                 %[[VAL_27:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_14]], %[[VAL_23]]] : memref<32x16xf32>
 // CHECK:                 %[[VAL_28:.*]] = addf %[[VAL_26]], %[[VAL_27]] : f32
-// CHECK:                 store %[[VAL_28]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_23]]] : memref<32x16xf32>
+// CHECK:                 memref.store %[[VAL_28]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_23]]] : memref<32x16xf32>
 // CHECK:               } else {
 // CHECK:                 scf.if %[[VAL_6]] {
 // CHECK:                   %[[VAL_29:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_14]], %[[VAL_23]]] : memref<32x16xf32>
-// CHECK:                   store %[[VAL_29]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_23]]] : memref<32x16xf32>
+// CHECK:                   memref.store %[[VAL_29]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_23]]] : memref<32x16xf32>
 // CHECK:                 } else {
 // CHECK:                 }
 // CHECK:               }
@@ -148,15 +137,15 @@ func @mul_dd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:             }
 // CHECK:             scf.for %[[VAL_34:.*]] = %[[VAL_35:.*]]#1 to %[[VAL_4]] step %[[VAL_7]] {
 // CHECK:               %[[VAL_36:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_14]], %[[VAL_34]]] : memref<32x16xf32>
-// CHECK:               store %[[VAL_36]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_34]]] : memref<32x16xf32>
+// CHECK:               memref.store %[[VAL_36]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_34]]] : memref<32x16xf32>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_37:.*]] = memref.tensor_load %[[VAL_13]] : memref<32x16xf32>
 // CHECK:           return %[[VAL_37]] : tensor<32x16xf32>
 // CHECK:         }
-func @add_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  %0 = linalg.generic #trait_ds
-     ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+func @add_ds(%arga: tensor<32x16xf32, #Tds>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32x16xf32, #Tds>, tensor<32x16xf32>)
     outs(%argx: tensor<32x16xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -166,15 +155,15 @@ func @add_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 }
 
 // CHECK-LABEL:   func @mul_ds(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32x16xf32>,
+// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32> to memref<?xf32>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.alloc() : memref<32x16xf32>
@@ -188,15 +177,15 @@ func @add_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:               %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xf32>
 // CHECK:               %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]], %[[VAL_17]]] : memref<32x16xf32>
 // CHECK:               %[[VAL_20:.*]] = mulf %[[VAL_18]], %[[VAL_19]] : f32
-// CHECK:               store %[[VAL_20]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_17]]] : memref<32x16xf32>
+// CHECK:               memref.store %[[VAL_20]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_17]]] : memref<32x16xf32>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_21:.*]] = memref.tensor_load %[[VAL_11]] : memref<32x16xf32>
 // CHECK:           return %[[VAL_21]] : tensor<32x16xf32>
 // CHECK:         }
-func @mul_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  %0 = linalg.generic #trait_ds
-     ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+func @mul_ds(%arga: tensor<32x16xf32, #Tds>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32x16xf32, #Tds>, tensor<32x16xf32>)
     outs(%argx: tensor<32x16xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -205,33 +194,18 @@ func @mul_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
   return %0 : tensor<32x16xf32>
 }
 
-#trait_sd = {
-  indexing_maps = [
-    affine_map<(i,j) -> (i,j)>,  // A
-    affine_map<(i,j) -> (i,j)>,  // B
-    affine_map<(i,j) -> (i,j)>   // X (out)
-  ],
-  sparse = [
-    [ "S", "D" ],  // A
-    [ "D", "D" ],  // B
-    [ "D", "D" ]   // X
-  ],
-  iterator_types = ["parallel", "parallel"],
-  doc = "X(i,j) = A(i,j) OP B(i,j)"
-}
-
 // CHECK-LABEL:   func @add_sd(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32x16xf32>,
+// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 16 : index
 // CHECK:           %[[VAL_5:.*]] = constant true
 // CHECK:           %[[VAL_6:.*]] = constant 0 : index
 // CHECK:           %[[VAL_7:.*]] = constant 1 : index
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32> to memref<?xf32>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.alloc() : memref<32x16xf32>
@@ -252,13 +226,13 @@ func @mul_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:                 %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref<?xf32>
 // CHECK:                 %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_24]]] : memref<32x16xf32>
 // CHECK:                 %[[VAL_29:.*]] = addf %[[VAL_27]], %[[VAL_28]] : f32
-// CHECK:                 store %[[VAL_29]], %[[VAL_13]]{{\[}}%[[VAL_21]], %[[VAL_24]]] : memref<32x16xf32>
+// CHECK:                 memref.store %[[VAL_29]], %[[VAL_13]]{{\[}}%[[VAL_21]], %[[VAL_24]]] : memref<32x16xf32>
 // CHECK:               }
 // CHECK:             } else {
 // CHECK:               scf.if %[[VAL_5]] {
 // CHECK:                 scf.for %[[VAL_30:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
 // CHECK:                   %[[VAL_31:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_30]]] : memref<32x16xf32>
-// CHECK:                   store %[[VAL_31]], %[[VAL_13]]{{\[}}%[[VAL_21]], %[[VAL_30]]] : memref<32x16xf32>
+// CHECK:                   memref.store %[[VAL_31]], %[[VAL_13]]{{\[}}%[[VAL_21]], %[[VAL_30]]] : memref<32x16xf32>
 // CHECK:                 }
 // CHECK:               } else {
 // CHECK:               }
@@ -272,15 +246,15 @@ func @mul_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:           scf.for %[[VAL_36:.*]] = %[[VAL_37:.*]]#1 to %[[VAL_3]] step %[[VAL_7]] {
 // CHECK:             scf.for %[[VAL_38:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
 // CHECK:               %[[VAL_39:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_36]], %[[VAL_38]]] : memref<32x16xf32>
-// CHECK:               store %[[VAL_39]], %[[VAL_13]]{{\[}}%[[VAL_36]], %[[VAL_38]]] : memref<32x16xf32>
+// CHECK:               memref.store %[[VAL_39]], %[[VAL_13]]{{\[}}%[[VAL_36]], %[[VAL_38]]] : memref<32x16xf32>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_40:.*]] = memref.tensor_load %[[VAL_13]] : memref<32x16xf32>
 // CHECK:           return %[[VAL_40]] : tensor<32x16xf32>
 // CHECK:         }
-func @add_sd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  %0 = linalg.generic #trait_sd
-     ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+func @add_sd(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32x16xf32, #Tsd>, tensor<32x16xf32>)
     outs(%argx: tensor<32x16xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -290,15 +264,15 @@ func @add_sd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 }
 
 // CHECK-LABEL:   func @mul_sd(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32x16xf32>,
+// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 16 : index
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32> to memref<?xf32>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.alloc() : memref<32x16xf32>
@@ -313,15 +287,15 @@ func @add_sd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:               %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf32>
 // CHECK:               %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_15]], %[[VAL_16]]] : memref<32x16xf32>
 // CHECK:               %[[VAL_21:.*]] = mulf %[[VAL_19]], %[[VAL_20]] : f32
-// CHECK:               store %[[VAL_21]], %[[VAL_11]]{{\[}}%[[VAL_15]], %[[VAL_16]]] : memref<32x16xf32>
+// CHECK:               memref.store %[[VAL_21]], %[[VAL_11]]{{\[}}%[[VAL_15]], %[[VAL_16]]] : memref<32x16xf32>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_22:.*]] = memref.tensor_load %[[VAL_11]] : memref<32x16xf32>
 // CHECK:           return %[[VAL_22]] : tensor<32x16xf32>
 // CHECK:         }
-func @mul_sd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  %0 = linalg.generic #trait_sd
-     ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+func @mul_sd(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32x16xf32, #Tsd>, tensor<32x16xf32>)
     outs(%argx: tensor<32x16xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -330,35 +304,20 @@ func @mul_sd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
   return %0 : tensor<32x16xf32>
 }
 
-#trait_ss = {
-  indexing_maps = [
-    affine_map<(i,j) -> (i,j)>,  // A
-    affine_map<(i,j) -> (i,j)>,  // B
-    affine_map<(i,j) -> (i,j)>   // X (out)
-  ],
-  sparse = [
-    [ "S", "S" ],  // A
-    [ "D", "D" ],  // B
-    [ "D", "D" ]   // X
-  ],
-  iterator_types = ["parallel", "parallel"],
-  doc = "X(i,j) = A(i,j) OP B(i,j)"
-}
-
 // CHECK-LABEL:   func @add_ss(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32x16xf32>,
+// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 16 : index
 // CHECK:           %[[VAL_5:.*]] = constant true
 // CHECK:           %[[VAL_6:.*]] = constant 0 : index
 // CHECK:           %[[VAL_7:.*]] = constant 1 : index
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32> to memref<?xf32>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16xf32>
 // CHECK:           %[[VAL_14:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           %[[VAL_15:.*]] = memref.alloc() : memref<32x16xf32>
@@ -387,11 +346,11 @@ func @mul_sd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:                   %[[VAL_37:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_33]]] : memref<?xf32>
 // CHECK:                   %[[VAL_38:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_23]], %[[VAL_34]]] : memref<32x16xf32>
 // CHECK:                   %[[VAL_39:.*]] = addf %[[VAL_37]], %[[VAL_38]] : f32
-// CHECK:                   store %[[VAL_39]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_34]]] : memref<32x16xf32>
+// CHECK:                   memref.store %[[VAL_39]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_34]]] : memref<32x16xf32>
 // CHECK:                 } else {
 // CHECK:                   scf.if %[[VAL_5]] {
 // CHECK:                     %[[VAL_40:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_23]], %[[VAL_34]]] : memref<32x16xf32>
-// CHECK:                     store %[[VAL_40]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_34]]] : memref<32x16xf32>
+// CHECK:                     memref.store %[[VAL_40]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_34]]] : memref<32x16xf32>
 // CHECK:                   } else {
 // CHECK:                   }
 // CHECK:                 }
@@ -403,13 +362,13 @@ func @mul_sd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:               }
 // CHECK:               scf.for %[[VAL_45:.*]] = %[[VAL_46:.*]]#1 to %[[VAL_4]] step %[[VAL_7]] {
 // CHECK:                 %[[VAL_47:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_23]], %[[VAL_45]]] : memref<32x16xf32>
-// CHECK:                 store %[[VAL_47]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_45]]] : memref<32x16xf32>
+// CHECK:                 memref.store %[[VAL_47]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_45]]] : memref<32x16xf32>
 // CHECK:               }
 // CHECK:             } else {
 // CHECK:               scf.if %[[VAL_5]] {
 // CHECK:                 scf.for %[[VAL_48:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
 // CHECK:                   %[[VAL_49:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_23]], %[[VAL_48]]] : memref<32x16xf32>
-// CHECK:                   store %[[VAL_49]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_48]]] : memref<32x16xf32>
+// CHECK:                   memref.store %[[VAL_49]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_48]]] : memref<32x16xf32>
 // CHECK:                 }
 // CHECK:               } else {
 // CHECK:               }
@@ -423,15 +382,15 @@ func @mul_sd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:           scf.for %[[VAL_54:.*]] = %[[VAL_55:.*]]#1 to %[[VAL_3]] step %[[VAL_7]] {
 // CHECK:             scf.for %[[VAL_56:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
 // CHECK:               %[[VAL_57:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_54]], %[[VAL_56]]] : memref<32x16xf32>
-// CHECK:               store %[[VAL_57]], %[[VAL_15]]{{\[}}%[[VAL_54]], %[[VAL_56]]] : memref<32x16xf32>
+// CHECK:               memref.store %[[VAL_57]], %[[VAL_15]]{{\[}}%[[VAL_54]], %[[VAL_56]]] : memref<32x16xf32>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_58:.*]] = memref.tensor_load %[[VAL_15]] : memref<32x16xf32>
 // CHECK:           return %[[VAL_58]] : tensor<32x16xf32>
 // CHECK:         }
-func @add_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  %0 = linalg.generic #trait_ss
-     ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+func @add_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32x16xf32, #Tss>, tensor<32x16xf32>)
     outs(%argx: tensor<32x16xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -441,16 +400,16 @@ func @add_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 }
 
 // CHECK-LABEL:   func @mul_ss(
-// CHECK-SAME:                 %[[VAL_0:.*0]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_1:.*1]]: tensor<32x16xf32>,
-// CHECK-SAME:                 %[[VAL_2:.*2]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32x16xf32>,
+// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32> to memref<?xf32>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.alloc() : memref<32x16xf32>
@@ -467,15 +426,15 @@ func @add_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref<?xf32>
 // CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]], %[[VAL_21]]] : memref<32x16xf32>
 // CHECK:               %[[VAL_24:.*]] = mulf %[[VAL_22]], %[[VAL_23]] : f32
-// CHECK:               store %[[VAL_24]], %[[VAL_12]]{{\[}}%[[VAL_16]], %[[VAL_21]]] : memref<32x16xf32>
+// CHECK:               memref.store %[[VAL_24]], %[[VAL_12]]{{\[}}%[[VAL_16]], %[[VAL_21]]] : memref<32x16xf32>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_25:.*]] = memref.tensor_load %[[VAL_12]] : memref<32x16xf32>
 // CHECK:           return %[[VAL_25]] : tensor<32x16xf32>
 // CHECK:         }
-func @mul_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  %0 = linalg.generic #trait_ss
-     ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+func @mul_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32x16xf32, #Tss>, tensor<32x16xf32>)
     outs(%argx: tensor<32x16xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -484,37 +443,22 @@ func @mul_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
   return %0 : tensor<32x16xf32>
 }
 
-#trait_ss_ss = {
-  indexing_maps = [
-    affine_map<(i,j) -> (i,j)>,  // A
-    affine_map<(i,j) -> (i,j)>,  // B
-    affine_map<(i,j) -> (i,j)>   // X (out)
-  ],
-  sparse = [
-    [ "S", "S" ],  // A
-    [ "S", "S" ],  // B
-    [ "D", "D" ]   // X
-  ],
-  iterator_types = ["parallel", "parallel"],
-  doc = "X(i,j) = A(i,j) OP B(i,j)"
-}
-
 // CHECK-LABEL:   func @add_ss_ss(
-// CHECK-SAME:                    %[[VAL_0:.*0]]: tensor<32x16xf32>,
-// CHECK-SAME:                    %[[VAL_1:.*1]]: tensor<32x16xf32>,
+// CHECK-SAME:                    %[[VAL_0:.*0]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                    %[[VAL_1:.*1]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                    %[[VAL_2:.*2]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32> to memref<?xf32>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32> to memref<?xf32>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_15:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           %[[VAL_16:.*]] = memref.alloc() : memref<32x16xf32>
 // CHECK:           linalg.copy(%[[VAL_15]], %[[VAL_16]]) : memref<32x16xf32>, memref<32x16xf32>
@@ -561,17 +505,17 @@ func @mul_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:                   %[[VAL_57:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_48]]] : memref<?xf32>
 // CHECK:                   %[[VAL_58:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_49]]] : memref<?xf32>
 // CHECK:                   %[[VAL_59:.*]] = addf %[[VAL_57]], %[[VAL_58]] : f32
-// CHECK:                   store %[[VAL_59]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_53]]] : memref<32x16xf32>
+// CHECK:                   memref.store %[[VAL_59]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_53]]] : memref<32x16xf32>
 // CHECK:                 } else {
 // CHECK:                   %[[VAL_60:.*]] = cmpi eq, %[[VAL_50]], %[[VAL_53]] : index
 // CHECK:                   scf.if %[[VAL_60]] {
 // CHECK:                     %[[VAL_61:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_48]]] : memref<?xf32>
-// CHECK:                     store %[[VAL_61]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_53]]] : memref<32x16xf32>
+// CHECK:                     memref.store %[[VAL_61]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_53]]] : memref<32x16xf32>
 // CHECK:                   } else {
 // CHECK:                     %[[VAL_62:.*]] = cmpi eq, %[[VAL_51]], %[[VAL_53]] : index
 // CHECK:                     scf.if %[[VAL_62]] {
 // CHECK:                       %[[VAL_63:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_49]]] : memref<?xf32>
-// CHECK:                       store %[[VAL_63]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_53]]] : memref<32x16xf32>
+// CHECK:                       memref.store %[[VAL_63]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_53]]] : memref<32x16xf32>
 // CHECK:                     } else {
 // CHECK:                     }
 // CHECK:                   }
@@ -587,12 +531,12 @@ func @mul_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:               scf.for %[[VAL_70:.*]] = %[[VAL_71:.*]]#0 to %[[VAL_38]] step %[[VAL_4]] {
 // CHECK:                 %[[VAL_72:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_70]]] : memref<?xindex>
 // CHECK:                 %[[VAL_73:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_70]]] : memref<?xf32>
-// CHECK:                 store %[[VAL_73]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_72]]] : memref<32x16xf32>
+// CHECK:                 memref.store %[[VAL_73]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_72]]] : memref<32x16xf32>
 // CHECK:               }
 // CHECK:               scf.for %[[VAL_74:.*]] = %[[VAL_75:.*]]#1 to %[[VAL_41]] step %[[VAL_4]] {
 // CHECK:                 %[[VAL_76:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_74]]] : memref<?xindex>
 // CHECK:                 %[[VAL_77:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_74]]] : memref<?xf32>
-// CHECK:                 store %[[VAL_77]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_76]]] : memref<32x16xf32>
+// CHECK:                 memref.store %[[VAL_77]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_76]]] : memref<32x16xf32>
 // CHECK:               }
 // CHECK:             } else {
 // CHECK:               %[[VAL_78:.*]] = cmpi eq, %[[VAL_29]], %[[VAL_32]] : index
@@ -603,7 +547,7 @@ func @mul_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:                 scf.for %[[VAL_82:.*]] = %[[VAL_79]] to %[[VAL_81]] step %[[VAL_4]] {
 // CHECK:                   %[[VAL_83:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_82]]] : memref<?xindex>
 // CHECK:                   %[[VAL_84:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_82]]] : memref<?xf32>
-// CHECK:                   store %[[VAL_84]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_83]]] : memref<32x16xf32>
+// CHECK:                   memref.store %[[VAL_84]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_83]]] : memref<32x16xf32>
 // CHECK:                 }
 // CHECK:               } else {
 // CHECK:                 %[[VAL_85:.*]] = cmpi eq, %[[VAL_30]], %[[VAL_32]] : index
@@ -614,7 +558,7 @@ func @mul_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:                   scf.for %[[VAL_89:.*]] = %[[VAL_86]] to %[[VAL_88]] step %[[VAL_4]] {
 // CHECK:                     %[[VAL_90:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_89]]] : memref<?xindex>
 // CHECK:                     %[[VAL_91:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_89]]] : memref<?xf32>
-// CHECK:                     store %[[VAL_91]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_90]]] : memref<32x16xf32>
+// CHECK:                     memref.store %[[VAL_91]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_90]]] : memref<32x16xf32>
 // CHECK:                   }
 // CHECK:                 } else {
 // CHECK:                 }
@@ -636,7 +580,7 @@ func @mul_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:             scf.for %[[VAL_104:.*]] = %[[VAL_101]] to %[[VAL_103]] step %[[VAL_4]] {
 // CHECK:               %[[VAL_105:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_104]]] : memref<?xindex>
 // CHECK:               %[[VAL_106:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_104]]] : memref<?xf32>
-// CHECK:               store %[[VAL_106]], %[[VAL_16]]{{\[}}%[[VAL_100]], %[[VAL_105]]] : memref<32x16xf32>
+// CHECK:               memref.store %[[VAL_106]], %[[VAL_16]]{{\[}}%[[VAL_100]], %[[VAL_105]]] : memref<32x16xf32>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           scf.for %[[VAL_107:.*]] = %[[VAL_108:.*]]#1 to %[[VAL_20]] step %[[VAL_4]] {
@@ -647,15 +591,15 @@ func @mul_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<3
 // CHECK:             scf.for %[[VAL_113:.*]] = %[[VAL_110]] to %[[VAL_112]] step %[[VAL_4]] {
 // CHECK:               %[[VAL_114:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_113]]] : memref<?xindex>
 // CHECK:               %[[VAL_115:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_113]]] : memref<?xf32>
-// CHECK:               store %[[VAL_115]], %[[VAL_16]]{{\[}}%[[VAL_109]], %[[VAL_114]]] : memref<32x16xf32>
+// CHECK:               memref.store %[[VAL_115]], %[[VAL_16]]{{\[}}%[[VAL_109]], %[[VAL_114]]] : memref<32x16xf32>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_116:.*]] = memref.tensor_load %[[VAL_16]] : memref<32x16xf32>
 // CHECK:           return %[[VAL_116]] : tensor<32x16xf32>
 // CHECK:         }
-func @add_ss_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  %0 = linalg.generic #trait_ss_ss
-     ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #Tss>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32x16xf32, #Tss>, tensor<32x16xf32, #Tss>)
     outs(%argx: tensor<32x16xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -665,21 +609,21 @@ func @add_ss_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tenso
 }
 
 // CHECK-LABEL:   func @mul_ss_ss(
-// CHECK-SAME:                    %[[VAL_0:.*0]]: tensor<32x16xf32>,
-// CHECK-SAME:                    %[[VAL_1:.*1]]: tensor<32x16xf32>,
+// CHECK-SAME:                    %[[VAL_0:.*0]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                    %[[VAL_1:.*1]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                    %[[VAL_2:.*2]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32> to memref<?xf32>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32> to memref<?xf32>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_15:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf32>
 // CHECK:           %[[VAL_16:.*]] = memref.alloc() : memref<32x16xf32>
 // CHECK:           linalg.copy(%[[VAL_15]], %[[VAL_16]]) : memref<32x16xf32>, memref<32x16xf32>
@@ -726,7 +670,7 @@ func @add_ss_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tenso
 // CHECK:                   %[[VAL_57:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_48]]] : memref<?xf32>
 // CHECK:                   %[[VAL_58:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_49]]] : memref<?xf32>
 // CHECK:                   %[[VAL_59:.*]] = mulf %[[VAL_57]], %[[VAL_58]] : f32
-// CHECK:                   store %[[VAL_59]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_53]]] : memref<32x16xf32>
+// CHECK:                   memref.store %[[VAL_59]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_53]]] : memref<32x16xf32>
 // CHECK:                 } else {
 // CHECK:                 }
 // CHECK:                 %[[VAL_60:.*]] = cmpi eq, %[[VAL_50]], %[[VAL_53]] : index
@@ -750,9 +694,9 @@ func @add_ss_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tenso
 // CHECK:           %[[VAL_72:.*]] = memref.tensor_load %[[VAL_16]] : memref<32x16xf32>
 // CHECK:           return %[[VAL_72]] : tensor<32x16xf32>
 // CHECK:         }
-func @mul_ss_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  %0 = linalg.generic #trait_ss_ss
-     ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+func @mul_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #Tss>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32x16xf32, #Tss>, tensor<32x16xf32, #Tss>)
     outs(%argx: tensor<32x16xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -761,178 +705,105 @@ func @mul_ss_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tenso
   return %0 : tensor<32x16xf32>
 }
 
-#trait_sd_ds = {
-  indexing_maps = [
-    affine_map<(i,j) -> (i,j)>,  // A
-    affine_map<(i,j) -> (i,j)>,  // B
-    affine_map<(i,j) -> (i,j)>   // X (out)
-  ],
-  sparse = [
-    [ "S", "D" ],  // A
-    [ "D", "S" ],  // B
-    [ "D", "D" ]   // X
-  ],
-  iterator_types = ["parallel", "parallel"],
-  doc = "X(i,j) = A(i,j) OP B(i,j)"
-}
-
 // CHECK-LABEL:   func @add_sd_ds(
-// CHECK-SAME:                    %[[VAL_0:.*0]]: tensor<32x16xf32>,
-// CHECK-SAME:                    %[[VAL_1:.*1]]: tensor<32x16xf32>,
-// CHECK-SAME:                    %[[VAL_2:.*2]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
-// CHECK:           %[[VAL_3:.*]] = constant 0 : index
-// CHECK:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32> to memref<?xf32>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32> to memref<?xf32>
-// CHECK:           %[[VAL_15:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf32>
-// CHECK:           %[[VAL_16:.*]] = memref.alloc() : memref<32x16xf32>
-// CHECK:           linalg.copy(%[[VAL_15]], %[[VAL_16]]) : memref<32x16xf32>, memref<32x16xf32>
-// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_21:.*]]:2 = scf.while (%[[VAL_22:.*]] = %[[VAL_17]], %[[VAL_23:.*]] = %[[VAL_19]]) : (index, index) -> (index, index) {
-// CHECK:             %[[VAL_24:.*]] = cmpi ult, %[[VAL_22]], %[[VAL_18]] : index
-// CHECK:             %[[VAL_25:.*]] = cmpi ult, %[[VAL_23]], %[[VAL_20]] : index
-// CHECK:             %[[VAL_26:.*]] = and %[[VAL_24]], %[[VAL_25]] : i1
-// CHECK:             scf.condition(%[[VAL_26]]) %[[VAL_22]], %[[VAL_23]] : index, index
+// CHECK-SAME:                    %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                    %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                    %[[VAL_2:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
+// CHECK:           %[[VAL_3:.*]] = constant 32 : index
+// CHECK:           %[[VAL_4:.*]] = constant 16 : index
+// CHECK:           %[[VAL_5:.*]] = constant 0 : index
+// CHECK:           %[[VAL_6:.*]] = constant true
+// CHECK:           %[[VAL_7:.*]] = constant 1 : index
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK:           %[[VAL_14:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf32>
+// CHECK:           %[[VAL_15:.*]] = memref.alloc() : memref<32x16xf32>
+// CHECK:           linalg.copy(%[[VAL_14]], %[[VAL_15]]) : memref<32x16xf32>, memref<32x16xf32>
+// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK:           %[[VAL_18:.*]]:2 = scf.while (%[[VAL_19:.*]] = %[[VAL_16]], %[[VAL_20:.*]] = %[[VAL_5]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_21:.*]] = cmpi ult, %[[VAL_19]], %[[VAL_17]] : index
+// CHECK:             scf.condition(%[[VAL_21]]) %[[VAL_19]], %[[VAL_20]] : index, index
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_27:.*]]: index, %[[VAL_28:.*]]: index):
-// CHECK:             %[[VAL_29:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_27]]] : memref<?xindex>
-// CHECK:             %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_28]]] : memref<?xindex>
-// CHECK:             %[[VAL_31:.*]] = cmpi ult, %[[VAL_30]], %[[VAL_29]] : index
-// CHECK:             %[[VAL_32:.*]] = select %[[VAL_31]], %[[VAL_30]], %[[VAL_29]] : index
-// CHECK:             %[[VAL_33:.*]] = cmpi eq, %[[VAL_29]], %[[VAL_32]] : index
-// CHECK:             %[[VAL_34:.*]] = cmpi eq, %[[VAL_30]], %[[VAL_32]] : index
-// CHECK:             %[[VAL_35:.*]] = and %[[VAL_33]], %[[VAL_34]] : i1
-// CHECK:             scf.if %[[VAL_35]] {
-// CHECK:               %[[VAL_36:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_27]]] : memref<?xindex>
-// CHECK:               %[[VAL_37:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index
-// CHECK:               %[[VAL_38:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_37]]] : memref<?xindex>
-// CHECK:               %[[VAL_39:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_28]]] : memref<?xindex>
-// CHECK:               %[[VAL_40:.*]] = addi %[[VAL_28]], %[[VAL_4]] : index
-// CHECK:               %[[VAL_41:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_40]]] : memref<?xindex>
-// CHECK:               %[[VAL_42:.*]]:2 = scf.while (%[[VAL_43:.*]] = %[[VAL_36]], %[[VAL_44:.*]] = %[[VAL_39]]) : (index, index) -> (index, index) {
-// CHECK:                 %[[VAL_45:.*]] = cmpi ult, %[[VAL_43]], %[[VAL_38]] : index
-// CHECK:                 %[[VAL_46:.*]] = cmpi ult, %[[VAL_44]], %[[VAL_41]] : index
-// CHECK:                 %[[VAL_47:.*]] = and %[[VAL_45]], %[[VAL_46]] : i1
-// CHECK:                 scf.condition(%[[VAL_47]]) %[[VAL_43]], %[[VAL_44]] : index, index
+// CHECK:           ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index):
+// CHECK:             %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_22]]] : memref<?xindex>
+// CHECK:             %[[VAL_25:.*]] = cmpi eq, %[[VAL_24]], %[[VAL_23]] : index
+// CHECK:             scf.if %[[VAL_25]] {
+// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]]] : memref<?xindex>
+// CHECK:               %[[VAL_27:.*]] = addi %[[VAL_23]], %[[VAL_7]] : index
+// CHECK:               %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref<?xindex>
+// CHECK:               %[[VAL_29:.*]]:2 = scf.while (%[[VAL_30:.*]] = %[[VAL_26]], %[[VAL_31:.*]] = %[[VAL_5]]) : (index, index) -> (index, index) {
+// CHECK:                 %[[VAL_32:.*]] = cmpi ult, %[[VAL_30]], %[[VAL_28]] : index
+// CHECK:                 scf.condition(%[[VAL_32]]) %[[VAL_30]], %[[VAL_31]] : index, index
 // CHECK:               } do {
-// CHECK:               ^bb0(%[[VAL_48:.*]]: index, %[[VAL_49:.*]]: index):
-// CHECK:                 %[[VAL_50:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_48]]] : memref<?xindex>
-// CHECK:                 %[[VAL_51:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_49]]] : memref<?xindex>
-// CHECK:                 %[[VAL_52:.*]] = cmpi ult, %[[VAL_51]], %[[VAL_50]] : index
-// CHECK:                 %[[VAL_53:.*]] = select %[[VAL_52]], %[[VAL_51]], %[[VAL_50]] : index
-// CHECK:                 %[[VAL_54:.*]] = cmpi eq, %[[VAL_50]], %[[VAL_53]] : index
-// CHECK:                 %[[VAL_55:.*]] = cmpi eq, %[[VAL_51]], %[[VAL_53]] : index
-// CHECK:                 %[[VAL_56:.*]] = and %[[VAL_54]], %[[VAL_55]] : i1
-// CHECK:                 scf.if %[[VAL_56]] {
-// CHECK:                   %[[VAL_57:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_48]]] : memref<?xf32>
-// CHECK:                   %[[VAL_58:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_49]]] : memref<?xf32>
-// CHECK:                   %[[VAL_59:.*]] = addf %[[VAL_57]], %[[VAL_58]] : f32
-// CHECK:                   store %[[VAL_59]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_53]]] : memref<32x16xf32>
+// CHECK:               ^bb0(%[[VAL_33:.*]]: index, %[[VAL_34:.*]]: index):
+// CHECK:                 %[[VAL_35:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_33]]] : memref<?xindex>
+// CHECK:                 %[[VAL_36:.*]] = muli %[[VAL_22]], %[[VAL_4]] : index
+// CHECK:                 %[[VAL_37:.*]] = addi %[[VAL_36]], %[[VAL_34]] : index
+// CHECK:                 %[[VAL_38:.*]] = cmpi eq, %[[VAL_35]], %[[VAL_34]] : index
+// CHECK:                 scf.if %[[VAL_38]] {
+// CHECK:                   %[[VAL_39:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_37]]] : memref<?xf32>
+// CHECK:                   %[[VAL_40:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:                   %[[VAL_41:.*]] = addf %[[VAL_39]], %[[VAL_40]] : f32
+// CHECK:                   memref.store %[[VAL_41]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_34]]] : memref<32x16xf32>
 // CHECK:                 } else {
-// CHECK:                   %[[VAL_60:.*]] = cmpi eq, %[[VAL_50]], %[[VAL_53]] : index
-// CHECK:                   scf.if %[[VAL_60]] {
-// CHECK:                     %[[VAL_61:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_48]]] : memref<?xf32>
-// CHECK:                     store %[[VAL_61]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_53]]] : memref<32x16xf32>
+// CHECK:                   scf.if %[[VAL_6]] {
+// CHECK:                     %[[VAL_42:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_37]]] : memref<?xf32>
+// CHECK:                     memref.store %[[VAL_42]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_34]]] : memref<32x16xf32>
 // CHECK:                   } else {
-// CHECK:                     %[[VAL_62:.*]] = cmpi eq, %[[VAL_51]], %[[VAL_53]] : index
-// CHECK:                     scf.if %[[VAL_62]] {
-// CHECK:                       %[[VAL_63:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_49]]] : memref<?xf32>
-// CHECK:                       store %[[VAL_63]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_53]]] : memref<32x16xf32>
-// CHECK:                     } else {
-// CHECK:                     }
 // CHECK:                   }
 // CHECK:                 }
-// CHECK:                 %[[VAL_64:.*]] = cmpi eq, %[[VAL_50]], %[[VAL_53]] : index
-// CHECK:                 %[[VAL_65:.*]] = addi %[[VAL_48]], %[[VAL_4]] : index
-// CHECK:                 %[[VAL_66:.*]] = select %[[VAL_64]], %[[VAL_65]], %[[VAL_48]] : index
-// CHECK:                 %[[VAL_67:.*]] = cmpi eq, %[[VAL_51]], %[[VAL_53]] : index
-// CHECK:                 %[[VAL_68:.*]] = addi %[[VAL_49]], %[[VAL_4]] : index
-// CHECK:                 %[[VAL_69:.*]] = select %[[VAL_67]], %[[VAL_68]], %[[VAL_49]] : index
-// CHECK:                 scf.yield %[[VAL_66]], %[[VAL_69]] : index, index
-// CHECK:               }
-// CHECK:               scf.for %[[VAL_70:.*]] = %[[VAL_71:.*]]#0 to %[[VAL_38]] step %[[VAL_4]] {
-// CHECK:                 %[[VAL_72:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_70]]] : memref<?xindex>
-// CHECK:                 %[[VAL_73:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_70]]] : memref<?xf32>
-// CHECK:                 store %[[VAL_73]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_72]]] : memref<32x16xf32>
+// CHECK:                 %[[VAL_43:.*]] = cmpi eq, %[[VAL_35]], %[[VAL_34]] : index
+// CHECK:                 %[[VAL_44:.*]] = addi %[[VAL_33]], %[[VAL_7]] : index
+// CHECK:                 %[[VAL_45:.*]] = select %[[VAL_43]], %[[VAL_44]], %[[VAL_33]] : index
+// CHECK:                 %[[VAL_46:.*]] = addi %[[VAL_34]], %[[VAL_7]] : index
+// CHECK:                 scf.yield %[[VAL_45]], %[[VAL_46]] : index, index
 // CHECK:               }
-// CHECK:               scf.for %[[VAL_74:.*]] = %[[VAL_75:.*]]#1 to %[[VAL_41]] step %[[VAL_4]] {
-// CHECK:                 %[[VAL_76:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_74]]] : memref<?xindex>
-// CHECK:                 %[[VAL_77:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_74]]] : memref<?xf32>
-// CHECK:                 store %[[VAL_77]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_76]]] : memref<32x16xf32>
+// CHECK:               scf.for %[[VAL_47:.*]] = %[[VAL_48:.*]]#1 to %[[VAL_4]] step %[[VAL_7]] {
+// CHECK:                 %[[VAL_49:.*]] = muli %[[VAL_22]], %[[VAL_4]] : index
+// CHECK:                 %[[VAL_50:.*]] = addi %[[VAL_49]], %[[VAL_47]] : index
+// CHECK:                 %[[VAL_51:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_50]]] : memref<?xf32>
+// CHECK:                 memref.store %[[VAL_51]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_47]]] : memref<32x16xf32>
 // CHECK:               }
 // CHECK:             } else {
-// CHECK:               %[[VAL_78:.*]] = cmpi eq, %[[VAL_29]], %[[VAL_32]] : index
-// CHECK:               scf.if %[[VAL_78]] {
-// CHECK:                 %[[VAL_79:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_27]]] : memref<?xindex>
-// CHECK:                 %[[VAL_80:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index
-// CHECK:                 %[[VAL_81:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_80]]] : memref<?xindex>
-// CHECK:                 scf.for %[[VAL_82:.*]] = %[[VAL_79]] to %[[VAL_81]] step %[[VAL_4]] {
-// CHECK:                   %[[VAL_83:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_82]]] : memref<?xindex>
-// CHECK:                   %[[VAL_84:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_82]]] : memref<?xf32>
-// CHECK:                   store %[[VAL_84]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_83]]] : memref<32x16xf32>
+// CHECK:               scf.if %[[VAL_6]] {
+// CHECK:                 %[[VAL_52:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]]] : memref<?xindex>
+// CHECK:                 %[[VAL_53:.*]] = addi %[[VAL_23]], %[[VAL_7]] : index
+// CHECK:                 %[[VAL_54:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_53]]] : memref<?xindex>
+// CHECK:                 scf.for %[[VAL_55:.*]] = %[[VAL_52]] to %[[VAL_54]] step %[[VAL_7]] {
+// CHECK:                   %[[VAL_56:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_55]]] : memref<?xindex>
+// CHECK:                   %[[VAL_57:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_55]]] : memref<?xf32>
+// CHECK:                   memref.store %[[VAL_57]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_56]]] : memref<32x16xf32>
 // CHECK:                 }
 // CHECK:               } else {
-// CHECK:                 %[[VAL_85:.*]] = cmpi eq, %[[VAL_30]], %[[VAL_32]] : index
-// CHECK:                 scf.if %[[VAL_85]] {
-// CHECK:                   %[[VAL_86:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_28]]] : memref<?xindex>
-// CHECK:                   %[[VAL_87:.*]] = addi %[[VAL_28]], %[[VAL_4]] : index
-// CHECK:                   %[[VAL_88:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_87]]] : memref<?xindex>
-// CHECK:                   scf.for %[[VAL_89:.*]] = %[[VAL_86]] to %[[VAL_88]] step %[[VAL_4]] {
-// CHECK:                     %[[VAL_90:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_89]]] : memref<?xindex>
-// CHECK:                     %[[VAL_91:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_89]]] : memref<?xf32>
-// CHECK:                     store %[[VAL_91]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_90]]] : memref<32x16xf32>
-// CHECK:                   }
-// CHECK:                 } else {
-// CHECK:                 }
 // CHECK:               }
 // CHECK:             }
-// CHECK:             %[[VAL_92:.*]] = cmpi eq, %[[VAL_29]], %[[VAL_32]] : index
-// CHECK:             %[[VAL_93:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index
-// CHECK:             %[[VAL_94:.*]] = select %[[VAL_92]], %[[VAL_93]], %[[VAL_27]] : index
-// CHECK:             %[[VAL_95:.*]] = cmpi eq, %[[VAL_30]], %[[VAL_32]] : index
-// CHECK:             %[[VAL_96:.*]] = addi %[[VAL_28]], %[[VAL_4]] : index
-// CHECK:             %[[VAL_97:.*]] = select %[[VAL_95]], %[[VAL_96]], %[[VAL_28]] : index
-// CHECK:             scf.yield %[[VAL_94]], %[[VAL_97]] : index, index
-// CHECK:           }
-// CHECK:           scf.for %[[VAL_98:.*]] = %[[VAL_99:.*]]#0 to %[[VAL_18]] step %[[VAL_4]] {
-// CHECK:             %[[VAL_100:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_98]]] : memref<?xindex>
-// CHECK:             %[[VAL_101:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_98]]] : memref<?xindex>
-// CHECK:             %[[VAL_102:.*]] = addi %[[VAL_98]], %[[VAL_4]] : index
-// CHECK:             %[[VAL_103:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_102]]] : memref<?xindex>
-// CHECK:             scf.for %[[VAL_104:.*]] = %[[VAL_101]] to %[[VAL_103]] step %[[VAL_4]] {
-// CHECK:               %[[VAL_105:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_104]]] : memref<?xindex>
-// CHECK:               %[[VAL_106:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_104]]] : memref<?xf32>
-// CHECK:               store %[[VAL_106]], %[[VAL_16]]{{\[}}%[[VAL_100]], %[[VAL_105]]] : memref<32x16xf32>
-// CHECK:             }
+// CHECK:             %[[VAL_58:.*]] = cmpi eq, %[[VAL_24]], %[[VAL_23]] : index
+// CHECK:             %[[VAL_59:.*]] = addi %[[VAL_22]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_60:.*]] = select %[[VAL_58]], %[[VAL_59]], %[[VAL_22]] : index
+// CHECK:             %[[VAL_61:.*]] = addi %[[VAL_23]], %[[VAL_7]] : index
+// CHECK:             scf.yield %[[VAL_60]], %[[VAL_61]] : index, index
 // CHECK:           }
-// CHECK:           scf.for %[[VAL_107:.*]] = %[[VAL_108:.*]]#1 to %[[VAL_20]] step %[[VAL_4]] {
-// CHECK:             %[[VAL_109:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_107]]] : memref<?xindex>
-// CHECK:             %[[VAL_110:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_107]]] : memref<?xindex>
-// CHECK:             %[[VAL_111:.*]] = addi %[[VAL_107]], %[[VAL_4]] : index
-// CHECK:             %[[VAL_112:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_111]]] : memref<?xindex>
-// CHECK:             scf.for %[[VAL_113:.*]] = %[[VAL_110]] to %[[VAL_112]] step %[[VAL_4]] {
-// CHECK:               %[[VAL_114:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_113]]] : memref<?xindex>
-// CHECK:               %[[VAL_115:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_113]]] : memref<?xf32>
-// CHECK:               store %[[VAL_115]], %[[VAL_16]]{{\[}}%[[VAL_109]], %[[VAL_114]]] : memref<32x16xf32>
+// CHECK:           scf.for %[[VAL_62:.*]] = %[[VAL_63:.*]]#1 to %[[VAL_3]] step %[[VAL_7]] {
+// CHECK:             %[[VAL_64:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_62]]] : memref<?xindex>
+// CHECK:             %[[VAL_65:.*]] = addi %[[VAL_62]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_66:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_65]]] : memref<?xindex>
+// CHECK:             scf.for %[[VAL_67:.*]] = %[[VAL_64]] to %[[VAL_66]] step %[[VAL_7]] {
+// CHECK:               %[[VAL_68:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_67]]] : memref<?xindex>
+// CHECK:               %[[VAL_69:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_67]]] : memref<?xf32>
+// CHECK:               memref.store %[[VAL_69]], %[[VAL_15]]{{\[}}%[[VAL_62]], %[[VAL_68]]] : memref<32x16xf32>
 // CHECK:             }
 // CHECK:           }
-// CHECK:           %[[VAL_116:.*]] = memref.tensor_load %[[VAL_16]] : memref<32x16xf32>
-// CHECK:           return %[[VAL_116]] : tensor<32x16xf32>
+// CHECK:           %[[VAL_70:.*]] = memref.tensor_load %[[VAL_15]] : memref<32x16xf32>
+// CHECK:           return %[[VAL_70]] : tensor<32x16xf32>
 // CHECK:         }
-func @add_sd_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  %0 = linalg.generic #trait_ss_ss
-     ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+func @add_sd_ds(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32, #Tds>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32x16xf32, #Tsd>, tensor<32x16xf32, #Tds>)
     outs(%argx: tensor<32x16xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -942,94 +813,44 @@ func @add_sd_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tenso
 }
 
 // CHECK-LABEL:   func @mul_sd_ds(
-// CHECK-SAME:                    %[[VAL_0:.*0]]: tensor<32x16xf32>,
-// CHECK-SAME:                    %[[VAL_1:.*1]]: tensor<32x16xf32>,
-// CHECK-SAME:                    %[[VAL_2:.*2]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
-// CHECK:           %[[VAL_3:.*]] = constant 0 : index
-// CHECK:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32> to memref<?xf32>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<32x16xf32> to memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32> to memref<?xf32>
-// CHECK:           %[[VAL_15:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf32>
-// CHECK:           %[[VAL_16:.*]] = memref.alloc() : memref<32x16xf32>
-// CHECK:           linalg.copy(%[[VAL_15]], %[[VAL_16]]) : memref<32x16xf32>, memref<32x16xf32>
-// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_21:.*]]:2 = scf.while (%[[VAL_22:.*]] = %[[VAL_17]], %[[VAL_23:.*]] = %[[VAL_19]]) : (index, index) -> (index, index) {
-// CHECK:             %[[VAL_24:.*]] = cmpi ult, %[[VAL_22]], %[[VAL_18]] : index
-// CHECK:             %[[VAL_25:.*]] = cmpi ult, %[[VAL_23]], %[[VAL_20]] : index
-// CHECK:             %[[VAL_26:.*]] = and %[[VAL_24]], %[[VAL_25]] : i1
-// CHECK:             scf.condition(%[[VAL_26]]) %[[VAL_22]], %[[VAL_23]] : index, index
-// CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_27:.*]]: index, %[[VAL_28:.*]]: index):
-// CHECK:             %[[VAL_29:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_27]]] : memref<?xindex>
-// CHECK:             %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_28]]] : memref<?xindex>
-// CHECK:             %[[VAL_31:.*]] = cmpi ult, %[[VAL_30]], %[[VAL_29]] : index
-// CHECK:             %[[VAL_32:.*]] = select %[[VAL_31]], %[[VAL_30]], %[[VAL_29]] : index
-// CHECK:             %[[VAL_33:.*]] = cmpi eq, %[[VAL_29]], %[[VAL_32]] : index
-// CHECK:             %[[VAL_34:.*]] = cmpi eq, %[[VAL_30]], %[[VAL_32]] : index
-// CHECK:             %[[VAL_35:.*]] = and %[[VAL_33]], %[[VAL_34]] : i1
-// CHECK:             scf.if %[[VAL_35]] {
-// CHECK:               %[[VAL_36:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_27]]] : memref<?xindex>
-// CHECK:               %[[VAL_37:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index
-// CHECK:               %[[VAL_38:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_37]]] : memref<?xindex>
-// CHECK:               %[[VAL_39:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_28]]] : memref<?xindex>
-// CHECK:               %[[VAL_40:.*]] = addi %[[VAL_28]], %[[VAL_4]] : index
-// CHECK:               %[[VAL_41:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_40]]] : memref<?xindex>
-// CHECK:               %[[VAL_42:.*]]:2 = scf.while (%[[VAL_43:.*]] = %[[VAL_36]], %[[VAL_44:.*]] = %[[VAL_39]]) : (index, index) -> (index, index) {
-// CHECK:                 %[[VAL_45:.*]] = cmpi ult, %[[VAL_43]], %[[VAL_38]] : index
-// CHECK:                 %[[VAL_46:.*]] = cmpi ult, %[[VAL_44]], %[[VAL_41]] : index
-// CHECK:                 %[[VAL_47:.*]] = and %[[VAL_45]], %[[VAL_46]] : i1
-// CHECK:                 scf.condition(%[[VAL_47]]) %[[VAL_43]], %[[VAL_44]] : index, index
-// CHECK:               } do {
-// CHECK:               ^bb0(%[[VAL_48:.*]]: index, %[[VAL_49:.*]]: index):
-// CHECK:                 %[[VAL_50:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_48]]] : memref<?xindex>
-// CHECK:                 %[[VAL_51:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_49]]] : memref<?xindex>
-// CHECK:                 %[[VAL_52:.*]] = cmpi ult, %[[VAL_51]], %[[VAL_50]] : index
-// CHECK:                 %[[VAL_53:.*]] = select %[[VAL_52]], %[[VAL_51]], %[[VAL_50]] : index
-// CHECK:                 %[[VAL_54:.*]] = cmpi eq, %[[VAL_50]], %[[VAL_53]] : index
-// CHECK:                 %[[VAL_55:.*]] = cmpi eq, %[[VAL_51]], %[[VAL_53]] : index
-// CHECK:                 %[[VAL_56:.*]] = and %[[VAL_54]], %[[VAL_55]] : i1
-// CHECK:                 scf.if %[[VAL_56]] {
-// CHECK:                   %[[VAL_57:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_48]]] : memref<?xf32>
-// CHECK:                   %[[VAL_58:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_49]]] : memref<?xf32>
-// CHECK:                   %[[VAL_59:.*]] = mulf %[[VAL_57]], %[[VAL_58]] : f32
-// CHECK:                   store %[[VAL_59]], %[[VAL_16]]{{\[}}%[[VAL_32]], %[[VAL_53]]] : memref<32x16xf32>
-// CHECK:                 } else {
-// CHECK:                 }
-// CHECK:                 %[[VAL_60:.*]] = cmpi eq, %[[VAL_50]], %[[VAL_53]] : index
-// CHECK:                 %[[VAL_61:.*]] = addi %[[VAL_48]], %[[VAL_4]] : index
-// CHECK:                 %[[VAL_62:.*]] = select %[[VAL_60]], %[[VAL_61]], %[[VAL_48]] : index
-// CHECK:                 %[[VAL_63:.*]] = cmpi eq, %[[VAL_51]], %[[VAL_53]] : index
-// CHECK:                 %[[VAL_64:.*]] = addi %[[VAL_49]], %[[VAL_4]] : index
-// CHECK:                 %[[VAL_65:.*]] = select %[[VAL_63]], %[[VAL_64]], %[[VAL_49]] : index
-// CHECK:                 scf.yield %[[VAL_62]], %[[VAL_65]] : index, index
-// CHECK:               }
-// CHECK:             } else {
+// CHECK-SAME:                    %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                    %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                    %[[VAL_2:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
+// CHECK:           %[[VAL_3:.*]] = constant 16 : index
+// CHECK:           %[[VAL_4:.*]] = constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_5]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_5]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf32>
+// CHECK:           %[[VAL_13:.*]] = memref.alloc() : memref<32x16xf32>
+// CHECK:           linalg.copy(%[[VAL_12]], %[[VAL_13]]) : memref<32x16xf32>, memref<32x16xf32>
+// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_5]] {
+// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK:             %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref<?xindex>
+// CHECK:             %[[VAL_19:.*]] = addi %[[VAL_17]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK:             scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_5]] {
+// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xindex>
+// CHECK:               %[[VAL_23:.*]] = muli %[[VAL_16]], %[[VAL_3]] : index
+// CHECK:               %[[VAL_24:.*]] = addi %[[VAL_23]], %[[VAL_22]] : index
+// CHECK:               %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]]] : memref<?xf32>
+// CHECK:               %[[VAL_27:.*]] = mulf %[[VAL_25]], %[[VAL_26]] : f32
+// CHECK:               memref.store %[[VAL_27]], %[[VAL_13]]{{\[}}%[[VAL_17]], %[[VAL_22]]] : memref<32x16xf32>
 // CHECK:             }
-// CHECK:             %[[VAL_66:.*]] = cmpi eq, %[[VAL_29]], %[[VAL_32]] : index
-// CHECK:             %[[VAL_67:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index
-// CHECK:             %[[VAL_68:.*]] = select %[[VAL_66]], %[[VAL_67]], %[[VAL_27]] : index
-// CHECK:             %[[VAL_69:.*]] = cmpi eq, %[[VAL_30]], %[[VAL_32]] : index
-// CHECK:             %[[VAL_70:.*]] = addi %[[VAL_28]], %[[VAL_4]] : index
-// CHECK:             %[[VAL_71:.*]] = select %[[VAL_69]], %[[VAL_70]], %[[VAL_28]] : index
-// CHECK:             scf.yield %[[VAL_68]], %[[VAL_71]] : index, index
 // CHECK:           }
-// CHECK:           %[[VAL_72:.*]] = memref.tensor_load %[[VAL_16]] : memref<32x16xf32>
-// CHECK:           return %[[VAL_72]] : tensor<32x16xf32>
+// CHECK:           %[[VAL_28:.*]] = memref.tensor_load %[[VAL_13]] : memref<32x16xf32>
+// CHECK:           return %[[VAL_28]] : tensor<32x16xf32>
 // CHECK:         }
-func @mul_sd_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  %0 = linalg.generic #trait_ss_ss
-     ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>)
+func @mul_sd_ds(%arga: tensor<32x16xf32, #Tsd>, %argb: tensor<32x16xf32, #Tds>, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32x16xf32, #Tsd>, tensor<32x16xf32, #Tds>)
     outs(%argx: tensor<32x16xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -1044,25 +865,20 @@ func @mul_sd_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tenso
     affine_map<(i,j) -> (j)>,    // b
     affine_map<(i,j) -> (i)>     // x (out)
   ],
-  sparse = [
-    [ "D", "S" ],  // A
-    [ "D" ],       // b
-    [ "D" ]        // x
-  ],
   iterator_types = ["parallel", "reduction"],
   doc = "x(i) += SUM_j A(i,j) * b(j)"
 }
 
 // CHECK-LABEL:   func @matvec(
-// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<16x32xf32>,
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<16x32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32xf32>,
 // CHECK-SAME:                 %[[VAL_2:.*]]: tensor<16xf32>) -> tensor<16xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 16 : index
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<16x32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<16x32xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<16x32xf32> to memref<?xf32>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<16x32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<16x32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<16x32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<16xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.alloc() : memref<16xf32>
@@ -1080,14 +896,14 @@ func @mul_sd_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>, %argx: tenso
 // CHECK:               %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_19]] : f32
 // CHECK:               scf.yield %[[VAL_24]] : f32
 // CHECK:             }
-// CHECK:             store %[[VAL_25:.*]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32>
+// CHECK:             memref.store %[[VAL_25:.*]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32>
 // CHECK:           }
 // CHECK:           %[[VAL_26:.*]] = memref.tensor_load %[[VAL_11]] : memref<16xf32>
 // CHECK:           return %[[VAL_26]] : tensor<16xf32>
 // CHECK:         }
-func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> {
+func @matvec(%argA: tensor<16x32xf32, #Tds>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> {
   %0 = linalg.generic #trait_matvec
-       ins(%argA, %argb: tensor<16x32xf32>, tensor<32xf32>)
+       ins(%argA, %argb: tensor<16x32xf32, #Tds>, tensor<32xf32>)
       outs(%argx: tensor<16xf32>) {
     ^bb(%A: f32, %b: f32, %x: f32):
       %0 = mulf %A, %b : f32
@@ -1102,22 +918,18 @@ func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf
     affine_map<(i,j) -> (i,j)>, // A
     affine_map<(i,j) -> ()>     // x (scalar out)
   ],
-  sparse = [
-    [ "D", "S" ], // A
-    [ ]           // x
-  ],
   iterator_types = ["reduction", "reduction"],
   doc = "x += SUM_ij A(i,j)"
 }
 
 // CHECK-LABEL:   func @sum_reduction(
-// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<10x20xf32>,
+// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                        %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
 // CHECK:           %[[VAL_2:.*]] = constant 10 : index
 // CHECK:           %[[VAL_3:.*]] = constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<10x20xf32> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20xf32> to memref<?xf32>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32>
 // CHECK:           %[[VAL_8:.*]] = memref.alloc() : memref<f32>
 // CHECK:           linalg.copy(%[[VAL_7]], %[[VAL_8]]) : memref<f32>, memref<f32>
@@ -1131,14 +943,14 @@ func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf
 // CHECK:               %[[VAL_18:.*]] = addf %[[VAL_16]], %[[VAL_17]] : f32
 // CHECK:               scf.yield %[[VAL_18]] : f32
 // CHECK:             }
-// CHECK:             store %[[VAL_19:.*]], %[[VAL_8]][] : memref<f32>
+// CHECK:             memref.store %[[VAL_19:.*]], %[[VAL_8]][] : memref<f32>
 // CHECK:           }
 // CHECK:           %[[VAL_20:.*]] = memref.tensor_load %[[VAL_8]] : memref<f32>
 // CHECK:           return %[[VAL_20]] : tensor<f32>
 // CHECK:         }
-func @sum_reduction(%arga: tensor<10x20xf32>, %argx: tensor<f32>) -> tensor<f32> {
+func @sum_reduction(%arga: tensor<10x20xf32, #Tds>, %argx: tensor<f32>) -> tensor<f32> {
   %0 = linalg.generic #trait_sum_reduction
-     ins(%arga: tensor<10x20xf32>)
+     ins(%arga: tensor<10x20xf32, #Tds>)
     outs(%argx: tensor<f32>) {
       ^bb(%a: f32, %x: f32):
         %0 = addf %x, %a : f32
@@ -1152,23 +964,19 @@ func @sum_reduction(%arga: tensor<10x20xf32>, %argx: tensor<f32>) -> tensor<f32>
     affine_map<(i,j) -> (i,j)>,  // A
     affine_map<(i,j) -> (i,j)>   // X (out)
   ],
-  sparse = [
-    [ "D", "S" ],  // A
-    [ "D", "D" ]   // X
-  ],
   iterator_types = ["parallel", "parallel"],
   doc = "X(i,j) = A(i,j) * SCALE"
 }
 
 // CHECK-LABEL:   func @scale(
-// CHECK-SAME:                %[[VAL_0:.*]]: tensor<?x?xf64>,
+// CHECK-SAME:                %[[VAL_0:.*]]: tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                %[[VAL_1:.*]]: tensor<?x?xf64>) -> tensor<?x?xf64> {
-// CHECK-DAG:           %[[VAL_3:.*]] = constant 0 : index
-// CHECK-DAG:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK-DAG:           %[[VAL_2:.*]] = constant 2.000000e+00 : f64
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64> to memref<?xf64>
+// CHECK:           %[[VAL_2:.*]] = constant 2.000000e+00 : f64
+// CHECK:           %[[VAL_3:.*]] = constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
 // CHECK:           %[[VAL_8:.*]] = memref.dim %[[VAL_1]], %[[VAL_3]] : tensor<?x?xf64>
 // CHECK:           %[[VAL_9:.*]] = memref.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x?xf64>
 // CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<?x?xf64>
@@ -1182,16 +990,16 @@ func @sum_reduction(%arga: tensor<10x20xf32>, %argx: tensor<f32>) -> tensor<f32>
 // CHECK:               %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_16]]] : memref<?xindex>
 // CHECK:               %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xf64>
 // CHECK:               %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_2]] : f64
-// CHECK:               store %[[VAL_19]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_17]]] : memref<?x?xf64>
+// CHECK:               memref.store %[[VAL_19]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_17]]] : memref<?x?xf64>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_20:.*]] = memref.tensor_load %[[VAL_11]] : memref<?x?xf64>
 // CHECK:           return %[[VAL_20]] : tensor<?x?xf64>
 // CHECK:         }
-func @scale(%arga: tensor<?x?xf64>, %argx: tensor<?x?xf64>) -> tensor<?x?xf64> {
+func @scale(%arga: tensor<?x?xf64, #Tds>, %argx: tensor<?x?xf64>) -> tensor<?x?xf64> {
   %0 = constant 2.0 : f64
   %1 = linalg.generic #trait_scale
-     ins(%arga: tensor<?x?xf64>)
+     ins(%arga: tensor<?x?xf64, #Tds>)
     outs(%argx: tensor<?x?xf64>) {
       ^bb(%a: f64, %x: f64):
         %2 = mulf %a, %0 : f64
@@ -1207,28 +1015,22 @@ func @scale(%arga: tensor<?x?xf64>, %argx: tensor<?x?xf64>) -> tensor<?x?xf64> {
     affine_map<(i,j,k) -> (k,j)>,  // B
     affine_map<(i,j,k) -> (i,j)>   // X (out)
   ],
-  sparse = [
-    [ "S", "S" ],  // S
-    [ "D", "D" ],  // A
-    [ "D", "D" ],  // B
-    [ "D", "D" ]   // X
-  ],
   iterator_types = ["parallel", "parallel", "reduction"],
   doc = "X(i,j) += S(i,j) SUM_k A(i,k) B(k,j)"
 }
 
 // CHECK-LABEL:   func @sampled_dense_dense(
-// CHECK-SAME:                              %[[VAL_0:.*0]]: tensor<?x?xf32>,
+// CHECK-SAME:                              %[[VAL_0:.*0]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                              %[[VAL_1:.*1]]: tensor<?x?xf32>,
 // CHECK-SAME:                              %[[VAL_2:.*2]]: tensor<?x?xf32>,
 // CHECK-SAME:                              %[[VAL_3:.*3]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<?x?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<?x?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf32> to memref<?xf32>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_1]] : memref<?x?xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.dim %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.buffer_cast %[[VAL_2]] : memref<?x?xf32>
@@ -1254,19 +1056,19 @@ func @scale(%arga: tensor<?x?xf64>, %argx: tensor<?x?xf64>) -> tensor<?x?xf64> {
 // CHECK:                 %[[VAL_32:.*]] = mulf %[[VAL_23]], %[[VAL_31]] : f32
 // CHECK:                 %[[VAL_33:.*]] = mulf %[[VAL_30]], %[[VAL_32]] : f32
 // CHECK:                 %[[VAL_34:.*]] = addf %[[VAL_29]], %[[VAL_33]] : f32
-// CHECK:                 store %[[VAL_34]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_28]]] : memref<?x?xf32>
+// CHECK:                 memref.store %[[VAL_34]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_28]]] : memref<?x?xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_35:.*]] = memref.tensor_load %[[VAL_17]] : memref<?x?xf32>
 // CHECK:           return %[[VAL_35]] : tensor<?x?xf32>
 // CHECK:         }
-func @sampled_dense_dense(%args: tensor<?x?xf32>,
+func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
                           %arga: tensor<?x?xf32>,
                           %argb: tensor<?x?xf32>,
                           %argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.generic #trait_sampled_dense_dense
-     ins(%args, %arga, %argb: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
+     ins(%args, %arga, %argb: tensor<?x?xf32, #Tss>, tensor<?x?xf32>, tensor<?x?xf32>)
     outs(%argx: tensor<?x?xf32>) {
       ^bb(%s: f32, %a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -1286,39 +1088,31 @@ func @sampled_dense_dense(%args: tensor<?x?xf32>,
     affine_map<(i,j) -> ()>,     // e
     affine_map<(i,j) -> (i)>     // x (out)
   ],
-  sparse = [
-    [ "S", "S" ], // A
-    [ "D", "S" ], // B
-    [ "D", "S" ], // C
-    [ "D"  ],     // d
-    [      ],     // e
-    [ "D"  ]      // x
-  ],
   iterator_types = ["parallel", "reduction"],
   doc = "x(i) = SUM_j A(i,j) * B(i,j) * d(i) * e + C(i,j)"
 }
 
 // CHECK-LABEL:   func @sum_kernel_with_inv(
-// CHECK-SAME:                              %[[VAL_0:.*0]]: tensor<?x?xf32>,
-// CHECK-SAME:                              %[[VAL_1:.*1]]: tensor<?x?xf32>,
-// CHECK-SAME:                              %[[VAL_2:.*2]]: tensor<?x?xf32>,
+// CHECK-SAME:                              %[[VAL_0:.*0]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                              %[[VAL_1:.*1]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                              %[[VAL_2:.*2]]: tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                              %[[VAL_3:.*3]]: tensor<?xf32>,
 // CHECK-SAME:                              %[[VAL_4:.*4]]: tensor<f32>,
 // CHECK-SAME:                              %[[VAL_5:.*5]]: tensor<?xf32>) -> tensor<?xf32> {
 // CHECK:           %[[VAL_6:.*]] = constant 0 : index
 // CHECK:           %[[VAL_7:.*]] = constant true
 // CHECK:           %[[VAL_8:.*]] = constant 1 : index
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<?x?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<?x?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<?x?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<?x?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf32> to memref<?xf32>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_8]] : tensor<?x?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_8]] : tensor<?x?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?xf32> to memref<?xf32>
-// CHECK:           %[[VAL_17:.*]] = sparse_tensor.pointers %[[VAL_2]], %[[VAL_8]] : tensor<?x?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_18:.*]] = sparse_tensor.indices %[[VAL_2]], %[[VAL_8]] : tensor<?x?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_19:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<?x?xf32> to memref<?xf32>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_8]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_8]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
+// CHECK:           %[[VAL_17:.*]] = sparse_tensor.pointers %[[VAL_2]], %[[VAL_8]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_18:.*]] = sparse_tensor.indices %[[VAL_2]], %[[VAL_8]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_19:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_20:.*]] = memref.buffer_cast %[[VAL_3]] : memref<?xf32>
 // CHECK:           %[[VAL_21:.*]] = memref.buffer_cast %[[VAL_4]] : memref<f32>
 // CHECK:           %[[VAL_22:.*]] = memref.dim %[[VAL_5]], %[[VAL_6]] : tensor<?xf32>
@@ -1377,7 +1171,7 @@ func @sampled_dense_dense(%args: tensor<?x?xf32>,
 // CHECK:                   %[[VAL_76:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_57]]] : memref<?xf32>
 // CHECK:                   %[[VAL_77:.*]] = addf %[[VAL_75]], %[[VAL_76]] : f32
 // CHECK:                   %[[VAL_78:.*]] = addf %[[VAL_70]], %[[VAL_77]] : f32
-// CHECK:                   store %[[VAL_78]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:                   memref.store %[[VAL_78]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
 // CHECK:                 } else {
 // CHECK:                   %[[VAL_79:.*]] = cmpi eq, %[[VAL_58]], %[[VAL_64]] : index
 // CHECK:                   %[[VAL_80:.*]] = cmpi eq, %[[VAL_59]], %[[VAL_64]] : index
@@ -1390,14 +1184,14 @@ func @sampled_dense_dense(%args: tensor<?x?xf32>,
 // CHECK:                     %[[VAL_86:.*]] = mulf %[[VAL_85]], %[[VAL_36]] : f32
 // CHECK:                     %[[VAL_87:.*]] = mulf %[[VAL_86]], %[[VAL_25]] : f32
 // CHECK:                     %[[VAL_88:.*]] = addf %[[VAL_82]], %[[VAL_87]] : f32
-// CHECK:                     store %[[VAL_88]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:                     memref.store %[[VAL_88]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
 // CHECK:                   } else {
 // CHECK:                     %[[VAL_89:.*]] = cmpi eq, %[[VAL_62]], %[[VAL_64]] : index
 // CHECK:                     scf.if %[[VAL_89]] {
 // CHECK:                       %[[VAL_90:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
 // CHECK:                       %[[VAL_91:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_57]]] : memref<?xf32>
 // CHECK:                       %[[VAL_92:.*]] = addf %[[VAL_90]], %[[VAL_91]] : f32
-// CHECK:                       store %[[VAL_92]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:                       memref.store %[[VAL_92]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
 // CHECK:                     } else {
 // CHECK:                     }
 // CHECK:                   }
@@ -1435,7 +1229,7 @@ func @sampled_dense_dense(%args: tensor<?x?xf32>,
 // CHECK:                   %[[VAL_122:.*]] = mulf %[[VAL_121]], %[[VAL_36]] : f32
 // CHECK:                   %[[VAL_123:.*]] = mulf %[[VAL_122]], %[[VAL_25]] : f32
 // CHECK:                   %[[VAL_124:.*]] = addf %[[VAL_118]], %[[VAL_123]] : f32
-// CHECK:                   store %[[VAL_124]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:                   memref.store %[[VAL_124]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
 // CHECK:                 } else {
 // CHECK:                 }
 // CHECK:                 %[[VAL_125:.*]] = cmpi eq, %[[VAL_111]], %[[VAL_114]] : index
@@ -1452,7 +1246,7 @@ func @sampled_dense_dense(%args: tensor<?x?xf32>,
 // CHECK:                 %[[VAL_137:.*]] = addf %[[VAL_135]], %[[VAL_136]] : f32
 // CHECK:                 scf.yield %[[VAL_137]] : f32
 // CHECK:               }
-// CHECK:               store %[[VAL_138:.*]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:               memref.store %[[VAL_138:.*]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
 // CHECK:             } else {
 // CHECK:               scf.if %[[VAL_7]] {
 // CHECK:                 %[[VAL_139:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_33]]] : memref<?xindex>
@@ -1464,7 +1258,7 @@ func @sampled_dense_dense(%args: tensor<?x?xf32>,
 // CHECK:                   %[[VAL_147:.*]] = addf %[[VAL_145]], %[[VAL_146]] : f32
 // CHECK:                   scf.yield %[[VAL_147]] : f32
 // CHECK:                 }
-// CHECK:                 store %[[VAL_148:.*]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
+// CHECK:                 memref.store %[[VAL_148:.*]], %[[VAL_24]]{{\[}}%[[VAL_33]]] : memref<?xf32>
 // CHECK:               } else {
 // CHECK:               }
 // CHECK:             }
@@ -1484,21 +1278,21 @@ func @sampled_dense_dense(%args: tensor<?x?xf32>,
 // CHECK:               %[[VAL_163:.*]] = addf %[[VAL_161]], %[[VAL_162]] : f32
 // CHECK:               scf.yield %[[VAL_163]] : f32
 // CHECK:             }
-// CHECK:             store %[[VAL_164:.*]], %[[VAL_24]]{{\[}}%[[VAL_153]]] : memref<?xf32>
+// CHECK:             memref.store %[[VAL_164:.*]], %[[VAL_24]]{{\[}}%[[VAL_153]]] : memref<?xf32>
 // CHECK:           }
 // CHECK:           %[[VAL_165:.*]] = memref.tensor_load %[[VAL_24]] : memref<?xf32>
 // CHECK:           return %[[VAL_165]] : tensor<?xf32>
 // CHECK:         }
-func @sum_kernel_with_inv(%arga: tensor<?x?xf32>,
-                          %argb: tensor<?x?xf32>,
-                          %argc: tensor<?x?xf32>,
+func @sum_kernel_with_inv(%arga: tensor<?x?xf32, #Tss>,
+                          %argb: tensor<?x?xf32, #Tds>,
+                          %argc: tensor<?x?xf32, #Tds>,
                           %argd: tensor<?xf32>,
                           %arge: tensor<f32>,
                           %argx: tensor<?xf32>) -> tensor<?xf32> {
   %0 = linalg.generic #trait_sum_kernel_with_inv
-    ins(%arga, %argb, %argc, %argd, %arge : tensor<?x?xf32>,
-                                            tensor<?x?xf32>,
-                                            tensor<?x?xf32>,
+    ins(%arga, %argb, %argc, %argd, %arge : tensor<?x?xf32, #Tss>,
+                                            tensor<?x?xf32, #Tds>,
+                                            tensor<?x?xf32, #Tds>,
                                             tensor<?xf32>,
                                             tensor<f32>)
     outs(%argx: tensor<?xf32>) {

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
index 270b11e220ed7..1c636fee6dc8b 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
@@ -1,51 +1,59 @@
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
 // RUN: mlir-opt %s -sparsification | FileCheck %s
 
-#trait_ddd = {
+#Tddd = #sparse_tensor.encoding<{ dimLevelType = [ "dense",      "dense",      "dense"      ] }>
+#Tdds = #sparse_tensor.encoding<{ dimLevelType = [ "dense",      "dense",      "compressed" ] }>
+#Tdsd = #sparse_tensor.encoding<{ dimLevelType = [ "dense",      "compressed", "dense"      ] }>
+#Tdss = #sparse_tensor.encoding<{ dimLevelType = [ "dense",      "compressed", "compressed" ] }>
+#Tsdd = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense",      "dense"      ] }>
+#Tsds = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense",      "compressed" ] }>
+#Tssd = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense"      ] }>
+#Tsss = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>
+
+#trait3 = {
   indexing_maps = [
     affine_map<(i,j,k) -> (i,j,k)>,  // A
     affine_map<(i,j,k) -> (i,j,k)>,  // B
     affine_map<(i,j,k) -> (i,j,k)>   // X (out)
   ],
-  sparse = [
-    [ "D", "D", "D" ],  // A
-    [ "D", "D", "D" ],  // B
-    [ "D", "D", "D" ]   // X
-  ],
   iterator_types = ["parallel", "parallel", "parallel"],
   doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)"
 }
 
 // CHECK-LABEL:   func @add_ddd(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 16 : index
 // CHECK:           %[[VAL_5:.*]] = constant 8 : index
 // CHECK:           %[[VAL_6:.*]] = constant 0 : index
 // CHECK:           %[[VAL_7:.*]] = constant 1 : index
-// CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_0]] : memref<32x16x8xf32>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.alloc() : memref<32x16x8xf32>
 // CHECK:           linalg.copy(%[[VAL_10]], %[[VAL_11]]) : memref<32x16x8xf32>, memref<32x16x8xf32>
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
 // CHECK:             scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
-// CHECK:               scf.for %[[VAL_14:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
-// CHECK:                 %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_14]]] : memref<32x16x8xf32>
-// CHECK:                 %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_14]]] : memref<32x16x8xf32>
-// CHECK:                 %[[VAL_17:.*]] = addf %[[VAL_15]], %[[VAL_16]] : f32
-// CHECK:                 store %[[VAL_17]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_14]]] : memref<32x16x8xf32>
+// CHECK:               %[[VAL_14:.*]] = muli %[[VAL_12]], %[[VAL_4]] : index
+// CHECK:               %[[VAL_15:.*]] = addi %[[VAL_14]], %[[VAL_13]] : index
+// CHECK:               scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
+// CHECK:                 %[[VAL_17:.*]] = muli %[[VAL_15]], %[[VAL_5]] : index
+// CHECK:                 %[[VAL_18:.*]] = addi %[[VAL_17]], %[[VAL_16]] : index
+// CHECK:                 %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf32>
+// CHECK:                 %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
+// CHECK:                 %[[VAL_21:.*]] = addf %[[VAL_19]], %[[VAL_20]] : f32
+// CHECK:                 memref.store %[[VAL_21]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
-// CHECK:           %[[VAL_18:.*]] = memref.tensor_load %[[VAL_11]] : memref<32x16x8xf32>
-// CHECK:           return %[[VAL_18]] : tensor<32x16x8xf32>
+// CHECK:           %[[VAL_22:.*]] = memref.tensor_load %[[VAL_11]] : memref<32x16x8xf32>
+// CHECK:           return %[[VAL_22]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @add_ddd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_ddd
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @add_ddd(%arga: tensor<32x16x8xf32, #Tddd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tddd>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -55,35 +63,39 @@ func @add_ddd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 }
 
 // CHECK-LABEL:   func @mul_ddd(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 16 : index
 // CHECK:           %[[VAL_5:.*]] = constant 8 : index
 // CHECK:           %[[VAL_6:.*]] = constant 0 : index
 // CHECK:           %[[VAL_7:.*]] = constant 1 : index
-// CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_0]] : memref<32x16x8xf32>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.alloc() : memref<32x16x8xf32>
 // CHECK:           linalg.copy(%[[VAL_10]], %[[VAL_11]]) : memref<32x16x8xf32>, memref<32x16x8xf32>
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
 // CHECK:             scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
-// CHECK:               scf.for %[[VAL_14:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
-// CHECK:                 %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_14]]] : memref<32x16x8xf32>
-// CHECK:                 %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_14]]] : memref<32x16x8xf32>
-// CHECK:                 %[[VAL_17:.*]] = mulf %[[VAL_15]], %[[VAL_16]] : f32
-// CHECK:                 store %[[VAL_17]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_14]]] : memref<32x16x8xf32>
+// CHECK:               %[[VAL_14:.*]] = muli %[[VAL_12]], %[[VAL_4]] : index
+// CHECK:               %[[VAL_15:.*]] = addi %[[VAL_14]], %[[VAL_13]] : index
+// CHECK:               scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
+// CHECK:                 %[[VAL_17:.*]] = muli %[[VAL_15]], %[[VAL_5]] : index
+// CHECK:                 %[[VAL_18:.*]] = addi %[[VAL_17]], %[[VAL_16]] : index
+// CHECK:                 %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf32>
+// CHECK:                 %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
+// CHECK:                 %[[VAL_21:.*]] = mulf %[[VAL_19]], %[[VAL_20]] : f32
+// CHECK:                 memref.store %[[VAL_21]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
-// CHECK:           %[[VAL_18:.*]] = memref.tensor_load %[[VAL_11]] : memref<32x16x8xf32>
-// CHECK:           return %[[VAL_18]] : tensor<32x16x8xf32>
+// CHECK:           %[[VAL_22:.*]] = memref.tensor_load %[[VAL_11]] : memref<32x16x8xf32>
+// CHECK:           return %[[VAL_22]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @mul_ddd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_ddd
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @mul_ddd(%arga: tensor<32x16x8xf32, #Tddd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tddd>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -92,25 +104,10 @@ func @mul_ddd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
   return %0 : tensor<32x16x8xf32>
 }
 
-#trait_dds = {
-  indexing_maps = [
-    affine_map<(i,j,k) -> (i,j,k)>,  // A
-    affine_map<(i,j,k) -> (i,j,k)>,  // B
-    affine_map<(i,j,k) -> (i,j,k)>   // X (out)
-  ],
-  sparse = [
-    [ "D", "D", "S" ],  // A
-    [ "D", "D", "D" ],  // B
-    [ "D", "D", "D" ]   // X
-  ],
-  iterator_types = ["parallel", "parallel", "parallel"],
-  doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)"
-}
-
 // CHECK-LABEL:   func @add_dds(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 2 : index
 // CHECK:           %[[VAL_4:.*]] = constant 32 : index
 // CHECK:           %[[VAL_5:.*]] = constant 16 : index
@@ -118,9 +115,9 @@ func @mul_ddd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:           %[[VAL_7:.*]] = constant 0 : index
 // CHECK:           %[[VAL_8:.*]] = constant true
 // CHECK:           %[[VAL_9:.*]] = constant 1 : index
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32> to memref<?xf32>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_14:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_15:.*]] = memref.alloc() : memref<32x16x8xf32>
@@ -143,11 +140,11 @@ func @mul_ddd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                   %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref<?xf32>
 // CHECK:                   %[[VAL_32:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_28]]] : memref<32x16x8xf32>
 // CHECK:                   %[[VAL_33:.*]] = addf %[[VAL_31]], %[[VAL_32]] : f32
-// CHECK:                   store %[[VAL_33]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_28]]] : memref<32x16x8xf32>
+// CHECK:                   memref.store %[[VAL_33]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_28]]] : memref<32x16x8xf32>
 // CHECK:                 } else {
 // CHECK:                   scf.if %[[VAL_8]] {
 // CHECK:                     %[[VAL_34:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_28]]] : memref<32x16x8xf32>
-// CHECK:                     store %[[VAL_34]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_28]]] : memref<32x16x8xf32>
+// CHECK:                     memref.store %[[VAL_34]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_28]]] : memref<32x16x8xf32>
 // CHECK:                   } else {
 // CHECK:                   }
 // CHECK:                 }
@@ -159,16 +156,16 @@ func @mul_ddd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:               }
 // CHECK:               scf.for %[[VAL_39:.*]] = %[[VAL_40:.*]]#1 to %[[VAL_6]] step %[[VAL_9]] {
 // CHECK:                 %[[VAL_41:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_39]]] : memref<32x16x8xf32>
-// CHECK:                 store %[[VAL_41]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_39]]] : memref<32x16x8xf32>
+// CHECK:                 memref.store %[[VAL_41]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_39]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_42:.*]] = memref.tensor_load %[[VAL_15]] : memref<32x16x8xf32>
 // CHECK:           return %[[VAL_42]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @add_dds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_dds
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @add_dds(%arga: tensor<32x16x8xf32, #Tdds>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tdds>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -178,17 +175,17 @@ func @add_dds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 }
 
 // CHECK-LABEL:   func @mul_dds(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 2 : index
 // CHECK:           %[[VAL_4:.*]] = constant 32 : index
 // CHECK:           %[[VAL_5:.*]] = constant 16 : index
 // CHECK:           %[[VAL_6:.*]] = constant 0 : index
 // CHECK:           %[[VAL_7:.*]] = constant 1 : index
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32> to memref<?xf32>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.alloc() : memref<32x16x8xf32>
@@ -205,16 +202,16 @@ func @add_dds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                 %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xf32>
 // CHECK:                 %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_14]], %[[VAL_15]], %[[VAL_22]]] : memref<32x16x8xf32>
 // CHECK:                 %[[VAL_25:.*]] = mulf %[[VAL_23]], %[[VAL_24]] : f32
-// CHECK:                 store %[[VAL_25]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_15]], %[[VAL_22]]] : memref<32x16x8xf32>
+// CHECK:                 memref.store %[[VAL_25]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_15]], %[[VAL_22]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_26:.*]] = memref.tensor_load %[[VAL_13]] : memref<32x16x8xf32>
 // CHECK:           return %[[VAL_26]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @mul_dds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_dds
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @mul_dds(%arga: tensor<32x16x8xf32, #Tdds>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tdds>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -223,34 +220,19 @@ func @mul_dds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
   return %0 : tensor<32x16x8xf32>
 }
 
-#trait_dsd = {
-  indexing_maps = [
-    affine_map<(i,j,k) -> (i,j,k)>,  // A
-    affine_map<(i,j,k) -> (i,j,k)>,  // B
-    affine_map<(i,j,k) -> (i,j,k)>   // X (out)
-  ],
-  sparse = [
-    [ "D", "S", "D" ],  // A
-    [ "D", "D", "D" ],  // B
-    [ "D", "D", "D" ]   // X
-  ],
-  iterator_types = ["parallel", "parallel", "parallel"],
-  doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)"
-}
-
 // CHECK-LABEL:   func @add_dsd(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 16 : index
 // CHECK:           %[[VAL_5:.*]] = constant 8 : index
 // CHECK:           %[[VAL_6:.*]] = constant true
 // CHECK:           %[[VAL_7:.*]] = constant 0 : index
 // CHECK:           %[[VAL_8:.*]] = constant 1 : index
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32> to memref<?xf32>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_14:.*]] = memref.alloc() : memref<32x16x8xf32>
@@ -273,13 +255,13 @@ func @mul_dds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                   %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref<?xf32>
 // CHECK:                   %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_27]]] : memref<32x16x8xf32>
 // CHECK:                   %[[VAL_32:.*]] = addf %[[VAL_30]], %[[VAL_31]] : f32
-// CHECK:                   store %[[VAL_32]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_27]]] : memref<32x16x8xf32>
+// CHECK:                   memref.store %[[VAL_32]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_27]]] : memref<32x16x8xf32>
 // CHECK:                 }
 // CHECK:               } else {
 // CHECK:                 scf.if %[[VAL_6]] {
 // CHECK:                   scf.for %[[VAL_33:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
 // CHECK:                     %[[VAL_34:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_33]]] : memref<32x16x8xf32>
-// CHECK:                     store %[[VAL_34]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_33]]] : memref<32x16x8xf32>
+// CHECK:                     memref.store %[[VAL_34]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_33]]] : memref<32x16x8xf32>
 // CHECK:                   }
 // CHECK:                 } else {
 // CHECK:                 }
@@ -293,16 +275,16 @@ func @mul_dds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:             scf.for %[[VAL_39:.*]] = %[[VAL_40:.*]]#1 to %[[VAL_4]] step %[[VAL_8]] {
 // CHECK:               scf.for %[[VAL_41:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
 // CHECK:                 %[[VAL_42:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_39]], %[[VAL_41]]] : memref<32x16x8xf32>
-// CHECK:                 store %[[VAL_42]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_39]], %[[VAL_41]]] : memref<32x16x8xf32>
+// CHECK:                 memref.store %[[VAL_42]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_39]], %[[VAL_41]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_43:.*]] = memref.tensor_load %[[VAL_14]] : memref<32x16x8xf32>
 // CHECK:           return %[[VAL_43]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @add_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_dsd
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @add_dsd(%arga: tensor<32x16x8xf32, #Tdsd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tdsd>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -312,16 +294,16 @@ func @add_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 }
 
 // CHECK-LABEL:   func @mul_dsd(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 8 : index
 // CHECK:           %[[VAL_5:.*]] = constant 0 : index
 // CHECK:           %[[VAL_6:.*]] = constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32> to memref<?xf32>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.alloc() : memref<32x16x8xf32>
@@ -338,16 +320,16 @@ func @add_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                 %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xf32>
 // CHECK:                 %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_13]], %[[VAL_18]], %[[VAL_19]]] : memref<32x16x8xf32>
 // CHECK:                 %[[VAL_24:.*]] = mulf %[[VAL_22]], %[[VAL_23]] : f32
-// CHECK:                 store %[[VAL_24]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_18]], %[[VAL_19]]] : memref<32x16x8xf32>
+// CHECK:                 memref.store %[[VAL_24]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_18]], %[[VAL_19]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_25:.*]] = memref.tensor_load %[[VAL_12]] : memref<32x16x8xf32>
 // CHECK:           return %[[VAL_25]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @mul_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_dsd
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @mul_dsd(%arga: tensor<32x16x8xf32, #Tdsd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tdsd>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -356,25 +338,10 @@ func @mul_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
   return %0 : tensor<32x16x8xf32>
 }
 
-#trait_dss = {
-  indexing_maps = [
-    affine_map<(i,j,k) -> (i,j,k)>,  // A
-    affine_map<(i,j,k) -> (i,j,k)>,  // B
-    affine_map<(i,j,k) -> (i,j,k)>   // X (out)
-  ],
-  sparse = [
-    [ "D", "S", "S" ],  // A
-    [ "D", "D", "D" ],  // B
-    [ "D", "D", "D" ]   // X
-  ],
-  iterator_types = ["parallel", "parallel", "parallel"],
-  doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)"
-}
-
 // CHECK-LABEL:   func @add_dss(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 2 : index
 // CHECK:           %[[VAL_4:.*]] = constant 32 : index
 // CHECK:           %[[VAL_5:.*]] = constant 16 : index
@@ -382,11 +349,11 @@ func @mul_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:           %[[VAL_7:.*]] = constant true
 // CHECK:           %[[VAL_8:.*]] = constant 0 : index
 // CHECK:           %[[VAL_9:.*]] = constant 1 : index
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_9]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_9]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32> to memref<?xf32>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_9]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_9]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_15:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_16:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_17:.*]] = memref.alloc() : memref<32x16x8xf32>
@@ -417,11 +384,11 @@ func @mul_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                     %[[VAL_41:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_37]]] : memref<?xf32>
 // CHECK:                     %[[VAL_42:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_38]]] : memref<32x16x8xf32>
 // CHECK:                     %[[VAL_43:.*]] = addf %[[VAL_41]], %[[VAL_42]] : f32
-// CHECK:                     store %[[VAL_43]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_38]]] : memref<32x16x8xf32>
+// CHECK:                     memref.store %[[VAL_43]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_38]]] : memref<32x16x8xf32>
 // CHECK:                   } else {
 // CHECK:                     scf.if %[[VAL_7]] {
 // CHECK:                       %[[VAL_44:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_38]]] : memref<32x16x8xf32>
-// CHECK:                       store %[[VAL_44]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_38]]] : memref<32x16x8xf32>
+// CHECK:                       memref.store %[[VAL_44]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_38]]] : memref<32x16x8xf32>
 // CHECK:                     } else {
 // CHECK:                     }
 // CHECK:                   }
@@ -433,13 +400,13 @@ func @mul_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                 }
 // CHECK:                 scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#1 to %[[VAL_6]] step %[[VAL_9]] {
 // CHECK:                   %[[VAL_51:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_49]]] : memref<32x16x8xf32>
-// CHECK:                   store %[[VAL_51]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_49]]] : memref<32x16x8xf32>
+// CHECK:                   memref.store %[[VAL_51]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_49]]] : memref<32x16x8xf32>
 // CHECK:                 }
 // CHECK:               } else {
 // CHECK:                 scf.if %[[VAL_7]] {
 // CHECK:                   scf.for %[[VAL_52:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
 // CHECK:                     %[[VAL_53:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_52]]] : memref<32x16x8xf32>
-// CHECK:                     store %[[VAL_53]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_52]]] : memref<32x16x8xf32>
+// CHECK:                     memref.store %[[VAL_53]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_52]]] : memref<32x16x8xf32>
 // CHECK:                   }
 // CHECK:                 } else {
 // CHECK:                 }
@@ -453,16 +420,16 @@ func @mul_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:             scf.for %[[VAL_58:.*]] = %[[VAL_59:.*]]#1 to %[[VAL_5]] step %[[VAL_9]] {
 // CHECK:               scf.for %[[VAL_60:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
 // CHECK:                 %[[VAL_61:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_58]], %[[VAL_60]]] : memref<32x16x8xf32>
-// CHECK:                 store %[[VAL_61]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_58]], %[[VAL_60]]] : memref<32x16x8xf32>
+// CHECK:                 memref.store %[[VAL_61]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_58]], %[[VAL_60]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_62:.*]] = memref.tensor_load %[[VAL_17]] : memref<32x16x8xf32>
 // CHECK:           return %[[VAL_62]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @add_dss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_dss
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @add_dss(%arga: tensor<32x16x8xf32, #Tdss>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tdss>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -472,18 +439,18 @@ func @add_dss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 }
 
 // CHECK-LABEL:   func @mul_dss(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 2 : index
 // CHECK:           %[[VAL_4:.*]] = constant 32 : index
 // CHECK:           %[[VAL_5:.*]] = constant 0 : index
 // CHECK:           %[[VAL_6:.*]] = constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32> to memref<?xf32>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_14:.*]] = memref.alloc() : memref<32x16x8xf32>
@@ -502,16 +469,16 @@ func @add_dss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_24]]] : memref<?xf32>
 // CHECK:                 %[[VAL_27:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_20]], %[[VAL_25]]] : memref<32x16x8xf32>
 // CHECK:                 %[[VAL_28:.*]] = mulf %[[VAL_26]], %[[VAL_27]] : f32
-// CHECK:                 store %[[VAL_28]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_20]], %[[VAL_25]]] : memref<32x16x8xf32>
+// CHECK:                 memref.store %[[VAL_28]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_20]], %[[VAL_25]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_29:.*]] = memref.tensor_load %[[VAL_14]] : memref<32x16x8xf32>
 // CHECK:           return %[[VAL_29]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @mul_dss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_dss
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @mul_dss(%arga: tensor<32x16x8xf32, #Tdss>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tdss>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -520,34 +487,19 @@ func @mul_dss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
   return %0 : tensor<32x16x8xf32>
 }
 
-#trait_sdd = {
-  indexing_maps = [
-    affine_map<(i,j,k) -> (i,j,k)>,  // A
-    affine_map<(i,j,k) -> (i,j,k)>,  // B
-    affine_map<(i,j,k) -> (i,j,k)>   // X (out)
-  ],
-  sparse = [
-    [ "S", "D", "D" ],  // A
-    [ "D", "D", "D" ],  // B
-    [ "D", "D", "D" ]   // X
-  ],
-  iterator_types = ["parallel", "parallel", "parallel"],
-  doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)"
-}
-
 // CHECK-LABEL:   func @add_sdd(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 16 : index
 // CHECK:           %[[VAL_5:.*]] = constant 8 : index
 // CHECK:           %[[VAL_6:.*]] = constant true
 // CHECK:           %[[VAL_7:.*]] = constant 0 : index
 // CHECK:           %[[VAL_8:.*]] = constant 1 : index
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32> to memref<?xf32>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_14:.*]] = memref.alloc() : memref<32x16x8xf32>
@@ -571,7 +523,7 @@ func @mul_dss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                   %[[VAL_31:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<?xf32>
 // CHECK:                   %[[VAL_32:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]], %[[VAL_25]], %[[VAL_28]]] : memref<32x16x8xf32>
 // CHECK:                   %[[VAL_33:.*]] = addf %[[VAL_31]], %[[VAL_32]] : f32
-// CHECK:                   store %[[VAL_33]], %[[VAL_14]]{{\[}}%[[VAL_22]], %[[VAL_25]], %[[VAL_28]]] : memref<32x16x8xf32>
+// CHECK:                   memref.store %[[VAL_33]], %[[VAL_14]]{{\[}}%[[VAL_22]], %[[VAL_25]], %[[VAL_28]]] : memref<32x16x8xf32>
 // CHECK:                 }
 // CHECK:               }
 // CHECK:             } else {
@@ -579,7 +531,7 @@ func @mul_dss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                 scf.for %[[VAL_34:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
 // CHECK:                   scf.for %[[VAL_35:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
 // CHECK:                     %[[VAL_36:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]], %[[VAL_34]], %[[VAL_35]]] : memref<32x16x8xf32>
-// CHECK:                     store %[[VAL_36]], %[[VAL_14]]{{\[}}%[[VAL_22]], %[[VAL_34]], %[[VAL_35]]] : memref<32x16x8xf32>
+// CHECK:                     memref.store %[[VAL_36]], %[[VAL_14]]{{\[}}%[[VAL_22]], %[[VAL_34]], %[[VAL_35]]] : memref<32x16x8xf32>
 // CHECK:                   }
 // CHECK:                 }
 // CHECK:               } else {
@@ -595,16 +547,16 @@ func @mul_dss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:             scf.for %[[VAL_43:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
 // CHECK:               scf.for %[[VAL_44:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
 // CHECK:                 %[[VAL_45:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_41]], %[[VAL_43]], %[[VAL_44]]] : memref<32x16x8xf32>
-// CHECK:                 store %[[VAL_45]], %[[VAL_14]]{{\[}}%[[VAL_41]], %[[VAL_43]], %[[VAL_44]]] : memref<32x16x8xf32>
+// CHECK:                 memref.store %[[VAL_45]], %[[VAL_14]]{{\[}}%[[VAL_41]], %[[VAL_43]], %[[VAL_44]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_46:.*]] = memref.tensor_load %[[VAL_14]] : memref<32x16x8xf32>
 // CHECK:           return %[[VAL_46]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @add_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_sdd
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @add_sdd(%arga: tensor<32x16x8xf32, #Tsdd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tsdd>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -614,16 +566,16 @@ func @add_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 }
 
 // CHECK-LABEL:   func @mul_sdd(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 16 : index
 // CHECK:           %[[VAL_4:.*]] = constant 8 : index
 // CHECK:           %[[VAL_5:.*]] = constant 0 : index
 // CHECK:           %[[VAL_6:.*]] = constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32> to memref<?xf32>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.alloc() : memref<32x16x8xf32>
@@ -641,16 +593,16 @@ func @add_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                 %[[VAL_23:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_22]]] : memref<?xf32>
 // CHECK:                 %[[VAL_24:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_20]]] : memref<32x16x8xf32>
 // CHECK:                 %[[VAL_25:.*]] = mulf %[[VAL_23]], %[[VAL_24]] : f32
-// CHECK:                 store %[[VAL_25]], %[[VAL_12]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_20]]] : memref<32x16x8xf32>
+// CHECK:                 memref.store %[[VAL_25]], %[[VAL_12]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_20]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_26:.*]] = memref.tensor_load %[[VAL_12]] : memref<32x16x8xf32>
 // CHECK:           return %[[VAL_26]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @mul_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_sdd
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @mul_sdd(%arga: tensor<32x16x8xf32, #Tsdd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tsdd>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -659,25 +611,10 @@ func @mul_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
   return %0 : tensor<32x16x8xf32>
 }
 
-#trait_sds = {
-  indexing_maps = [
-    affine_map<(i,j,k) -> (i,j,k)>,  // A
-    affine_map<(i,j,k) -> (i,j,k)>,  // B
-    affine_map<(i,j,k) -> (i,j,k)>   // X (out)
-  ],
-  sparse = [
-    [ "S", "D", "S" ],  // A
-    [ "D", "D", "D" ],  // B
-    [ "D", "D", "D" ]   // X
-  ],
-  iterator_types = ["parallel", "parallel", "parallel"],
-  doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)"
-}
-
 // CHECK-LABEL:   func @add_sds(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 2 : index
 // CHECK:           %[[VAL_4:.*]] = constant 32 : index
 // CHECK:           %[[VAL_5:.*]] = constant 16 : index
@@ -685,11 +622,11 @@ func @mul_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:           %[[VAL_7:.*]] = constant true
 // CHECK:           %[[VAL_8:.*]] = constant 0 : index
 // CHECK:           %[[VAL_9:.*]] = constant 1 : index
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32> to memref<?xf32>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_15:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_16:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_17:.*]] = memref.alloc() : memref<32x16x8xf32>
@@ -721,11 +658,11 @@ func @mul_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                     %[[VAL_42:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_38]]] : memref<?xf32>
 // CHECK:                     %[[VAL_43:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_39]]] : memref<32x16x8xf32>
 // CHECK:                     %[[VAL_44:.*]] = addf %[[VAL_42]], %[[VAL_43]] : f32
-// CHECK:                     store %[[VAL_44]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_39]]] : memref<32x16x8xf32>
+// CHECK:                     memref.store %[[VAL_44]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_39]]] : memref<32x16x8xf32>
 // CHECK:                   } else {
 // CHECK:                     scf.if %[[VAL_7]] {
 // CHECK:                       %[[VAL_45:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_39]]] : memref<32x16x8xf32>
-// CHECK:                       store %[[VAL_45]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_39]]] : memref<32x16x8xf32>
+// CHECK:                       memref.store %[[VAL_45]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_39]]] : memref<32x16x8xf32>
 // CHECK:                     } else {
 // CHECK:                     }
 // CHECK:                   }
@@ -737,7 +674,7 @@ func @mul_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                 }
 // CHECK:                 scf.for %[[VAL_50:.*]] = %[[VAL_51:.*]]#1 to %[[VAL_6]] step %[[VAL_9]] {
 // CHECK:                   %[[VAL_52:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_50]]] : memref<32x16x8xf32>
-// CHECK:                   store %[[VAL_52]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_50]]] : memref<32x16x8xf32>
+// CHECK:                   memref.store %[[VAL_52]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_50]]] : memref<32x16x8xf32>
 // CHECK:                 }
 // CHECK:               }
 // CHECK:             } else {
@@ -745,7 +682,7 @@ func @mul_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                 scf.for %[[VAL_53:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
 // CHECK:                   scf.for %[[VAL_54:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
 // CHECK:                     %[[VAL_55:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]], %[[VAL_53]], %[[VAL_54]]] : memref<32x16x8xf32>
-// CHECK:                     store %[[VAL_55]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_53]], %[[VAL_54]]] : memref<32x16x8xf32>
+// CHECK:                     memref.store %[[VAL_55]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_53]], %[[VAL_54]]] : memref<32x16x8xf32>
 // CHECK:                   }
 // CHECK:                 }
 // CHECK:               } else {
@@ -761,16 +698,16 @@ func @mul_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:             scf.for %[[VAL_62:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
 // CHECK:               scf.for %[[VAL_63:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
 // CHECK:                 %[[VAL_64:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_60]], %[[VAL_62]], %[[VAL_63]]] : memref<32x16x8xf32>
-// CHECK:                 store %[[VAL_64]], %[[VAL_17]]{{\[}}%[[VAL_60]], %[[VAL_62]], %[[VAL_63]]] : memref<32x16x8xf32>
+// CHECK:                 memref.store %[[VAL_64]], %[[VAL_17]]{{\[}}%[[VAL_60]], %[[VAL_62]], %[[VAL_63]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_65:.*]] = memref.tensor_load %[[VAL_17]] : memref<32x16x8xf32>
 // CHECK:           return %[[VAL_65]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @add_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_sds
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @add_sds(%arga: tensor<32x16x8xf32, #Tsds>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tsds>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -780,18 +717,18 @@ func @add_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 }
 
 // CHECK-LABEL:   func @mul_sds(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 2 : index
 // CHECK:           %[[VAL_4:.*]] = constant 16 : index
 // CHECK:           %[[VAL_5:.*]] = constant 0 : index
 // CHECK:           %[[VAL_6:.*]] = constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32> to memref<?xf32>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_14:.*]] = memref.alloc() : memref<32x16x8xf32>
@@ -811,16 +748,16 @@ func @add_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                 %[[VAL_27:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref<?xf32>
 // CHECK:                 %[[VAL_28:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_18]], %[[VAL_19]], %[[VAL_26]]] : memref<32x16x8xf32>
 // CHECK:                 %[[VAL_29:.*]] = mulf %[[VAL_27]], %[[VAL_28]] : f32
-// CHECK:                 store %[[VAL_29]], %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_19]], %[[VAL_26]]] : memref<32x16x8xf32>
+// CHECK:                 memref.store %[[VAL_29]], %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_19]], %[[VAL_26]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_30:.*]] = memref.tensor_load %[[VAL_14]] : memref<32x16x8xf32>
 // CHECK:           return %[[VAL_30]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @mul_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_sds
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @mul_sds(%arga: tensor<32x16x8xf32, #Tsds>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tsds>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -829,36 +766,21 @@ func @mul_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
   return %0 : tensor<32x16x8xf32>
 }
 
-#trait_ssd = {
-  indexing_maps = [
-    affine_map<(i,j,k) -> (i,j,k)>,  // A
-    affine_map<(i,j,k) -> (i,j,k)>,  // B
-    affine_map<(i,j,k) -> (i,j,k)>   // X (out)
-  ],
-  sparse = [
-    [ "S", "S", "D" ],  // A
-    [ "D", "D", "D" ],  // B
-    [ "D", "D", "D" ]   // X
-  ],
-  iterator_types = ["parallel", "parallel", "parallel"],
-  doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)"
-}
-
 // CHECK-LABEL:   func @add_ssd(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK:           %[[VAL_4:.*]] = constant 16 : index
 // CHECK:           %[[VAL_5:.*]] = constant 8 : index
 // CHECK:           %[[VAL_6:.*]] = constant true
 // CHECK:           %[[VAL_7:.*]] = constant 0 : index
 // CHECK:           %[[VAL_8:.*]] = constant 1 : index
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32> to memref<?xf32>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_14:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_15:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_16:.*]] = memref.alloc() : memref<32x16x8xf32>
@@ -890,13 +812,13 @@ func @mul_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                     %[[VAL_41:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_40]]] : memref<?xf32>
 // CHECK:                     %[[VAL_42:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_38]]] : memref<32x16x8xf32>
 // CHECK:                     %[[VAL_43:.*]] = addf %[[VAL_41]], %[[VAL_42]] : f32
-// CHECK:                     store %[[VAL_43]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_38]]] : memref<32x16x8xf32>
+// CHECK:                     memref.store %[[VAL_43]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_38]]] : memref<32x16x8xf32>
 // CHECK:                   }
 // CHECK:                 } else {
 // CHECK:                   scf.if %[[VAL_6]] {
 // CHECK:                     scf.for %[[VAL_44:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
 // CHECK:                       %[[VAL_45:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_44]]] : memref<32x16x8xf32>
-// CHECK:                       store %[[VAL_45]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_44]]] : memref<32x16x8xf32>
+// CHECK:                       memref.store %[[VAL_45]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_44]]] : memref<32x16x8xf32>
 // CHECK:                     }
 // CHECK:                   } else {
 // CHECK:                   }
@@ -910,7 +832,7 @@ func @mul_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:               scf.for %[[VAL_50:.*]] = %[[VAL_51:.*]]#1 to %[[VAL_4]] step %[[VAL_8]] {
 // CHECK:                 scf.for %[[VAL_52:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
 // CHECK:                   %[[VAL_53:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_50]], %[[VAL_52]]] : memref<32x16x8xf32>
-// CHECK:                   store %[[VAL_53]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_50]], %[[VAL_52]]] : memref<32x16x8xf32>
+// CHECK:                   memref.store %[[VAL_53]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_50]], %[[VAL_52]]] : memref<32x16x8xf32>
 // CHECK:                 }
 // CHECK:               }
 // CHECK:             } else {
@@ -918,7 +840,7 @@ func @mul_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                 scf.for %[[VAL_54:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
 // CHECK:                   scf.for %[[VAL_55:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
 // CHECK:                     %[[VAL_56:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_54]], %[[VAL_55]]] : memref<32x16x8xf32>
-// CHECK:                     store %[[VAL_56]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_54]], %[[VAL_55]]] : memref<32x16x8xf32>
+// CHECK:                     memref.store %[[VAL_56]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_54]], %[[VAL_55]]] : memref<32x16x8xf32>
 // CHECK:                   }
 // CHECK:                 }
 // CHECK:               } else {
@@ -934,16 +856,16 @@ func @mul_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:             scf.for %[[VAL_63:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
 // CHECK:               scf.for %[[VAL_64:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
 // CHECK:                 %[[VAL_65:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_61]], %[[VAL_63]], %[[VAL_64]]] : memref<32x16x8xf32>
-// CHECK:                 store %[[VAL_65]], %[[VAL_16]]{{\[}}%[[VAL_61]], %[[VAL_63]], %[[VAL_64]]] : memref<32x16x8xf32>
+// CHECK:                 memref.store %[[VAL_65]], %[[VAL_16]]{{\[}}%[[VAL_61]], %[[VAL_63]], %[[VAL_64]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_66:.*]] = memref.tensor_load %[[VAL_16]] : memref<32x16x8xf32>
 // CHECK:           return %[[VAL_66]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @add_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_ssd
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @add_ssd(%arga: tensor<32x16x8xf32, #Tssd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tssd>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -953,17 +875,17 @@ func @add_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 }
 
 // CHECK-LABEL:   func @mul_ssd(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 8 : index
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32> to memref<?xf32>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.alloc() : memref<32x16x8xf32>
@@ -983,16 +905,16 @@ func @add_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xf32>
 // CHECK:                 %[[VAL_27:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_17]], %[[VAL_22]], %[[VAL_23]]] : memref<32x16x8xf32>
 // CHECK:                 %[[VAL_28:.*]] = mulf %[[VAL_26]], %[[VAL_27]] : f32
-// CHECK:                 store %[[VAL_28]], %[[VAL_13]]{{\[}}%[[VAL_17]], %[[VAL_22]], %[[VAL_23]]] : memref<32x16x8xf32>
+// CHECK:                 memref.store %[[VAL_28]], %[[VAL_13]]{{\[}}%[[VAL_17]], %[[VAL_22]], %[[VAL_23]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_29:.*]] = memref.tensor_load %[[VAL_13]] : memref<32x16x8xf32>
 // CHECK:           return %[[VAL_29]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @mul_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_ssd
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @mul_ssd(%arga: tensor<32x16x8xf32, #Tssd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tssd>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -1001,25 +923,10 @@ func @mul_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
   return %0 : tensor<32x16x8xf32>
 }
 
-#trait_sss = {
-  indexing_maps = [
-    affine_map<(i,j,k) -> (i,j,k)>,  // A
-    affine_map<(i,j,k) -> (i,j,k)>,  // B
-    affine_map<(i,j,k) -> (i,j,k)>   // X (out)
-  ],
-  sparse = [
-    [ "S", "S", "S" ],  // A
-    [ "D", "D", "D" ],  // B
-    [ "D", "D", "D" ]   // X
-  ],
-  iterator_types = ["parallel", "parallel", "parallel"],
-  doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)"
-}
-
 // CHECK-LABEL:   func @add_sss(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 2 : index
 // CHECK:           %[[VAL_4:.*]] = constant 32 : index
 // CHECK:           %[[VAL_5:.*]] = constant 16 : index
@@ -1027,13 +934,13 @@ func @mul_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:           %[[VAL_7:.*]] = constant true
 // CHECK:           %[[VAL_8:.*]] = constant 0 : index
 // CHECK:           %[[VAL_9:.*]] = constant 1 : index
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_9]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_9]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32> to memref<?xf32>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_8]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_9]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_9]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_17:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_18:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_19:.*]] = memref.alloc() : memref<32x16x8xf32>
@@ -1073,11 +980,11 @@ func @mul_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                       %[[VAL_52:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_48]]] : memref<?xf32>
 // CHECK:                       %[[VAL_53:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_49]]] : memref<32x16x8xf32>
 // CHECK:                       %[[VAL_54:.*]] = addf %[[VAL_52]], %[[VAL_53]] : f32
-// CHECK:                       store %[[VAL_54]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_49]]] : memref<32x16x8xf32>
+// CHECK:                       memref.store %[[VAL_54]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_49]]] : memref<32x16x8xf32>
 // CHECK:                     } else {
 // CHECK:                       scf.if %[[VAL_7]] {
 // CHECK:                         %[[VAL_55:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_49]]] : memref<32x16x8xf32>
-// CHECK:                         store %[[VAL_55]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_49]]] : memref<32x16x8xf32>
+// CHECK:                         memref.store %[[VAL_55]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_49]]] : memref<32x16x8xf32>
 // CHECK:                       } else {
 // CHECK:                       }
 // CHECK:                     }
@@ -1089,13 +996,13 @@ func @mul_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                   }
 // CHECK:                   scf.for %[[VAL_60:.*]] = %[[VAL_61:.*]]#1 to %[[VAL_6]] step %[[VAL_9]] {
 // CHECK:                     %[[VAL_62:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_60]]] : memref<32x16x8xf32>
-// CHECK:                     store %[[VAL_62]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_60]]] : memref<32x16x8xf32>
+// CHECK:                     memref.store %[[VAL_62]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_60]]] : memref<32x16x8xf32>
 // CHECK:                   }
 // CHECK:                 } else {
 // CHECK:                   scf.if %[[VAL_7]] {
 // CHECK:                     scf.for %[[VAL_63:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
 // CHECK:                       %[[VAL_64:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_63]]] : memref<32x16x8xf32>
-// CHECK:                       store %[[VAL_64]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_63]]] : memref<32x16x8xf32>
+// CHECK:                       memref.store %[[VAL_64]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_63]]] : memref<32x16x8xf32>
 // CHECK:                     }
 // CHECK:                   } else {
 // CHECK:                   }
@@ -1109,7 +1016,7 @@ func @mul_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:               scf.for %[[VAL_69:.*]] = %[[VAL_70:.*]]#1 to %[[VAL_5]] step %[[VAL_9]] {
 // CHECK:                 scf.for %[[VAL_71:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
 // CHECK:                   %[[VAL_72:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_69]], %[[VAL_71]]] : memref<32x16x8xf32>
-// CHECK:                   store %[[VAL_72]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_69]], %[[VAL_71]]] : memref<32x16x8xf32>
+// CHECK:                   memref.store %[[VAL_72]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_69]], %[[VAL_71]]] : memref<32x16x8xf32>
 // CHECK:                 }
 // CHECK:               }
 // CHECK:             } else {
@@ -1117,7 +1024,7 @@ func @mul_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                 scf.for %[[VAL_73:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
 // CHECK:                   scf.for %[[VAL_74:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
 // CHECK:                     %[[VAL_75:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_73]], %[[VAL_74]]] : memref<32x16x8xf32>
-// CHECK:                     store %[[VAL_75]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_73]], %[[VAL_74]]] : memref<32x16x8xf32>
+// CHECK:                     memref.store %[[VAL_75]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_73]], %[[VAL_74]]] : memref<32x16x8xf32>
 // CHECK:                   }
 // CHECK:                 }
 // CHECK:               } else {
@@ -1133,16 +1040,16 @@ func @mul_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:             scf.for %[[VAL_82:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
 // CHECK:               scf.for %[[VAL_83:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
 // CHECK:                 %[[VAL_84:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_80]], %[[VAL_82]], %[[VAL_83]]] : memref<32x16x8xf32>
-// CHECK:                 store %[[VAL_84]], %[[VAL_19]]{{\[}}%[[VAL_80]], %[[VAL_82]], %[[VAL_83]]] : memref<32x16x8xf32>
+// CHECK:                 memref.store %[[VAL_84]], %[[VAL_19]]{{\[}}%[[VAL_80]], %[[VAL_82]], %[[VAL_83]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_85:.*]] = memref.tensor_load %[[VAL_19]] : memref<32x16x8xf32>
 // CHECK:           return %[[VAL_85]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @add_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_sss
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @add_sss(%arga: tensor<32x16x8xf32, #Tsss>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tsss>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = addf %a, %b : f32
@@ -1152,19 +1059,19 @@ func @add_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 }
 
 // CHECK-LABEL:   func @mul_sss(
-// CHECK-SAME:                  %[[VAL_0:.*0]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_1:.*1]]: tensor<32x16x8xf32>,
-// CHECK-SAME:                  %[[VAL_2:.*2]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME:                  %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 2 : index
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32> to memref<?xindex>
-// CHECK:           %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32> to memref<?xf32>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_13:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_14:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16x8xf32>
 // CHECK:           %[[VAL_15:.*]] = memref.alloc() : memref<32x16x8xf32>
@@ -1186,16 +1093,16 @@ func @add_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                 %[[VAL_30:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_28]]] : memref<?xf32>
 // CHECK:                 %[[VAL_31:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_19]], %[[VAL_24]], %[[VAL_29]]] : memref<32x16x8xf32>
 // CHECK:                 %[[VAL_32:.*]] = mulf %[[VAL_30]], %[[VAL_31]] : f32
-// CHECK:                 store %[[VAL_32]], %[[VAL_15]]{{\[}}%[[VAL_19]], %[[VAL_24]], %[[VAL_29]]] : memref<32x16x8xf32>
+// CHECK:                 memref.store %[[VAL_32]], %[[VAL_15]]{{\[}}%[[VAL_19]], %[[VAL_24]], %[[VAL_29]]] : memref<32x16x8xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_33:.*]] = memref.tensor_load %[[VAL_15]] : memref<32x16x8xf32>
 // CHECK:           return %[[VAL_33]] : tensor<32x16x8xf32>
 // CHECK:         }
-func @mul_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
-  %0 = linalg.generic #trait_sss
-     ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>)
+func @mul_sss(%arga: tensor<32x16x8xf32, #Tsss>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
+  %0 = linalg.generic #trait3
+     ins(%arga, %argb: tensor<32x16x8xf32, #Tsss>, tensor<32x16x8xf32>)
     outs(%argx: tensor<32x16x8xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -1211,27 +1118,21 @@ func @mul_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
     affine_map<(i,j,k,l) -> (l,j)>,    // D
     affine_map<(i,j,k,l) -> (i,j)>     // A (out)
   ],
-  sparse = [
-    [ "D", "D", "S" ],  // B
-    [ "D", "D" ],       // C
-    [ "D", "D" ],       // D
-    [ "D", "D" ]        // A
-  ],
   iterator_types = ["parallel", "parallel", "reduction", "reduction"],
   doc = "A(i,j) += SUM_k,l B(i,k,l) * C(k,j) * D(l,j)"
 }
 
 // CHECK-LABEL:   func @kernel_3d(
 // CHECK-SAME:                    %[[VAL_0:.*0]]: tensor<?x?xf32>,
-// CHECK-SAME:                    %[[VAL_1:.*1]]: tensor<?x?x?xf32>,
+// CHECK-SAME:                    %[[VAL_1:.*1]]: tensor<?x?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                    %[[VAL_2:.*2]]: tensor<?x?xf32>,
 // CHECK-SAME:                    %[[VAL_3:.*3]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
 // CHECK:           %[[VAL_4:.*]] = constant 2 : index
 // CHECK:           %[[VAL_5:.*]] = constant 0 : index
 // CHECK:           %[[VAL_6:.*]] = constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<?x?x?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<?x?x?xf32> to memref<?xindex>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?x?xf32> to memref<?xf32>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_10:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : tensor<?x?xf32>
 // CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<?x?xf32>
 // CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_3]] : memref<?x?xf32>
@@ -1257,7 +1158,7 @@ func @mul_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:                   %[[VAL_31:.*]] = mulf %[[VAL_29]], %[[VAL_30]] : f32
 // CHECK:                   %[[VAL_32:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_17]], %[[VAL_27]]] : memref<?x?xf32>
 // CHECK:                   %[[VAL_33:.*]] = addf %[[VAL_31]], %[[VAL_32]] : f32
-// CHECK:                   store %[[VAL_33]], %[[VAL_16]]{{\[}}%[[VAL_17]], %[[VAL_27]]] : memref<?x?xf32>
+// CHECK:                   memref.store %[[VAL_33]], %[[VAL_16]]{{\[}}%[[VAL_17]], %[[VAL_27]]] : memref<?x?xf32>
 // CHECK:                 }
 // CHECK:               }
 // CHECK:             }
@@ -1266,11 +1167,11 @@ func @mul_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>, %argx: ten
 // CHECK:           return %[[VAL_34]] : tensor<?x?xf32>
 // CHECK:         }
 func @kernel_3d(%arga: tensor<?x?xf32>,
-                %argb: tensor<?x?x?xf32>,
+                %argb: tensor<?x?x?xf32, #Tdds>,
                 %argc: tensor<?x?xf32>,
 	        %argd: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.generic #trait_kernel_3d
-       ins(%argb, %argc, %argd: tensor<?x?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
+       ins(%argb, %argc, %argd: tensor<?x?x?xf32, #Tdds>, tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%arga: tensor<?x?xf32>) {
     ^bb(%b: f32, %c: f32, %d: f32, %a: f32):
       %0 = mulf %b, %c : f32
@@ -1286,24 +1187,20 @@ func @kernel_3d(%arga: tensor<?x?xf32>,
     affine_map<(i,j,k) -> (i,j,k)>,  // A
     affine_map<(i,j,k) -> ()>        // x (scalar out)
   ],
-  sparse = [
-    [ "S", "S", "S" ],  // A
-    [ ]                 // x
-  ],
   iterator_types = ["reduction", "reduction", "reduction"],
   doc = "x += SUM_ijk A(i,j,k)"
 }
 
 // CHECK-LABEL:   func @sum_reduction(
-// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<10x20x30xf32>,
+// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<10x20x30xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-SAME:                        %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
 // CHECK:           %[[VAL_2:.*]] = constant 2 : index
 // CHECK:           %[[VAL_3:.*]] = constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<10x20x30xf32> to memref<?xindex>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<10x20x30xf32> to memref<?xindex>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<10x20x30xf32> to memref<?xindex>
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20x30xf32> to memref<?xf32>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20x30xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<f32>
 // CHECK:           %[[VAL_10:.*]] = memref.alloc() : memref<f32>
 // CHECK:           linalg.copy(%[[VAL_9]], %[[VAL_10]]) : memref<f32>, memref<f32>
@@ -1323,15 +1220,15 @@ func @kernel_3d(%arga: tensor<?x?xf32>,
 // CHECK:                 %[[VAL_26:.*]] = addf %[[VAL_24]], %[[VAL_25]] : f32
 // CHECK:                 scf.yield %[[VAL_26]] : f32
 // CHECK:               }
-// CHECK:               store %[[VAL_27:.*]], %[[VAL_10]][] : memref<f32>
+// CHECK:               memref.store %[[VAL_27:.*]], %[[VAL_10]][] : memref<f32>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_28:.*]] = memref.tensor_load %[[VAL_10]] : memref<f32>
 // CHECK:           return %[[VAL_28]] : tensor<f32>
 // CHECK:         }
-func @sum_reduction(%arga: tensor<10x20x30xf32>, %argx: tensor<f32>) -> tensor<f32> {
+func @sum_reduction(%arga: tensor<10x20x30xf32, #Tsss>, %argx: tensor<f32>) -> tensor<f32> {
   %0 = linalg.generic #trait_sum_reduction
-     ins(%arga: tensor<10x20x30xf32>)
+     ins(%arga: tensor<10x20x30xf32, #Tsss>)
     outs(%argx: tensor<f32>) {
       ^bb(%a: f32, %x: f32):
         %0 = addf %x, %a : f32
@@ -1346,11 +1243,6 @@ func @sum_reduction(%arga: tensor<10x20x30xf32>, %argx: tensor<f32>) -> tensor<f
     affine_map<(i,j,k) -> (i)>,      // b
     affine_map<(i,j,k) -> ()>        // x (scalar out)
   ],
-  sparse = [
-    [ "D", "D", "D" ], // A
-    [ "D" ],           // b
-    [ ]                // x
-  ],
   iterator_types = ["reduction", "reduction", "reduction"],
   doc = "x += SUM_i A(i,j,k) * b(i)"
 }
@@ -1380,7 +1272,7 @@ func @sum_reduction(%arga: tensor<10x20x30xf32>, %argx: tensor<f32>) -> tensor<f
 // CHECK:                 %[[VAL_22:.*]] = addf %[[VAL_19]], %[[VAL_21]] : f32
 // CHECK:                 scf.yield %[[VAL_22]] : f32
 // CHECK:               }
-// CHECK:               store %[[VAL_23:.*]], %[[VAL_12]][] : memref<f32>
+// CHECK:               memref.store %[[VAL_23:.*]], %[[VAL_12]][] : memref<f32>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_24:.*]] = memref.tensor_load %[[VAL_12]] : memref<f32>
@@ -1407,12 +1299,6 @@ func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
     affine_map<(i,j,k) -> (k)>,      // c
     affine_map<(i,j,k) -> (i,j,k)>   // X (out)
   ],
-  sparse = [
-    [ "D" ],           // a
-    [ "D" ],           // b
-    [ "D" ],           // c
-    [ "D", "D", "D" ]  // X
-  ],
   iterator_types = ["parallel", "parallel", "parallel"],
   doc = "X(i,j,k) = a(i) * b(j) * c(k)"
 }
@@ -1441,7 +1327,7 @@ func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
 // CHECK:                 %[[VAL_19:.*]] = mulf %[[VAL_15]], %[[VAL_17]] : f32
 // CHECK:                 %[[VAL_20:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_18]]] : memref<30xf32>
 // CHECK:                 %[[VAL_21:.*]] = mulf %[[VAL_19]], %[[VAL_20]] : f32
-// CHECK:                 store %[[VAL_21]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_16]], %[[VAL_18]]] : memref<10x20x30xf32>
+// CHECK:                 memref.store %[[VAL_21]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_16]], %[[VAL_18]]] : memref<10x20x30xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:           }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_invalid.mlir b/mlir/test/Dialect/SparseTensor/sparse_invalid.mlir
deleted file mode 100644
index 4029964737d87..0000000000000
--- a/mlir/test/Dialect/SparseTensor/sparse_invalid.mlir
+++ /dev/null
@@ -1,186 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics
-
-#trait_memref = {
-  indexing_maps = [
-    affine_map<(i) -> (i)>,  // a
-    affine_map<(i) -> (i)>   // x (out)
-  ],
-  sparse = [
-    [ "S" ],  // a
-    [ "D" ]   // x
-  ],
-  iterator_types = ["parallel"],
-  doc = "x(i) = a(i) + b"
-}
-
-func @invalid_memref(%arga: memref<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  // expected-error at +1 {{'linalg.generic' op expected sparse annotations on tensors only}}
-  %0 = linalg.generic #trait_memref
-     ins(%arga: memref<32xf32>)
-    outs(%argx: tensor<32xf32>) {
-      ^bb(%a: f32, %x: f32):
-        %0 = addf %a, %argb : f32
-        linalg.yield %0 : f32
-  } -> tensor<32xf32>
-  return %0 : tensor<32xf32>
-}
-
-// -----
-
-#trait_too_many = {
-  indexing_maps = [
-    affine_map<(i) -> (i)>,  // a
-    affine_map<(i) -> (i)>   // x (out)
-  ],
-  sparse = [
-    [ "S" ],  // a
-    [ "S" ],  // b
-    [ "D" ]   // x
-  ],
-  iterator_types = ["parallel"],
-  doc = "x(i) = a(i) + b"
-}
-
-func @invalid_too_many(%arga: tensor<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  // expected-error at +1 {{'linalg.generic' op expected one sparse annotation for each tensor}}
-  %0 = linalg.generic #trait_too_many
-     ins(%arga: tensor<32xf32>)
-    outs(%argx: tensor<32xf32>) {
-      ^bb(%a: f32, %x: f32):
-        %0 = addf %a, %argb : f32
-        linalg.yield %0 : f32
-  } -> tensor<32xf32>
-  return %0 : tensor<32xf32>
-}
-
-// -----
-
-#trait_no_array = {
-  indexing_maps = [
-    affine_map<(i) -> (i)>,  // a
-    affine_map<(i) -> (i)>   // x (out)
-  ],
-  sparse = [ 1, 2 ],
-  iterator_types = ["parallel"],
-  doc = "x(i) = a(i) + b"
-}
-
-func @invalid_no_array(%arga: tensor<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  // expected-error at +1 {{'linalg.generic' op expected sparse annotation array for tensor 0}}
-  %0 = linalg.generic #trait_no_array
-     ins(%arga: tensor<32xf32>)
-    outs(%argx: tensor<32xf32>) {
-      ^bb(%a: f32, %x: f32):
-        %0 = addf %a, %argb : f32
-        linalg.yield %0 : f32
-  } -> tensor<32xf32>
-  return %0 : tensor<32xf32>
-}
-
-// -----
-
-#trait_wrong_rank = {
-  indexing_maps = [
-    affine_map<(i) -> (i)>,  // a
-    affine_map<(i) -> (i)>   // x (out)
-  ],
-  sparse = [
-    [ "S" ],
-    [ "D", "D" ]
-  ],
-  iterator_types = ["parallel"],
-  doc = "x(i) = a(i) + b"
-}
-
-func @invalid_wrong_rank(%arga: tensor<32xf32>, %argb: f32, %argx: tensor<32xf32>) -> tensor<32xf32> {
-  // expected-error at +1 {{'linalg.generic' op expected sparse annotation with rank 1 for tensor 1}}
-  %0 = linalg.generic #trait_wrong_rank
-     ins(%arga: tensor<32xf32>)
-    outs(%argx: tensor<32xf32>) {
-      ^bb(%a: f32, %x: f32):
-        %0 = addf %a, %argb : f32
-        linalg.yield %0 : f32
-  } -> tensor<32xf32>
-  return %0 : tensor<32xf32>
-}
-
-// -----
-
-#trait_no_string = {
-  indexing_maps = [
-    affine_map<(i,j) -> (i,j)>,  // a
-    affine_map<(i,j) -> (i,j)>   // x (out)
-  ],
-  sparse = [
-    [ "S", 1 ],
-    [ "D", "D" ]
-  ],
-  iterator_types = ["parallel","parallel"],
-  doc = "x(i,j) = a(i,j) + b"
-}
-
-func @invalid_no_string(%arga: tensor<32x16xf32>, %argb: f32, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  // expected-error at +1 {{'linalg.generic' op expected sparse annotation at position 1 for tensor 0}}
-  %0 = linalg.generic #trait_no_string
-     ins(%arga: tensor<32x16xf32>)
-    outs(%argx: tensor<32x16xf32>) {
-      ^bb(%a: f32, %x: f32):
-        %0 = addf %a, %argb : f32
-        linalg.yield %0 : f32
-  } -> tensor<32x16xf32>
-  return %0 : tensor<32x16xf32>
-}
-
-// -----
-
-#trait_wrong_symbol = {
-  indexing_maps = [
-    affine_map<(i,j) -> (i,j)>,  // a
-    affine_map<(i,j) -> (i,j)>   // x (out)
-  ],
-  sparse = [
-    [ "S", "S" ],
-    [ "D", "X" ]
-  ],
-  iterator_types = ["parallel","parallel"],
-  doc = "x(i,j) = a(i,j) + b"
-}
-
-func @invalid_wrong_symbol(%arga: tensor<32x16xf32>, %argb: f32, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  // expected-error at +1 {{'linalg.generic' op expected sparse annotation at position 1 for tensor 1}}
-  %0 = linalg.generic #trait_wrong_symbol
-     ins(%arga: tensor<32x16xf32>)
-    outs(%argx: tensor<32x16xf32>) {
-      ^bb(%a: f32, %x: f32):
-        %0 = addf %a, %argb : f32
-        linalg.yield %0 : f32
-  } -> tensor<32x16xf32>
-  return %0 : tensor<32x16xf32>
-}
-
-// -----
-
-#trait_no_sparse_output = {
-  indexing_maps = [
-    affine_map<(i,j) -> (i,j)>,  // a
-    affine_map<(i,j) -> (i,j)>   // x (out)
-  ],
-  sparse = [
-    [ "S", "S" ],
-    [ "D", "S" ]
-  ],
-  iterator_types = ["parallel","parallel"],
-  doc = "x(i,j) = a(i,j) + b"
-}
-
-func @invalid_no_sparse_output(%arga: tensor<32x16xf32>, %argb: f32, %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
-  // expected-error at +1 {{'linalg.generic' op sparse output tensors not supported (yet)}}
-  %0 = linalg.generic #trait_no_sparse_output
-     ins(%arga: tensor<32x16xf32>)
-    outs(%argx: tensor<32x16xf32>) {
-      ^bb(%a: f32, %x: f32):
-        %0 = addf %a, %argb : f32
-        linalg.yield %0 : f32
-  } -> tensor<32x16xf32>
-  return %0 : tensor<32x16xf32>
-}

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
index 54179a3395f16..7daee6a0ba50f 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
@@ -1,21 +1,19 @@
-// RUN: mlir-opt %s -sparsification | \
-// RUN:   FileCheck %s --check-prefix=CHECK-HIR
+// RUN: mlir-opt %s -sparsification | FileCheck %s --check-prefix=CHECK-HIR
 //
-// RUN: mlir-opt %s -sparsification \
-// RUN:   --sparse-tensor-conversion --convert-linalg-to-loops | \
-// RUN:   FileCheck %s --check-prefix=CHECK-MIR
+// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion                 \
+// RUN: --convert-linalg-to-loops | FileCheck %s --check-prefix=CHECK-MIR
 //
-// RUN: mlir-opt %s -sparsification \
-// RUN:   --sparse-tensor-conversion --convert-linalg-to-loops \
-// RUN:   --func-bufferize --tensor-constant-bufferize \
-// RUN:   --tensor-bufferize --finalizing-bufferize  | \
-// RUN:   FileCheck %s --check-prefix=CHECK-LIR
+// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion                 \
+// RUN: --convert-linalg-to-loops --func-bufferize --tensor-constant-bufferize \
+// RUN: --tensor-bufferize --finalizing-bufferize |                            \
+// RUN: FileCheck %s --check-prefix=CHECK-LIR
 //
-// RUN: mlir-opt %s -sparsification="fast-output" \
-// RUN:   --sparse-tensor-conversion --convert-linalg-to-loops \
-// RUN:   --func-bufferize --tensor-constant-bufferize \
-// RUN:   --tensor-bufferize --finalizing-bufferize  | \
-// RUN:   FileCheck %s --check-prefix=CHECK-FAST
+// RUN: mlir-opt %s -sparsification="fast-output" --sparse-tensor-conversion   \
+// RUN: --convert-linalg-to-loops --func-bufferize --tensor-constant-bufferize \
+// RUN: --tensor-bufferize --finalizing-bufferize |                            \
+// RUN: FileCheck %s --check-prefix=CHECK-FAST
+
+#CSR = #sparse_tensor.encoding<{dimLevelType = [ "dense", "compressed" ]}>
 
 #trait_matvec = {
   indexing_maps = [
@@ -24,30 +22,19 @@
     affine_map<(i,j) -> (i)>     // x (out)
   ],
   iterator_types = ["parallel","reduction"],
-  sparse = [
-    [ "D", "S" ],  // A
-    [ "D" ],       // b
-    [ "D" ]        // x (out)
-  ],
-  sparse_dim_map = [
-    affine_map<(i,j) -> (j,i)>,  // A: column-wise
-    affine_map<(i)   -> (i)>,    // x
-    affine_map<(i)   -> (i)>     // b
-  ],
   doc = "x(i) += A(i,j) * b(j)"
 }
 
 // CHECK-HIR-LABEL:   func @matvec(
-// CHECK-HIR-SAME:                 %[[VAL_0:.*]]: !llvm.ptr<i8>,
+// CHECK-HIR-SAME:                 %[[VAL_0:.*]]: tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>>,
 // CHECK-HIR-SAME:                 %[[VAL_1:.*]]: tensor<64xf64>,
 // CHECK-HIR-SAME:                 %[[VAL_2:.*]]: tensor<64xf64>) -> tensor<64xf64> {
 // CHECK-HIR:           %[[VAL_3:.*]] = constant 64 : index
 // CHECK-HIR:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK-HIR:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK-HIR:           %[[VAL_6:.*]] = sparse_tensor.fromPtr %[[VAL_0]] : !llvm.ptr<i8> to tensor<64x64xf64>
-// CHECK-HIR:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_6]], %[[VAL_5]] : tensor<64x64xf64> to memref<?xindex>
-// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_6]], %[[VAL_5]] : tensor<64x64xf64> to memref<?xindex>
-// CHECK-HIR:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_6]] : tensor<64x64xf64> to memref<?xf64>
+// CHECK-HIR:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #{{.*}}> to memref<?xindex>
+// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #{{.*}}> to memref<?xindex>
+// CHECK-HIR:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<64x64xf64, #{{.*}}> to memref<?xf64>
 // CHECK-HIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
 // CHECK-HIR:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64>
 // CHECK-HIR:           %[[VAL_12:.*]] = memref.alloc() : memref<64xf64>
@@ -78,8 +65,8 @@
 // CHECK-MIR:           %[[VAL_3:.*]] = constant 64 : index
 // CHECK-MIR:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK-MIR:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK-MIR:           %[[VAL_6:.*]] = call @sparsePointers64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
-// CHECK-MIR:           %[[VAL_7:.*]] = call @sparseIndices64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK-MIR:           %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK-MIR:           %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-MIR:           %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
 // CHECK-MIR:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
 // CHECK-MIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64>
@@ -114,8 +101,8 @@
 // CHECK-LIR:           %[[VAL_3:.*]] = constant 64 : index
 // CHECK-LIR:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK-LIR:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK-LIR:           %[[VAL_6:.*]] = call @sparsePointers64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
-// CHECK-LIR:           %[[VAL_7:.*]] = call @sparseIndices64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK-LIR:           %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK-LIR:           %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-LIR:           %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
 // CHECK-LIR:           %[[VAL_9:.*]] = memref.alloc() : memref<64xf64>
 // CHECK-LIR:           scf.for %[[VAL_10:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
@@ -147,8 +134,8 @@
 // CHECK-FAST:           %[[VAL_3:.*]] = constant 64 : index
 // CHECK-FAST:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK-FAST:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK-FAST:           %[[VAL_6:.*]] = call @sparsePointers64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
-// CHECK-FAST:           %[[VAL_7:.*]] = call @sparseIndices64(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK-FAST:           %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK-FAST:           %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-FAST:           %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
 // CHECK-FAST:           scf.for %[[VAL_9:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK-FAST:             %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<?xindex>
@@ -168,12 +155,9 @@
 // CHECK-FAST:           return %[[VAL_2]] : memref<64xf64>
 // CHECK-FAST:         }
 
-!SparseTensor = type !llvm.ptr<i8>
-
-func @matvec(%argA: !SparseTensor, %argb: tensor<64xf64>, %argx: tensor<64xf64>) -> tensor<64xf64> {
-  %arga = sparse_tensor.fromPtr %argA : !SparseTensor to tensor<64x64xf64>
+func @matvec(%arga: tensor<64x64xf64, #CSR>, %argb: tensor<64xf64>, %argx: tensor<64xf64>) -> tensor<64xf64> {
   %0 = linalg.generic #trait_matvec
-      ins(%arga, %argb : tensor<64x64xf64>, tensor<64xf64>)
+      ins(%arga, %argb : tensor<64x64xf64, #CSR>, tensor<64xf64>)
       outs(%argx: tensor<64xf64>) {
     ^bb(%A: f64, %b: f64, %x: f64):
       %0 = mulf %A, %b : f64

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_nd.mlir b/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
index 7d926d33a3863..82c0d06e05688 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_nd.mlir
@@ -3,26 +3,27 @@
 
 // Example with cyclic iteration graph with sparse and dense constraints,
 // but an acyclic iteration graph using sparse constraints only.
+
+#SparseTensor = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "dense", "dense", "compressed",
+                   "compressed", "dense", "dense", "dense" ]
+}>
+
 #trait_mul = {
   indexing_maps = [
     affine_map<(i,j,k,l,m,n,o,p) -> (i,j,k,l,m,n,o,p)>,  // A
     affine_map<(i,j,k,l,m,n,o,p) -> (p,o,n,m,l,k,j,i)>,  // B
     affine_map<(i,j,k,l,m,n,o,p) -> (i,j,k,l,m,n,o,p)>   // X
   ],
-  sparse = [
-    [ "D", "D", "D", "D", "D", "D", "D", "D" ],  // a
-    [ "D", "D", "D", "S", "S", "D", "D", "D" ],  // b
-    [ "D", "D", "D", "D", "D", "D", "D", "D" ]   // x
-  ],
   iterator_types = ["parallel", "parallel", "parallel", "parallel",
                     "parallel", "parallel", "parallel", "parallel"],
-  doc = "X(i,j,k,l,m,n,o,p) = A(i,j,k,l,m,n,o,p)  * B(p,o,n,m,l,k,j,i)"
+  doc = "X(i,j,k,l,m,n,o,p) = A(i,j,k,l,m,n,o,p) * B(p,o,n,m,l,k,j,i)"
 }
 
 // CHECK-LABEL:   func @mul(
-// CHECK-SAME:              %[[VAL_0:.*0]]: tensor<10x20x30x40x50x60x70x80xf32>,
-// CHECK-SAME:              %[[VAL_1:.*1]]: tensor<80x70x60x50x40x30x20x10xf32>,
-// CHECK-SAME:              %[[VAL_2:.*2]]: tensor<10x20x30x40x50x60x70x80xf32>) -> tensor<10x20x30x40x50x60x70x80xf32> {
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<10x20x30x40x50x60x70x80xf32>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<80x70x60x50x40x30x20x10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense", "compressed", "compressed", "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:              %[[VAL_2:.*]]: tensor<10x20x30x40x50x60x70x80xf32>) -> tensor<10x20x30x40x50x60x70x80xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 3 : index
 // CHECK:           %[[VAL_4:.*]] = constant 4 : index
 // CHECK:           %[[VAL_5:.*]] = constant 10 : index
@@ -34,11 +35,11 @@
 // CHECK:           %[[VAL_11:.*]] = constant 0 : index
 // CHECK:           %[[VAL_12:.*]] = constant 1 : index
 // CHECK:           %[[VAL_13:.*]] = memref.buffer_cast %[[VAL_0]] : memref<10x20x30x40x50x60x70x80xf32>
-// CHECK:           %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xindex>
-// CHECK:           %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xindex>
-// CHECK:           %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xindex>
-// CHECK:           %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xindex>
-// CHECK:           %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xf32>
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<80x70x60x50x40x30x20x10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense", "compressed", "compressed", "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<80x70x60x50x40x30x20x10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense", "compressed", "compressed", "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<80x70x60x50x40x30x20x10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense", "compressed", "compressed", "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<80x70x60x50x40x30x20x10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense", "compressed", "compressed", "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK:           %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<80x70x60x50x40x30x20x10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense", "compressed", "compressed", "dense", "dense", "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
 // CHECK:           %[[VAL_19:.*]] = memref.buffer_cast %[[VAL_2]] : memref<10x20x30x40x50x60x70x80xf32>
 // CHECK:           %[[VAL_20:.*]] = memref.alloc() : memref<10x20x30x40x50x60x70x80xf32>
 // CHECK:           linalg.copy(%[[VAL_19]], %[[VAL_20]]) : memref<10x20x30x40x50x60x70x80xf32>, memref<10x20x30x40x50x60x70x80xf32>
@@ -84,12 +85,12 @@
 // CHECK:           return %[[VAL_50]] : tensor<10x20x30x40x50x60x70x80xf32>
 // CHECK:         }
 func @mul(%arga: tensor<10x20x30x40x50x60x70x80xf32>,
-          %argb: tensor<80x70x60x50x40x30x20x10xf32>,
+          %argb: tensor<80x70x60x50x40x30x20x10xf32, #SparseTensor>,
           %argx: tensor<10x20x30x40x50x60x70x80xf32>)
 	      -> tensor<10x20x30x40x50x60x70x80xf32> {
   %0 = linalg.generic #trait_mul
     ins(%arga, %argb: tensor<10x20x30x40x50x60x70x80xf32>,
-                      tensor<80x70x60x50x40x30x20x10xf32>)
+                      tensor<80x70x60x50x40x30x20x10xf32, #SparseTensor>)
     outs(%argx: tensor<10x20x30x40x50x60x70x80xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir b/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir
index 31395d72b9cb8..fc575141e311a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir
@@ -9,15 +9,19 @@
 // RUN: mlir-opt %s -sparsification="parallelization-strategy=4" | \
 // RUN:   FileCheck %s --check-prefix=CHECK-PAR4
 
+#SparseMatrix = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ]
+}>
+
+#CSR = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ]
+}>
+
 #trait_dd = {
   indexing_maps = [
     affine_map<(i,j) -> (i,j)>,  // A
     affine_map<(i,j) -> (i,j)>   // X (out)
   ],
-  sparse = [
-    [ "D", "D" ],  // A
-    [ "D", "D" ]   // X
-  ],
   iterator_types = ["parallel", "parallel"],
   doc = "X(i,j) = A(i,j) * SCALE"
 }
@@ -64,10 +68,6 @@ func @scale_dd(%scale: f32, %arga: tensor<?x?xf32>, %argx: tensor<?x?xf32>) -> t
     affine_map<(i,j) -> (i,j)>,  // A
     affine_map<(i,j) -> (i,j)>   // X (out)
   ],
-  sparse = [
-    [ "S", "S" ],  // A
-    [ "D", "D" ]   // X
-  ],
   iterator_types = ["parallel", "parallel"],
   doc = "X(i,j) = A(i,j) * SCALE"
 }
@@ -98,9 +98,9 @@ func @scale_dd(%scale: f32, %arga: tensor<?x?xf32>, %argx: tensor<?x?xf32>) -> t
 // CHECK-PAR4:           scf.parallel
 // CHECK-PAR4:         return
 //
-func @scale_ss(%scale: f32, %arga: tensor<?x?xf32>, %argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func @scale_ss(%scale: f32, %arga: tensor<?x?xf32, #SparseMatrix>, %argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.generic #trait_ss
-     ins(%arga: tensor<?x?xf32>)
+     ins(%arga: tensor<?x?xf32, #SparseMatrix>)
     outs(%argx: tensor<?x?xf32>) {
       ^bb(%a: f32, %x: f32):
         %0 = mulf %a, %scale : f32
@@ -115,11 +115,6 @@ func @scale_ss(%scale: f32, %arga: tensor<?x?xf32>, %argx: tensor<?x?xf32>) -> t
     affine_map<(i,j) -> (j)>,    // b
     affine_map<(i,j) -> (i)>     // x (out)
   ],
-  sparse = [
-    [ "D", "S" ],  // A
-    [ "D" ],       // b
-    [ "D" ]        // x
-  ],
   iterator_types = ["parallel", "reduction"],
   doc = "x(i) += A(i,j) * b(j)"
 }
@@ -150,9 +145,9 @@ func @scale_ss(%scale: f32, %arga: tensor<?x?xf32>, %argx: tensor<?x?xf32>) -> t
 // CHECK-PAR4:           scf.for
 // CHECK-PAR4:         return
 //
-func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> {
+func @matvec(%argA: tensor<16x32xf32, #CSR>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> {
   %0 = linalg.generic #trait_matvec
-      ins(%argA, %argb : tensor<16x32xf32>, tensor<32xf32>)
+      ins(%argA, %argb : tensor<16x32xf32, #CSR>, tensor<32xf32>)
      outs(%argx: tensor<16xf32>) {
     ^bb(%A: f32, %b: f32, %x: f32):
       %0 = mulf %A, %b : f32

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_storage.mlir
index b6d4adff69131..2f768bfd2d383 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_storage.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_storage.mlir
@@ -1,142 +1,74 @@
-// RUN: mlir-opt %s -sparsification="ptr-type=1 ind-type=1" | \
-// RUN:   FileCheck %s --check-prefix=CHECK-TYPE0
-// RUN: mlir-opt %s -sparsification="ptr-type=1 ind-type=2" | \
-// RUN:   FileCheck %s --check-prefix=CHECK-TYPE1
-// RUN: mlir-opt %s -sparsification="ptr-type=2 ind-type=1" | \
-// RUN:   FileCheck %s --check-prefix=CHECK-TYPE2
-// RUN: mlir-opt %s -sparsification="ptr-type=2 ind-type=2" | \
-// RUN:   FileCheck %s --check-prefix=CHECK-TYPE3
-// RUN: mlir-opt %s -sparsification="ptr-type=3 ind-type=3" | \
-// RUN:   FileCheck %s --check-prefix=CHECK-TYPE4
-// RUN: mlir-opt %s -sparsification="ptr-type=4 ind-type=4" | \
-// RUN:   FileCheck %s --check-prefix=CHECK-TYPE5
+// RUN: mlir-opt %s -sparsification= | FileCheck %s
 
-#trait_mul_1d = {
+#SparseVector64 = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed" ],
+  pointerBitWidth = 64,
+  indexBitWidth = 64
+}>
+
+#SparseVector32 = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed" ],
+  pointerBitWidth = 32,
+  indexBitWidth = 32
+}>
+
+#trait_mul = {
   indexing_maps = [
     affine_map<(i) -> (i)>,  // a
     affine_map<(i) -> (i)>,  // b
     affine_map<(i) -> (i)>   // x (out)
   ],
-  sparse = [
-    [ "S" ],  // a
-    [ "D" ],  // b
-    [ "D" ]   // x
-  ],
   iterator_types = ["parallel"],
   doc = "x(i) = a(i) * b(i)"
 }
 
-// CHECK-TYPE0-LABEL: func @mul_dd(
-// CHECK-TYPE0: %[[C0:.*]] = constant 0 : index
-// CHECK-TYPE0: %[[C1:.*]] = constant 1 : index
-// CHECK-TYPE0: %[[P0:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xi64>
-// CHECK-TYPE0: %[[B0:.*]] = index_cast %[[P0]] : i64 to index
-// CHECK-TYPE0: %[[P1:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xi64>
-// CHECK-TYPE0: %[[B1:.*]] = index_cast %[[P1]] : i64 to index
-// CHECK-TYPE0: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
-// CHECK-TYPE0:   %[[IND0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xi64>
-// CHECK-TYPE0:   %[[INDC:.*]] = index_cast %[[IND0]] : i64 to index
-// CHECK-TYPE0:   %[[VAL0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf64>
-// CHECK-TYPE0:   %[[VAL1:.*]] = memref.load %{{.*}}[%[[INDC]]] : memref<32xf64>
-// CHECK-TYPE0:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
-// CHECK-TYPE0:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
-// CHECK-TYPE0: }
-
-// CHECK-TYPE1-LABEL: func @mul_dd(
-// CHECK-TYPE1: %[[C0:.*]] = constant 0 : index
-// CHECK-TYPE1: %[[C1:.*]] = constant 1 : index
-// CHECK-TYPE1: %[[P0:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xi64>
-// CHECK-TYPE1: %[[B0:.*]] = index_cast %[[P0]] : i64 to index
-// CHECK-TYPE1: %[[P1:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xi64>
-// CHECK-TYPE1: %[[B1:.*]] = index_cast %[[P1]] : i64 to index
-// CHECK-TYPE1: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
-// CHECK-TYPE1:   %[[IND0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xi32>
-// CHECK-TYPE1:   %[[ZEXT:.*]] = zexti %[[IND0]] : i32 to i64
-// CHECK-TYPE1:   %[[INDC:.*]] = index_cast %[[ZEXT]] : i64 to index
-// CHECK-TYPE1:   %[[VAL0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf64>
-// CHECK-TYPE1:   %[[VAL1:.*]] = memref.load %{{.*}}[%[[INDC]]] : memref<32xf64>
-// CHECK-TYPE1:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
-// CHECK-TYPE1:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
-// CHECK-TYPE1: }
-
-// CHECK-TYPE2-LABEL: func @mul_dd(
-// CHECK-TYPE2: %[[C0:.*]] = constant 0 : index
-// CHECK-TYPE2: %[[C1:.*]] = constant 1 : index
-// CHECK-TYPE2: %[[P0:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xi32>
-// CHECK-TYPE2: %[[Z0:.*]] = zexti %[[P0]] : i32 to i64
-// CHECK-TYPE2: %[[B0:.*]] = index_cast %[[Z0]] : i64 to index
-// CHECK-TYPE2: %[[P1:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xi32>
-// CHECK-TYPE2: %[[Z1:.*]] = zexti %[[P1]] : i32 to i64
-// CHECK-TYPE2: %[[B1:.*]] = index_cast %[[Z1]] : i64 to index
-// CHECK-TYPE2: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
-// CHECK-TYPE2:   %[[IND0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xi64>
-// CHECK-TYPE2:   %[[INDC:.*]] = index_cast %[[IND0]] : i64 to index
-// CHECK-TYPE2:   %[[VAL0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf64>
-// CHECK-TYPE2:   %[[VAL1:.*]] = memref.load %{{.*}}[%[[INDC]]] : memref<32xf64>
-// CHECK-TYPE2:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
-// CHECK-TYPE2:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
-// CHECK-TYPE2: }
-
-// CHECK-TYPE3-LABEL: func @mul_dd(
-// CHECK-TYPE3: %[[C0:.*]] = constant 0 : index
-// CHECK-TYPE3: %[[C1:.*]] = constant 1 : index
-// CHECK-TYPE3: %[[P0:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xi32>
-// CHECK-TYPE3: %[[Z0:.*]] = zexti %[[P0]] : i32 to i64
-// CHECK-TYPE3: %[[B0:.*]] = index_cast %[[Z0]] : i64 to index
-// CHECK-TYPE3: %[[P1:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xi32>
-// CHECK-TYPE3: %[[Z1:.*]] = zexti %[[P1]] : i32 to i64
-// CHECK-TYPE3: %[[B1:.*]] = index_cast %[[Z1]] : i64 to index
-// CHECK-TYPE3: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
-// CHECK-TYPE3:   %[[IND0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xi32>
-// CHECK-TYPE3:   %[[ZEXT:.*]] = zexti %[[IND0]] : i32 to i64
-// CHECK-TYPE3:   %[[INDC:.*]] = index_cast %[[ZEXT]] : i64 to index
-// CHECK-TYPE3:   %[[VAL0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf64>
-// CHECK-TYPE3:   %[[VAL1:.*]] = memref.load %{{.*}}[%[[INDC]]] : memref<32xf64>
-// CHECK-TYPE3:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
-// CHECK-TYPE3:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
-// CHECK-TYPE3: }
-
-// CHECK-TYPE4-LABEL: func @mul_dd(
-// CHECK-TYPE4: %[[C0:.*]] = constant 0 : index
-// CHECK-TYPE4: %[[C1:.*]] = constant 1 : index
-// CHECK-TYPE4: %[[P0:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xi16>
-// CHECK-TYPE4: %[[Z0:.*]] = zexti %[[P0]] : i16 to i64
-// CHECK-TYPE4: %[[B0:.*]] = index_cast %[[Z0]] : i64 to index
-// CHECK-TYPE4: %[[P1:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xi16>
-// CHECK-TYPE4: %[[Z1:.*]] = zexti %[[P1]] : i16 to i64
-// CHECK-TYPE4: %[[B1:.*]] = index_cast %[[Z1]] : i64 to index
-// CHECK-TYPE4: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
-// CHECK-TYPE4:   %[[IND0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xi16>
-// CHECK-TYPE4:   %[[ZEXT:.*]] = zexti %[[IND0]] : i16 to i64
-// CHECK-TYPE4:   %[[INDC:.*]] = index_cast %[[ZEXT]] : i64 to index
-// CHECK-TYPE4:   %[[VAL0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf64>
-// CHECK-TYPE4:   %[[VAL1:.*]] = memref.load %{{.*}}[%[[INDC]]] : memref<32xf64>
-// CHECK-TYPE4:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
-// CHECK-TYPE4:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
-// CHECK-TYPE4: }
-
-// CHECK-TYPE5-LABEL: func @mul_dd(
-// CHECK-TYPE5: %[[C0:.*]] = constant 0 : index
-// CHECK-TYPE5: %[[C1:.*]] = constant 1 : index
-// CHECK-TYPE5: %[[P0:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xi8>
-// CHECK-TYPE5: %[[Z0:.*]] = zexti %[[P0]] : i8 to i64
-// CHECK-TYPE5: %[[B0:.*]] = index_cast %[[Z0]] : i64 to index
-// CHECK-TYPE5: %[[P1:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xi8>
-// CHECK-TYPE5: %[[Z1:.*]] = zexti %[[P1]] : i8 to i64
-// CHECK-TYPE5: %[[B1:.*]] = index_cast %[[Z1]] : i64 to index
-// CHECK-TYPE5: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
-// CHECK-TYPE5:   %[[IND0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xi8>
-// CHECK-TYPE5:   %[[ZEXT:.*]] = zexti %[[IND0]] : i8 to i64
-// CHECK-TYPE5:   %[[INDC:.*]] = index_cast %[[ZEXT]] : i64 to index
-// CHECK-TYPE5:   %[[VAL0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf64>
-// CHECK-TYPE5:   %[[VAL1:.*]] = memref.load %{{.*}}[%[[INDC]]] : memref<32xf64>
-// CHECK-TYPE5:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
-// CHECK-TYPE5:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
-// CHECK-TYPE5: }
+// CHECK-LABEL: func @mul64(
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[P0:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xi64>
+// CHECK: %[[B0:.*]] = index_cast %[[P0]] : i64 to index
+// CHECK: %[[P1:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xi64>
+// CHECK: %[[B1:.*]] = index_cast %[[P1]] : i64 to index
+// CHECK: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
+// CHECK:   %[[IND0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xi64>
+// CHECK:   %[[INDC:.*]] = index_cast %[[IND0]] : i64 to index
+// CHECK:   %[[VAL0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf64>
+// CHECK:   %[[VAL1:.*]] = memref.load %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
+// CHECK:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK: }
+func @mul64(%arga: tensor<32xf64, #SparseVector64>, %argb: tensor<32xf64>, %argx: tensor<32xf64>) -> tensor<32xf64> {
+  %0 = linalg.generic #trait_mul
+     ins(%arga, %argb: tensor<32xf64, #SparseVector64>, tensor<32xf64>)
+    outs(%argx: tensor<32xf64>) {
+      ^bb(%a: f64, %b: f64, %x: f64):
+        %0 = mulf %a, %b : f64
+        linalg.yield %0 : f64
+  } -> tensor<32xf64>
+  return %0 : tensor<32xf64>
+}
 
-func @mul_dd(%arga: tensor<32xf64>, %argb: tensor<32xf64>, %argx: tensor<32xf64>) -> tensor<32xf64> {
-  %0 = linalg.generic #trait_mul_1d
-     ins(%arga, %argb: tensor<32xf64>, tensor<32xf64>)
+// CHECK-LABEL: func @mul32(
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[P0:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xi32>
+// CHECK: %[[Z0:.*]] = zexti %[[P0]] : i32 to i64
+// CHECK: %[[B0:.*]] = index_cast %[[Z0]] : i64 to index
+// CHECK: %[[P1:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xi32>
+// CHECK: %[[Z1:.*]] = zexti %[[P1]] : i32 to i64
+// CHECK: %[[B1:.*]] = index_cast %[[Z1]] : i64 to index
+// CHECK: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
+// CHECK:   %[[IND0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xi32>
+// CHECK:   %[[ZEXT:.*]] = zexti %[[IND0]] : i32 to i64
+// CHECK:   %[[INDC:.*]] = index_cast %[[ZEXT]] : i64 to index
+// CHECK:   %[[VAL0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf64>
+// CHECK:   %[[VAL1:.*]] = memref.load %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
+// CHECK:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK: }
+func @mul32(%arga: tensor<32xf64, #SparseVector32>, %argb: tensor<32xf64>, %argx: tensor<32xf64>) -> tensor<32xf64> {
+  %0 = linalg.generic #trait_mul
+     ins(%arga, %argb: tensor<32xf64, #SparseVector32>, tensor<32xf64>)
     outs(%argx: tensor<32xf64>) {
       ^bb(%a: f64, %b: f64, %x: f64):
         %0 = mulf %a, %b : f64

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 310076d96df3a..29dd1f9bb9eae 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -1,21 +1,29 @@
-// RUN: mlir-opt %s -sparsification="vectorization-strategy=0 ptr-type=2 ind-type=2 vl=16" | \
+// RUN: mlir-opt %s -sparsification="vectorization-strategy=0 vl=16" | \
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC0
-// RUN: mlir-opt %s -sparsification="vectorization-strategy=1 ptr-type=2 ind-type=2 vl=16" | \
+// RUN: mlir-opt %s -sparsification="vectorization-strategy=1 vl=16" | \
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC1
-// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 ptr-type=2 ind-type=2 vl=16" | \
+// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" | \
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC2
-// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 ptr-type=0 ind-type=0 vl=16" | \
-// RUN:   FileCheck %s --check-prefix=CHECK-VEC3
+
+#DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
+
+#SparseVector = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed" ],
+  pointerBitWidth = 32,
+  indexBitWidth = 32
+}>
+
+#SparseMatrix = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  pointerBitWidth = 32,
+  indexBitWidth = 32
+}>
 
 #trait_scale_d = {
   indexing_maps = [
     affine_map<(i) -> (i)>,  // a
     affine_map<(i) -> (i)>   // x (out)
   ],
-  sparse = [
-    [ "D" ],  // a
-    [ "D" ]   // x
-  ],
   iterator_types = ["parallel"],
   doc = "x(i) = a(i) * b"
 }
@@ -26,7 +34,7 @@
 // CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
 // CHECK-VEC0-DAG:   %[[c1024:.*]] = constant 1024 : index
 // CHECK-VEC0:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] {
-// CHECK-VEC0:         %[[l:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32>
+// CHECK-VEC0:         %[[l:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
 // CHECK-VEC0:         %[[m:.*]] = mulf %[[l]], %{{.*}} : f32
 // CHECK-VEC0:         store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>
 // CHECK-VEC0:       }
@@ -37,7 +45,7 @@
 // CHECK-VEC1-DAG:   %[[c16:.*]] = constant 16 : index
 // CHECK-VEC1-DAG:   %[[c1024:.*]] = constant 1024 : index
 // CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
-// CHECK-VEC1:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC1:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC1:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
 // CHECK-VEC1:         %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
 // CHECK-VEC1:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -49,31 +57,19 @@
 // CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
 // CHECK-VEC2-DAG:   %[[c1024:.*]] = constant 1024 : index
 // CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
-// CHECK-VEC2:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC2:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC2:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
 // CHECK-VEC2:         %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
 // CHECK-VEC2:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
-// CHECK-VEC3-LABEL: func @scale_d
-// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
-// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
-// CHECK-VEC3-DAG:   %[[c1024:.*]] = constant 1024 : index
-// CHECK-VEC3:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
-// CHECK-VEC3:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
-// CHECK-VEC3:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
-// CHECK-VEC3:         %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
-// CHECK-VEC3:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
-// CHECK-VEC3:       }
-// CHECK-VEC3:       return
-//
-func @scale_d(%arga: tensor<1024xf32>, %scale: f32, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
+func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
   %0 = linalg.generic #trait_scale_d
-    ins(%arga: tensor<1024xf32>)
+    ins(%arga: tensor<1024xf32, #DenseVector>)
     outs(%argx: tensor<1024xf32>) {
       ^bb(%a: f32, %x: f32):
-        %0 = mulf %a, %scale : f32
+        %0 = mulf %a, %b : f32
         linalg.yield %0 : f32
   } -> tensor<1024xf32>
   return %0 : tensor<1024xf32>
@@ -85,11 +81,6 @@ func @scale_d(%arga: tensor<1024xf32>, %scale: f32, %argx: tensor<1024xf32>) ->
     affine_map<(i) -> (i)>,  // b
     affine_map<(i) -> (i)>   // x (out)
   ],
-  sparse = [
-    [ "S" ],  // a
-    [ "D" ],  // b
-    [ "D" ]   // x
-  ],
   iterator_types = ["parallel"],
   doc = "x(i) = a(i) * b(i)"
 }
@@ -157,81 +148,9 @@ func @scale_d(%arga: tensor<1024xf32>, %scale: f32, %argx: tensor<1024xf32>) ->
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
-// CHECK-VEC3-LABEL: func @mul_s
-// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
-// CHECK-VEC3-DAG:   %[[c1:.*]] = constant 1 : index
-// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
-// CHECK-VEC3:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xindex>
-// CHECK-VEC3:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xindex>
-// CHECK-VEC3:       scf.for %[[i:.*]] = %[[p]] to %[[r]] step %[[c16]] {
-// CHECK-VEC3:         %[[sub:.*]] = subi %[[r]], %[[i]] : index
-// CHECK-VEC3:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
-// CHECK-VEC3:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
-// CHECK-VEC3:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC3:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC3:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
-// CHECK-VEC3:         vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
-// CHECK-VEC3:       }
-// CHECK-VEC3:       return
-//
-func @mul_s(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
-  %0 = linalg.generic #trait_mul_s
-    ins(%arga, %argb: tensor<1024xf32>, tensor<1024xf32>)
-    outs(%argx: tensor<1024xf32>) {
-      ^bb(%a: f32, %b: f32, %x: f32):
-        %0 = mulf %a, %b : f32
-        linalg.yield %0 : f32
-  } -> tensor<1024xf32>
-  return %0 : tensor<1024xf32>
-}
-
-//
-// CHECK-VEC2-LABEL: func @mul_s_alt
-// CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
-// CHECK-VEC2-DAG:   %[[c1:.*]] = constant 1 : index
-// CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
-// CHECK-VEC2:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
-// CHECK-VEC2:       %[[a:.*]] = zexti %[[p]] : i32 to i64
-// CHECK-VEC2:       %[[q:.*]] = index_cast %[[a]] : i64 to index
-// CHECK-VEC2:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
-// CHECK-VEC2:       %[[b:.*]] = zexti %[[r]] : i32 to i64
-// CHECK-VEC2:       %[[s:.*]] = index_cast %[[b]] : i64 to index
-// CHECK-VEC2:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
-// CHECK-VEC2:         %[[sub:.*]] = subi %[[s]], %[[i]] : index
-// CHECK-VEC2:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
-// CHECK-VEC2:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
-// CHECK-VEC2:         %[[zi:.*]] = zexti %[[li]] : vector<16xi32> to vector<16xi64>
-// CHECK-VEC2:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC2:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC2:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
-// CHECK-VEC2:         vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
-// CHECK-VEC2:       }
-// CHECK-VEC2:       return
-//
-// CHECK-VEC3-LABEL: func @mul_s_alt
-// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
-// CHECK-VEC3-DAG:   %[[c1:.*]] = constant 1 : index
-// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
-// CHECK-VEC3:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xindex>
-// CHECK-VEC3:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xindex>
-// CHECK-VEC3:       scf.for %[[i:.*]] = %[[p]] to %[[r]] step %[[c16]] {
-// CHECK-VEC3:         %[[sub:.*]] = subi %[[r]], %[[i]] : index
-// CHECK-VEC3:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
-// CHECK-VEC3:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
-// CHECK-VEC3:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC3:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC3:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
-// CHECK-VEC3:         vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
-// CHECK-VEC3:       }
-// CHECK-VEC3:       return
-//
-//
-!SparseTensor = type !llvm.ptr<i8>
-func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
-  %arga = sparse_tensor.fromPtr %argA : !SparseTensor to tensor<1024xf32>
-  %argb = sparse_tensor.fromPtr %argB : !SparseTensor to tensor<1024xf32>
+func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
   %0 = linalg.generic #trait_mul_s
-    ins(%arga, %argb: tensor<1024xf32>, tensor<1024xf32>)
+    ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>)
     outs(%argx: tensor<1024xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -246,11 +165,6 @@ func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf
     affine_map<(i) -> (i)>,  // b
     affine_map<(i) -> ()>    // x (out)
   ],
-  sparse = [
-    [ "D" ],  // a
-    [ "D" ],  // b
-    [     ]   // x
-  ],
   iterator_types = ["reduction"],
   doc = "x += a(i) * b(i)"
 }
@@ -261,7 +175,7 @@ func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf
 // CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
 // CHECK-VEC0-DAG:   %[[c1024:.*]] = constant 1024 : index
 // CHECK-VEC0:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) {
-// CHECK-VEC0:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32>
+// CHECK-VEC0:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
 // CHECK-VEC0:         %[[lb:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32>
 // CHECK-VEC0:         %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
 // CHECK-VEC0:         %[[a:.*]] = addf %[[red_in]], %[[m]] : f32
@@ -275,7 +189,7 @@ func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf
 // CHECK-VEC1-DAG:   %[[c1024:.*]] = constant 1024 : index
 // CHECK-VEC1-DAG:   %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC1:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) {
-// CHECK-VEC1:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC1:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC1:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
 // CHECK-VEC1:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
 // CHECK-VEC1:         %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
@@ -290,7 +204,7 @@ func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf
 // CHECK-VEC2-DAG:   %[[c1024:.*]] = constant 1024 : index
 // CHECK-VEC2-DAG:   %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC2:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) {
-// CHECK-VEC2:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC2:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC2:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
 // CHECK-VEC2:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
 // CHECK-VEC2:         %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
@@ -299,55 +213,9 @@ func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf
 // CHECK-VEC2:       %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32
 // CHECK-VEC2:       return
 //
-// CHECK-VEC3-LABEL: func @reduction_d
-// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
-// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
-// CHECK-VEC3-DAG:   %[[c1024:.*]] = constant 1024 : index
-// CHECK-VEC3-DAG:   %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
-// CHECK-VEC3:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) {
-// CHECK-VEC3:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
-// CHECK-VEC3:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
-// CHECK-VEC3:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
-// CHECK-VEC3:         %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
-// CHECK-VEC3:         scf.yield %[[a]] : vector<16xf32>
-// CHECK-VEC3:       }
-// CHECK-VEC3:       %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32
-// CHECK-VEC3:       return
-//
-func @reduction_d(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tensor<f32>) -> tensor<f32> {
+func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>, %argx: tensor<f32>) -> tensor<f32> {
   %0 = linalg.generic #trait_reduction_d
-    ins(%arga, %argb: tensor<1024xf32>, tensor<1024xf32>)
-    outs(%argx: tensor<f32>) {
-      ^bb(%a: f32, %b: f32, %x: f32):
-        %0 = mulf %a, %b : f32
-        %1 = addf %x, %0 : f32
-        linalg.yield %1 : f32
-  } -> tensor<f32>
-  return %0 : tensor<f32>
-}
-
-//
-// CHECK-VEC1-LABEL: func @reduction_17
-// CHECK-VEC1-DAG:   %[[c0:.*]] = constant 0 : index
-// CHECK-VEC1-DAG:   %[[c16:.*]] = constant 16 : index
-// CHECK-VEC1-DAG:   %[[c17:.*]] = constant 17 : index
-// CHECK-VEC1-DAG:   %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
-// CHECK-VEC1:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c17]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) {
-// CHECK-VEC1:         %[[sub:.*]] = subi %[[c17]], %[[i]] : index
-// CHECK-VEC1:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
-// CHECK-VEC1:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<17xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC1:         %[[lb:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<17xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC1:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
-// CHECK-VEC1:         %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
-// CHECK-VEC1:         %[[s:.*]] = select %[[mask]], %[[a]], %[[red_in]] : vector<16xi1>, vector<16xf32>
-// CHECK-VEC1:         scf.yield %[[s]] : vector<16xf32>
-// CHECK-VEC1:       }
-// CHECK-VEC1:       %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32
-// CHECK-VEC1:       return
-//
-func @reduction_17(%arga: tensor<17xf32>, %argb: tensor<17xf32>, %argx: tensor<f32>) -> tensor<f32> {
-  %0 = linalg.generic #trait_reduction_d
-    ins(%arga, %argb: tensor<17xf32>, tensor<17xf32>)
+    ins(%arga, %argb: tensor<1024xf32, #DenseVector>, tensor<1024xf32>)
     outs(%argx: tensor<f32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -359,17 +227,12 @@ func @reduction_17(%arga: tensor<17xf32>, %argb: tensor<17xf32>, %argx: tensor<f
 
 #trait_mul_ds = {
   indexing_maps = [
-    affine_map<(i,j) -> (i,j)>,  // a
-    affine_map<(i,j) -> (i,j)>,  // b
-    affine_map<(i,j) -> (i,j)>   // x (out)
-  ],
-  sparse = [
-    [ "D", "S" ],  // a
-    [ "D", "D" ],  // b
-    [ "D", "D" ]   // x
+    affine_map<(i,j) -> (i,j)>,  // A
+    affine_map<(i,j) -> (i,j)>,  // B
+    affine_map<(i,j) -> (i,j)>   // X (out)
   ],
   iterator_types = ["parallel", "parallel"],
-  doc = "x(i,j) = a(i,j) * b(i,j)"
+  doc = "X(i,j) = A(i,j) * B(i,j)"
 }
 
 //
@@ -447,30 +310,9 @@ func @reduction_17(%arga: tensor<17xf32>, %argb: tensor<17xf32>, %argx: tensor<f
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
-// CHECK-VEC3-LABEL: func @mul_ds
-// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
-// CHECK-VEC3-DAG:   %[[c1:.*]] = constant 1 : index
-// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
-// CHECK-VEC3-DAG:   %[[c512:.*]] = constant 512 : index
-// CHECK-VEC3:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
-// CHECK-VEC3:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
-// CHECK-VEC3:         %[[a:.*]] = addi %[[i]], %[[c1]] : index
-// CHECK-VEC3:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xindex>
-// CHECK-VEC3:         scf.for %[[j:.*]] = %[[p]] to %[[r]] step %[[c16]] {
-// CHECK-VEC3:           %[[sub:.*]] = subi %[[r]], %[[j]] : index
-// CHECK-VEC3:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
-// CHECK-VEC3:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
-// CHECK-VEC3:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC3:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC3:           %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
-// CHECK-VEC3:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
-// CHECK-VEC3:         }
-// CHECK-VEC3:       }
-// CHECK-VEC3:       return
-//
-func @mul_ds(%arga: tensor<512x1024xf32>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
+func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
   %0 = linalg.generic #trait_mul_ds
-    ins(%arga, %argb: tensor<512x1024xf32>, tensor<512x1024xf32>)
+    ins(%arga, %argb: tensor<512x1024xf32, #SparseMatrix>, tensor<512x1024xf32>)
     outs(%argx: tensor<512x1024xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32
@@ -478,4 +320,3 @@ func @mul_ds(%arga: tensor<512x1024xf32>, %argb: tensor<512x1024xf32>, %argx: te
   } -> tensor<512x1024xf32>
   return %0 : tensor<512x1024xf32>
 }
-

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matvec.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matvec.mlir
index 5a73924c84250..712b6c533026b 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matvec.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matvec.mlir
@@ -1,17 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN:   --sparsification="ptr-type=4 ind-type=4" --sparse-tensor-conversion \
-// RUN:   --convert-linalg-to-loops --convert-vector-to-scf --convert-scf-to-std \
-// RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
-// RUN:   --std-bufferize --finalizing-bufferize  \
-// RUN:   --convert-vector-to-llvm --convert-std-to-llvm | \
-// RUN: TENSOR0="%mlir_integration_test_dir/data/wide.mtx" \
-// RUN: mlir-cpu-runner \
-// RUN:  -e entry -entry-point-result=void  \
-// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s
-//
-// RUN: mlir-opt %s \
-// RUN:   --sparsification="vectorization-strategy=2 ptr-type=4 ind-type=4 vl=16" --sparse-tensor-conversion \
+// RUN:   --sparsification --sparse-tensor-conversion \
 // RUN:   --convert-linalg-to-loops --convert-vector-to-scf --convert-scf-to-std \
 // RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
 // RUN:   --std-bufferize --finalizing-bufferize  \
@@ -22,11 +10,13 @@
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
-//
-// Use descriptive names for opaque pointers.
-//
-!Filename     = type !llvm.ptr<i8>
-!SparseTensor = type !llvm.ptr<i8>
+!Filename = type !llvm.ptr<i8>
+
+#SparseMatrix = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  pointerBitWidth = 8,
+  indexBitWidth = 8
+}>
 
 #matvec = {
   indexing_maps = [
@@ -34,11 +24,6 @@
     affine_map<(i,j) -> (j)>,   // b
     affine_map<(i,j) -> (i)>    // x (out)
   ],
-  sparse = [
-    [ "D", "S" ], // A
-    [ "D"      ], // b
-    [ "D"      ]  // x
-  ],
   iterator_types = ["parallel", "reduction"],
   doc = "X(i) += A(i,j) * B(j)"
 }
@@ -50,15 +35,14 @@
 //
 module {
   //
-  // The kernel expressed as an annotated Linalg op. The kernel multiplies
-  // a sparse matrix A with a dense vector b into a dense vector x.
+  // A kernel that multiplies a sparse matrix A with a dense vector b
+  // into a dense vector x.
   //
-  func @kernel_matvec(%argA: !SparseTensor,
+  func @kernel_matvec(%arga: tensor<?x?xi32, #SparseMatrix>,
                       %argb: tensor<?xi32>,
                       %argx: tensor<?xi32>) -> tensor<?xi32> {
-    %arga = sparse_tensor.fromPtr %argA : !SparseTensor to tensor<?x?xi32>
     %0 = linalg.generic #matvec
-      ins(%arga, %argb: tensor<?x?xi32>, tensor<?xi32>)
+      ins(%arga, %argb: tensor<?x?xi32, #SparseMatrix>, tensor<?xi32>)
       outs(%argx: tensor<?xi32>) {
       ^bb(%a: i32, %b: i32, %x: i32):
         %0 = muli %a, %b : i32
@@ -68,12 +52,7 @@ module {
     return %0 : tensor<?xi32>
   }
 
-  //
-  // Runtime support library that is called directly from here.
-  //
   func private @getTensorFilename(index) -> (!Filename)
-  func private @newSparseTensor(!Filename, memref<?xi1>, index, index, index) -> (!SparseTensor)
-  func private @delSparseTensor(!SparseTensor) -> ()
 
   //
   // Main driver that reads matrix from file and calls the sparse kernel.
@@ -82,27 +61,12 @@ module {
     %i0 = constant 0 : i32
     %c0 = constant 0 : index
     %c1 = constant 1 : index
-    %c2 = constant 2 : index
     %c4 = constant 4 : index
     %c256 = constant 256 : index
 
-    // Mark inner dimension of the matrix as sparse and encode the
-    // storage scheme types (this must match the metadata in the
-    // alias above and compiler switches). In this case, we test
-    // that 8-bit indices and pointers work correctly on a matrix
-    // with i32 elements.
-    %annotations = memref.alloc(%c2) : memref<?xi1>
-    %sparse = constant true
-    %dense = constant false
-    memref.store %dense, %annotations[%c0] : memref<?xi1>
-    memref.store %sparse, %annotations[%c1] : memref<?xi1>
-    %u8 = constant 4 : index
-    %i32 = constant 3 : index
-
     // Read the sparse matrix from file, construct sparse storage.
     %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
-    %a = call @newSparseTensor(%fileName, %annotations, %u8, %u8, %i32)
-      : (!Filename, memref<?xi1>, index, index, index) -> (!SparseTensor)
+    %a = sparse_tensor.new %fileName : !llvm.ptr<i8> to tensor<?x?xi32, #SparseMatrix>
 
     // Initialize dense vectors.
     %bdata = memref.alloc(%c256) : memref<?xi32>
@@ -120,7 +84,7 @@ module {
 
     // Call kernel.
     %0 = call @kernel_matvec(%a, %b, %x)
-      : (!SparseTensor, tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
+      : (tensor<?x?xi32, #SparseMatrix>, tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
 
     // Print the result for verification.
     //
@@ -131,7 +95,6 @@ module {
     vector.print %v : vector<4xi32>
 
     // Release the resources.
-    call @delSparseTensor(%a) : (!SparseTensor) -> ()
     memref.dealloc %bdata : memref<?xi32>
     memref.dealloc %xdata : memref<?xi32>
 

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir
index ee4d56c547cde..17d37e1839908 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s \
-// RUN:   --sparsification="ptr-type=2 ind-type=2 fast-output" --sparse-tensor-conversion \
+// RUN:   --sparsification="fast-output" --sparse-tensor-conversion \
 // RUN:   --convert-linalg-to-loops --convert-vector-to-scf --convert-scf-to-std \
 // RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
 // RUN:   --std-bufferize --finalizing-bufferize  \
@@ -10,11 +10,13 @@
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
-//
-// Use descriptive names for opaque pointers.
-//
-!Filename     = type !llvm.ptr<i8>
-!SparseTensor = type !llvm.ptr<i8>
+!Filename = type !llvm.ptr<i8>
+
+#SparseMatrix = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ],
+  pointerBitWidth = 32,
+  indexBitWidth = 32
+}>
 
 #trait_sampled_dense_dense = {
   indexing_maps = [
@@ -23,12 +25,6 @@
     affine_map<(i,j,k) -> (k,j)>,  // B
     affine_map<(i,j,k) -> (i,j)>   // X (out)
   ],
-  sparse = [
-    [ "S", "S" ],  // S
-    [ "D", "D" ],  // A
-    [ "D", "D" ],  // B
-    [ "D", "D" ]   // X
-  ],
   iterator_types = ["parallel", "parallel", "reduction"],
   doc = "X(i,j) += S(i,j) SUM_k A(i,k) B(k,j)"
 }
@@ -40,16 +36,14 @@
 //
 module {
   //
-  // The kernel expressed as an annotated Linalg op. The kernel
-  // computes a sampled matrix matrix multiplication.
+  // A kernel that computes a sampled matrix matrix multiplication.
   //
-  func @sampled_dense_dense(%argS: !SparseTensor,
+  func @sampled_dense_dense(%args: tensor<?x?xf32, #SparseMatrix>,
                             %arga: tensor<?x?xf32>,
                             %argb: tensor<?x?xf32>,
                             %argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
-    %args = sparse_tensor.fromPtr %argS : !SparseTensor to tensor<?x?xf32>
     %0 = linalg.generic #trait_sampled_dense_dense
-      ins(%args, %arga, %argb: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
+      ins(%args, %arga, %argb: tensor<?x?xf32, #SparseMatrix>, tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%argx: tensor<?x?xf32>) {
         ^bb(%s: f32, %a: f32, %b: f32, %x: f32):
           %0 = mulf %a, %b : f32
@@ -60,12 +54,7 @@ module {
     return %0 : tensor<?x?xf32>
   }
 
-  //
-  // Runtime support library that is called directly from here.
-  //
   func private @getTensorFilename(index) -> (!Filename)
-  func private @newSparseTensor(!Filename, memref<?xi1>, index, index, index) -> (!SparseTensor)
-  func private @delSparseTensor(!SparseTensor) -> ()
 
   //
   // Main driver that reads matrix from file and calls the sparse kernel.
@@ -74,20 +63,9 @@ module {
     %d0 = constant 0.0 : f32
     %c0 = constant 0 : index
     %c1 = constant 1 : index
-    %c2 = constant 2 : index
     %c5 = constant 5 : index
     %c10 = constant 10 : index
 
-    // Mark both dimensions of the matrix as sparse and encode the
-    // storage scheme types (this must match the metadata in the
-    // trait and compiler switches).
-    %annotations = memref.alloc(%c2) : memref<?xi1>
-    %sparse = constant true
-    memref.store %sparse, %annotations[%c0] : memref<?xi1>
-    memref.store %sparse, %annotations[%c1] : memref<?xi1>
-    %i32 = constant 2 : index
-    %f32 = constant 2 : index
-
     // Setup memory for the dense matrices and initialize.
     %adata = memref.alloc(%c5, %c10) : memref<?x?xf32>
     %bdata = memref.alloc(%c10, %c5) : memref<?x?xf32>
@@ -108,13 +86,14 @@ module {
     %b = memref.tensor_load %bdata : memref<?x?xf32>
     %x = memref.tensor_load %xdata : memref<?x?xf32>
 
-    // Read the sparse matrix from file, construct sparse storage
-    // according to <sparse,sparse> in memory, and call the kernel.
+    // Read the sparse matrix from file, construct sparse storage.
     %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
-    %s = call @newSparseTensor(%fileName, %annotations, %i32, %i32, %f32)
-      : (!Filename, memref<?xi1>, index, index, index) -> (!SparseTensor)
+    %s = sparse_tensor.new %fileName : !llvm.ptr<i8> to tensor<?x?xf32, #SparseMatrix>
+
+    // Call the kernel.
     %0 = call @sampled_dense_dense(%s, %a, %b, %x)
-       : (!SparseTensor, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+       : (tensor<?x?xf32, #SparseMatrix>,
+          tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
 
     // Print the result for verification.
     //
@@ -131,7 +110,6 @@ module {
     }
 
     // Release the resources.
-    call @delSparseTensor(%s) : (!SparseTensor) -> ()
     memref.dealloc %adata : memref<?x?xf32>
     memref.dealloc %bdata : memref<?x?xf32>
     memref.dealloc %xdata : memref<?x?xf32>

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir
index 4e1d44dcc2b85..aad77c7de4d7c 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir
@@ -10,21 +10,17 @@
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
-//
-// Use descriptive names for opaque pointers.
-//
-!Filename     = type !llvm.ptr<i8>
-!SparseTensor = type !llvm.ptr<i8>
+!Filename = type !llvm.ptr<i8>
+
+#SparseMatrix = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ]
+}>
 
 #trait_sum_reduce = {
   indexing_maps = [
     affine_map<(i,j) -> (i,j)>, // A
     affine_map<(i,j) -> ()>     // x (out)
   ],
-  sparse = [
-    [ "S", "S" ], // A
-    [          ]  // x
-  ],
   iterator_types = ["reduction", "reduction"],
   doc = "x += A(i,j)"
 }
@@ -36,14 +32,12 @@
 //
 module {
   //
-  // The kernel expressed as an annotated Linalg op. The kernel
-  // sum reduces a matrix to a single scalar.
+  // A kernel that sum-reduces a matrix to a single scalar.
   //
-  func @kernel_sum_reduce(%argA: !SparseTensor,
+  func @kernel_sum_reduce(%arga: tensor<?x?xf64, #SparseMatrix>,
                           %argx: tensor<f64>) -> tensor<f64> {
-    %arga = sparse_tensor.fromPtr %argA : !SparseTensor to tensor<?x?xf64>
     %0 = linalg.generic #trait_sum_reduce
-      ins(%arga: tensor<?x?xf64>)
+      ins(%arga: tensor<?x?xf64, #SparseMatrix>)
       outs(%argx: tensor<f64>) {
       ^bb(%a: f64, %x: f64):
         %0 = addf %x, %a : f64
@@ -52,12 +46,7 @@ module {
     return %0 : tensor<f64>
   }
 
-  //
-  // Runtime support library that is called directly from here.
-  //
   func private @getTensorFilename(index) -> (!Filename)
-  func private @newSparseTensor(!Filename, memref<?xi1>, index, index, index) -> (!SparseTensor)
-  func private @delSparseTensor(!SparseTensor) -> ()
 
   //
   // Main driver that reads matrix from file and calls the sparse kernel.
@@ -65,18 +54,6 @@ module {
   func @entry() {
     %d0 = constant 0.0 : f64
     %c0 = constant 0 : index
-    %c1 = constant 1 : index
-    %c2 = constant 2 : index
-
-    // Mark both dimensions of the matrix as sparse and encode the
-    // storage scheme types (this must match the metadata in the
-    // trait and compiler switches).
-    %annotations = memref.alloc(%c2) : memref<?xi1>
-    %sparse = constant true
-    memref.store %sparse, %annotations[%c0] : memref<?xi1>
-    memref.store %sparse, %annotations[%c1] : memref<?xi1>
-    %i64 = constant 1 : index
-    %f64 = constant 1 : index
 
     // Setup memory for a single reduction scalar,
     // initialized to zero.
@@ -84,13 +61,13 @@ module {
     memref.store %d0, %xdata[] : memref<f64>
     %x = memref.tensor_load %xdata : memref<f64>
 
-    // Read the sparse matrix from file, construct sparse storage
-    // according to <sparse,sparse> in memory, and call the kernel.
+    // Read the sparse matrix from file, construct sparse storage.
     %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
-    %a = call @newSparseTensor(%fileName, %annotations, %i64, %i64, %f64)
-      : (!Filename, memref<?xi1>, index, index, index) -> (!SparseTensor)
+    %a = sparse_tensor.new %fileName : !llvm.ptr<i8> to tensor<?x?xf64, #SparseMatrix>
+
+    // Call the kernel.
     %0 = call @kernel_sum_reduce(%a, %x)
-      : (!SparseTensor, tensor<f64>) -> tensor<f64>
+      : (tensor<?x?xf64, #SparseMatrix>, tensor<f64>) -> tensor<f64>
 
     // Print the result for verification.
     //
@@ -101,7 +78,6 @@ module {
     vector.print %v : f64
 
     // Release the resources.
-    call @delSparseTensor(%a) : (!SparseTensor) -> ()
     memref.dealloc %xdata : memref<f64>
 
     return


        


More information about the Mlir-commits mailing list