[Mlir-commits] [mlir] b8a021d - [mlir][sparse] support for negation and subtractions

Aart Bik llvmlistbot at llvm.org
Fri Jul 2 15:55:20 PDT 2021


Author: Aart Bik
Date: 2021-07-02T15:55:05-07:00
New Revision: b8a021dbe322c4fae196318df7c0ebb2dd0f4a31

URL: https://github.com/llvm/llvm-project/commit/b8a021dbe322c4fae196318df7c0ebb2dd0f4a31
DIFF: https://github.com/llvm/llvm-project/commit/b8a021dbe322c4fae196318df7c0ebb2dd0f4a31.diff

LOG: [mlir][sparse] support for negation and subtractions

This revision extends the sparse compiler support from fp/int addition and multiplication to fp/int negation and subtraction, thereby increasing the scope of sparse kernels that can be compiled.

Reviewed By: gussmith23

Differential Revision: https://reviews.llvm.org/D105306

Added: 
    mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
    mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 4141c68a5e37..d7496b30a1c3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -20,34 +20,33 @@
 namespace mlir {
 namespace sparse_tensor {
 
-/// Tensor expression kind.
-enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
-
 /// Dimension level type for a tensor (undef means index does not appear).
 enum class Dim { kSparse, kDense, kSingle, kUndef };
 
-/// Children expressions of a binary TensorExp.
+/// Tensor expression kind.
+enum class Kind {
+  // Leaf.
+  kTensor,
+  kInvariant,
+  kZero,
+  // Operation.
+  kMulF,
+  kMulI,
+  kAddF,
+  kAddI,
+  kSubF,
+  kSubI
+};
+
+/// Children subexpressions of tensor operations.
 struct Children {
   unsigned e0;
   unsigned e1;
 };
 
 /// Tensor expression. Represents a MLIR expression in tensor index notation.
-/// For tensors, e0 denotes the tensor index. For invariants, the IR value is
-/// stored directly. For binary operations, e0 and e1 denote the index of the
-/// children tensor expressions.
 struct TensorExp {
-  TensorExp(Kind k, unsigned x, unsigned y, Value v) : kind(k), val(v) {
-    assert((kind == Kind::kTensor && x != -1u && y == -1u && !val) ||
-           (kind == Kind::kInvariant && x == -1u && y == -1u && val) ||
-           (kind >= Kind::kMulF && x != -1u && y != -1u && !val));
-    if (kind == Kind::kTensor) {
-      tensor = x;
-    } else if (kind >= Kind::kMulF) {
-      children.e0 = x;
-      children.e1 = y;
-    }
-  }
+  TensorExp(Kind k, unsigned x, unsigned y, Value v);
 
   /// Tensor expression kind.
   Kind kind;
@@ -56,7 +55,7 @@ struct TensorExp {
     /// Expressions representing tensors simply have a tensor number.
     unsigned tensor;
 
-    /// Binary operations hold the indices of their child expressions.
+    /// Tensor operations hold the indices of their children.
     Children children;
   };
 
@@ -69,10 +68,8 @@ struct TensorExp {
 /// loop indices (encoded in a bitvector) and the index of the corresponding
 /// tensor expression.
 struct LatPoint {
-  LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
-    bits.set(b);
-  }
-  LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
+  LatPoint(unsigned n, unsigned e, unsigned b);
+  LatPoint(const llvm::BitVector &b, unsigned e);
 
   /// Conjunction of tensor loop indices as bitvector. This represents
   /// all indices involved in the tensor expression
@@ -103,7 +100,8 @@ class Merger {
         dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
 
   /// Adds a tensor expression. Returns its index.
-  unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value());
+  unsigned addExp(Kind k, unsigned e0 = -1u, unsigned e1 = -1u,
+                  Value v = Value());
   unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
 
   /// Adds an iteration lattice point. Returns its index.
@@ -126,6 +124,12 @@ class Merger {
   /// Returns the index of the new set.
   unsigned takeDisj(Kind kind, unsigned s0, unsigned s1);
 
+  /// Maps a zero operand over a lattice set, i.e. each lattice point on an
+  /// expression E is simply copied over, but with 0 OP E as new expression.
+  /// This is useful to deal with disjunctive, but non-commutative operators.
+  /// Returns the index of the new set.
+  unsigned mapZero(Kind kind, unsigned s0);
+
   /// Optimizes the iteration lattice points in the given set. This
   /// method should be called right before code generation to avoid
   /// generating redundant loops and conditions.
@@ -135,7 +139,7 @@ class Merger {
   /// within the given set using just two basic rules:
   /// (1) multiple dense conditions are reduced to single dense, and
   /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
-  llvm::BitVector simplifyCond(unsigned s, unsigned p0);
+  llvm::BitVector simplifyCond(unsigned s0, unsigned p0);
 
   /// Returns true if Li > Lj.
   bool latGT(unsigned i, unsigned j) const;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 813fe683ae61..775f3a140f82 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -624,24 +624,36 @@ static void genReductionEnd(Merger &merger, CodeGen &codegen,
 /// Recursively generates tensor expression.
 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
                     linalg::GenericOp op, unsigned exp) {
+  Location loc = op.getLoc();
   if (merger.exp(exp).kind == Kind::kTensor)
     return genTensorLoad(merger, codegen, rewriter, op, exp);
-  else if (merger.exp(exp).kind == Kind::kInvariant)
+  if (merger.exp(exp).kind == Kind::kInvariant)
+    return genInvariantValue(merger, codegen, rewriter, exp);
+  if (merger.exp(exp).kind == Kind::kZero) {
+    Type tp = op.getOutputTensorTypes()[0].getElementType();
+    merger.exp(exp).val =
+        rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
     return genInvariantValue(merger, codegen, rewriter, exp);
+  }
   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0);
   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1);
   switch (merger.exp(exp).kind) {
   case Kind::kTensor:
   case Kind::kInvariant:
+  case Kind::kZero:
     llvm_unreachable("handled above");
   case Kind::kMulF:
-    return rewriter.create<MulFOp>(op.getLoc(), v0, v1);
+    return rewriter.create<MulFOp>(loc, v0, v1);
   case Kind::kMulI:
-    return rewriter.create<MulIOp>(op.getLoc(), v0, v1);
+    return rewriter.create<MulIOp>(loc, v0, v1);
   case Kind::kAddF:
-    return rewriter.create<AddFOp>(op.getLoc(), v0, v1);
+    return rewriter.create<AddFOp>(loc, v0, v1);
   case Kind::kAddI:
-    return rewriter.create<AddIOp>(op.getLoc(), v0, v1);
+    return rewriter.create<AddIOp>(loc, v0, v1);
+  case Kind::kSubF:
+    return rewriter.create<SubFOp>(loc, v0, v1);
+  case Kind::kSubI:
+    return rewriter.create<SubIOp>(loc, v0, v1);
   }
   llvm_unreachable("unexpected expression kind");
 }
@@ -671,7 +683,8 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
       merger.exp(exp).val =
           hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
     }
-  } else if (merger.exp(exp).kind != Kind::kInvariant) {
+  } else if (merger.exp(exp).kind != Kind::kInvariant &&
+             merger.exp(exp).kind != Kind::kZero) {
     // Traverse into the binary operations. Note that we only hoist
     // tensor loads, since subsequent MLIR/LLVM passes know how to
     // deal with all other kinds of derived loop invariants.

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 6150c15a0ad1..2a1ad9ad56df 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -14,6 +14,39 @@
 namespace mlir {
 namespace sparse_tensor {
 
+//
+// Constructors.
+//
+
+TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
+    : kind(k), val(v) {
+  switch (kind) {
+  case Kind::kTensor:
+    assert(x != -1u && y == -1u && !v);
+    tensor = x;
+    break;
+  case Kind::kInvariant:
+    assert(x == -1u && y == -1u && v);
+    break;
+  case Kind::kZero:
+    assert(x == -1u && y == -1u && !v);
+    break;
+  default:
+    assert(x != -1u && y != -1u && !v);
+    children.e0 = x;
+    children.e1 = y;
+    break;
+  }
+}
+
+LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
+    : bits(n, false), simple(), exp(e) {
+  bits.set(b);
+}
+
+LatPoint::LatPoint(const llvm::BitVector &b, unsigned e)
+    : bits(b), simple(), exp(e) {}
+
 //
 // Lattice methods.
 //
@@ -56,13 +89,28 @@ unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
 
 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
   unsigned s = takeConj(kind, s0, s1);
+  // Followed by all in s0 and s1.
   for (unsigned p : latSets[s0])
     latSets[s].push_back(p);
+  if (Kind::kSubF <= kind && kind <= Kind::kSubI)
+    s1 = mapZero(kind, s1);
   for (unsigned p : latSets[s1])
     latSets[s].push_back(p);
   return s;
 }
 
+unsigned Merger::mapZero(Kind kind, unsigned s0) {
+  assert(Kind::kSubF <= kind && kind <= Kind::kSubI);
+  unsigned s = addSet();
+  unsigned z = addExp(Kind::kZero);
+  for (unsigned p : latSets[s0]) {
+    unsigned e = addExp(kind, z, latPoints[p].exp);
+    latPoints.push_back(LatPoint(latPoints[p].bits, e));
+    latSets[s].push_back(latPoints.size() - 1);
+  }
+  return s;
+}
+
 unsigned Merger::optimizeSet(unsigned s0) {
   unsigned s = addSet();
   assert(latSets[s0].size() != 0);
@@ -93,11 +141,11 @@ unsigned Merger::optimizeSet(unsigned s0) {
   return s;
 }
 
-llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) {
+llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
   // First determine if this lattice point is a *singleton*, i.e.,
   // the last point in a lattice, no other is less than this one.
   bool isSingleton = true;
-  for (unsigned p1 : latSets[s]) {
+  for (unsigned p1 : latSets[s0]) {
     if (p0 != p1 && latGT(p0, p1)) {
       isSingleton = false;
       break;
@@ -148,6 +196,23 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
 // Print methods (for debugging).
 //
 
+static char kindToOpSymbol(Kind kind) {
+  switch (kind) {
+  case Kind::kMulF:
+  case Kind::kMulI:
+    return '*';
+  case Kind::kAddF:
+  case Kind::kAddI:
+    return '+';
+  case Kind::kSubF:
+  case Kind::kSubI:
+    return '-';
+  default:
+    break;
+  }
+  llvm_unreachable("unexpected kind");
+}
+
 void Merger::dumpExp(unsigned e) const {
   switch (tensorExps[e].kind) {
   case Kind::kTensor:
@@ -160,22 +225,15 @@ void Merger::dumpExp(unsigned e) const {
   case Kind::kInvariant:
     llvm::dbgs() << "invariant";
     break;
-  default:
-  case Kind::kMulI:
-    llvm::dbgs() << "(";
-    dumpExp(tensorExps[e].children.e0);
-    llvm::dbgs() << " * ";
-    dumpExp(tensorExps[e].children.e1);
-    llvm::dbgs() << ")";
+  case Kind::kZero:
+    llvm::dbgs() << "zero";
     break;
-  case Kind::kAddF:
-  case Kind::kAddI:
+  default:
     llvm::dbgs() << "(";
     dumpExp(tensorExps[e].children.e0);
-    llvm::dbgs() << " + ";
+    llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
     dumpExp(tensorExps[e].children.e1);
     llvm::dbgs() << ")";
-    break;
   }
 }
 
@@ -184,7 +242,7 @@ void Merger::dumpLat(unsigned p) const {
   dumpBits(latPoints[p].bits);
   llvm::dbgs() << " :";
   dumpBits(latPoints[p].simple);
-  llvm::dbgs() << " / ";
+  llvm::dbgs() << " : ";
   dumpExp(latPoints[p].exp);
   llvm::dbgs() << " )\n";
 }
@@ -230,28 +288,34 @@ void Merger::dumpBits(const llvm::BitVector &bits) const {
 
 unsigned Merger::buildLattices(unsigned e, unsigned idx) {
   Kind kind = tensorExps[e].kind;
-  if (kind == Kind::kTensor || kind == Kind::kInvariant) {
+  switch (kind) {
+  case Kind::kTensor:
+  case Kind::kInvariant:
+  case Kind::kZero: {
     // Either the index is really used in the tensor expression, or it is
     // set to the undefined index in that dimension. An invariant expression
     // is set to a synthetic tensor with undefined indices only.
     unsigned s = addSet();
-    unsigned t =
-        kind == Kind::kTensor ? tensorExps[e].children.e0 : syntheticTensor;
+    unsigned t = kind == Kind::kTensor ? tensorExps[e].tensor : syntheticTensor;
     latSets[s].push_back(addLat(t, idx, e));
     return s;
   }
-  unsigned s0 = buildLattices(tensorExps[e].children.e0, idx);
-  unsigned s1 = buildLattices(tensorExps[e].children.e1, idx);
-  switch (kind) {
-  case Kind::kTensor:
-  case Kind::kInvariant:
-    llvm_unreachable("handled above");
   case Kind::kMulF:
   case Kind::kMulI:
-    return takeConj(kind, s0, s1);
+    return takeConj(kind, // take binary conjunction
+                    buildLattices(tensorExps[e].children.e0, idx),
+                    buildLattices(tensorExps[e].children.e1, idx));
+  case Kind::kSubF:
+  case Kind::kSubI:
+    if (tensorExps[tensorExps[e].children.e0].kind == Kind::kZero)
+      return mapZero(kind, // maps to 0-y with just y's lattices
+                     buildLattices(tensorExps[e].children.e1, idx));
+    LLVM_FALLTHROUGH;
   case Kind::kAddF:
   case Kind::kAddI:
-    return takeDisj(kind, s0, s1);
+    return takeDisj(kind, // take binary disjunction
+                    buildLattices(tensorExps[e].children.e0, idx),
+                    buildLattices(tensorExps[e].children.e1, idx));
   }
   llvm_unreachable("unexpected expression kind");
 }
@@ -281,7 +345,18 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
   Operation *def = val.getDefiningOp();
   if (def->getBlock() != &op.region().front())
     return addExp(Kind::kInvariant, val);
-  // Construct binary operations if subexpressions could be built.
+  // Construct unary operations if subexpression can be built.
+  if (def->getNumOperands() == 1) {
+    auto x = buildTensorExp(op, def->getOperand(0));
+    if (x.hasValue()) {
+      unsigned e0 = addExp(Kind::kZero);
+      unsigned e1 = x.getValue();
+      if (isa<NegFOp>(def))
+        return addExp(Kind::kSubF, e0, e1);
+      // TODO: no negi in std?
+    }
+  }
+  // Construct binary operations if subexpressions can be built.
   if (def->getNumOperands() == 2) {
     auto x = buildTensorExp(op, def->getOperand(0));
     auto y = buildTensorExp(op, def->getOperand(1));
@@ -296,6 +371,10 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
         return addExp(Kind::kAddF, e0, e1);
       if (isa<AddIOp>(def))
         return addExp(Kind::kAddI, e0, e1);
+      if (isa<SubFOp>(def))
+        return addExp(Kind::kSubF, e0, e1);
+      if (isa<SubIOp>(def))
+        return addExp(Kind::kSubI, e0, e1);
     }
   }
   // Cannot build.

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
new file mode 100644
index 000000000000..86a009183c05
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
@@ -0,0 +1,215 @@
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+
+#trait1 = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // a
+    affine_map<(i) -> (i)>   // x (out)
+  ],
+  iterator_types = ["parallel"],
+  doc = "x(i) = OP a(i)"
+}
+
+#trait2 = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // a
+    affine_map<(i) -> (i)>,  // b
+    affine_map<(i) -> (i)>   // x (out)
+  ],
+  iterator_types = ["parallel"],
+  doc = "x(i) = a(i) OP b(i)"
+}
+
+// CHECK-LABEL:   func @neg(
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK:           %[[VAL_2:.*]] = constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = constant 0.000000e+00 : f64
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xf64>
+// CHECK:             %[[VAL_14:.*]] = subf %[[VAL_4]], %[[VAL_13]] : f64
+// CHECK:             memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xf64>
+// CHECK:           return %[[VAL_15]] : tensor<32xf64>
+// CHECK:         }
+func @neg(%arga: tensor<32xf64, #SV>,
+          %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+  %0 = linalg.generic #trait1
+     ins(%arga: tensor<32xf64, #SV>)
+    outs(%argx: tensor<32xf64>) {
+      ^bb(%a: f64, %x: f64):
+        %0 = negf %a : f64
+        linalg.yield %0 : f64
+  } -> tensor<32xf64>
+  return %0 : tensor<32xf64>
+}
+
+// CHECK-LABEL:   func @add(
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xf64>,
+// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK:           %[[VAL_3:.*]] = constant 32 : index
+// CHECK:           %[[VAL_4:.*]] = constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = constant true
+// CHECK:           %[[VAL_6:.*]] = constant 1 : index
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
+// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_17:.*]] = cmpi ult, %[[VAL_15]], %[[VAL_13]] : index
+// CHECK:             scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
+// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK:             %[[VAL_21:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK:             scf.if %[[VAL_21]] {
+// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xf64>
+// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK:               %[[VAL_24:.*]] = addf %[[VAL_22]], %[[VAL_23]] : f64
+// CHECK:               memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK:             } else {
+// CHECK:               scf.if %[[VAL_5]] {
+// CHECK:                 %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK:                 memref.store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK:               } else {
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_26:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_27:.*]] = addi %[[VAL_18]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index
+// CHECK:             %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
+// CHECK:             scf.yield %[[VAL_28]], %[[VAL_29]] : index, index
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_30:.*]] = %[[VAL_31:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK:             %[[VAL_32:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<32xf64>
+// CHECK:             memref.store %[[VAL_32]], %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<32xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_33:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64>
+// CHECK:           return %[[VAL_33]] : tensor<32xf64>
+// CHECK:         }
+func @add(%arga: tensor<32xf64, #SV>,
+          %argb: tensor<32xf64>,
+          %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xf64, #SV>, tensor<32xf64>)
+    outs(%argx: tensor<32xf64>) {
+      ^bb(%a: f64, %b: f64, %x: f64):
+        %0 = addf %a, %b : f64
+        linalg.yield %0 : f64
+  } -> tensor<32xf64>
+  return %0 : tensor<32xf64>
+}
+
+// CHECK-LABEL:   func @sub(
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xf64>,
+// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK:           %[[VAL_3:.*]] = constant 32 : index
+// CHECK:           %[[VAL_4:.*]] = constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = constant true
+// CHECK:           %[[VAL_6:.*]] = constant 1 : index
+// CHECK:           %[[VAL_7:.*]] = constant 0.000000e+00 : f64
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_18:.*]] = cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
+// CHECK:             scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index):
+// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK:             %[[VAL_22:.*]] = cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK:             scf.if %[[VAL_22]] {
+// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xf64>
+// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<32xf64>
+// CHECK:               %[[VAL_25:.*]] = subf %[[VAL_23]], %[[VAL_24]] : f64
+// CHECK:               memref.store %[[VAL_25]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<32xf64>
+// CHECK:             } else {
+// CHECK:               scf.if %[[VAL_5]] {
+// CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<32xf64>
+// CHECK:                 %[[VAL_27:.*]] = subf %[[VAL_7]], %[[VAL_26]] : f64
+// CHECK:                 memref.store %[[VAL_27]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<32xf64>
+// CHECK:               } else {
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_28:.*]] = cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK:             %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_30:.*]] = select %[[VAL_28]], %[[VAL_29]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_31:.*]] = addi %[[VAL_20]], %[[VAL_6]] : index
+// CHECK:             scf.yield %[[VAL_30]], %[[VAL_31]] : index, index
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_32:.*]] = %[[VAL_33:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK:             %[[VAL_34:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_32]]] : memref<32xf64>
+// CHECK:             %[[VAL_35:.*]] = subf %[[VAL_7]], %[[VAL_34]] : f64
+// CHECK:             memref.store %[[VAL_35]], %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<32xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_36:.*]] = memref.tensor_load %[[VAL_12]] : memref<32xf64>
+// CHECK:           return %[[VAL_36]] : tensor<32xf64>
+// CHECK:         }
+func @sub(%arga: tensor<32xf64, #SV>,
+          %argb: tensor<32xf64>,
+          %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xf64, #SV>, tensor<32xf64>)
+    outs(%argx: tensor<32xf64>) {
+      ^bb(%a: f64, %b: f64, %x: f64):
+        %0 = subf %a, %b : f64
+        linalg.yield %0 : f64
+  } -> tensor<32xf64>
+  return %0 : tensor<32xf64>
+}
+
+// CHECK-LABEL:   func @mul(
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xf64>,
+// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK:           %[[VAL_3:.*]] = constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
+// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<?xf64>
+// CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<32xf64>
+// CHECK:             %[[VAL_16:.*]] = mulf %[[VAL_14]], %[[VAL_15]] : f64
+// CHECK:             memref.store %[[VAL_16]], %[[VAL_9]]{{\[}}%[[VAL_13]]] : memref<32xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_17:.*]] = memref.tensor_load %[[VAL_9]] : memref<32xf64>
+// CHECK:           return %[[VAL_17]] : tensor<32xf64>
+// CHECK:         }
+func @mul(%arga: tensor<32xf64, #SV>,
+          %argb: tensor<32xf64>,
+          %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xf64, #SV>, tensor<32xf64>)
+    outs(%argx: tensor<32xf64>) {
+      ^bb(%a: f64, %b: f64, %x: f64):
+        %0 = mulf %a, %b : f64
+        linalg.yield %0 : f64
+  } -> tensor<32xf64>
+  return %0 : tensor<32xf64>
+}

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
new file mode 100644
index 000000000000..f306b6624099
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
@@ -0,0 +1,173 @@
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+
+#trait2 = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // a
+    affine_map<(i) -> (i)>,  // b
+    affine_map<(i) -> (i)>   // x (out)
+  ],
+  iterator_types = ["parallel"],
+  doc = "x(i) = a(i) OP b(i)"
+}
+
+// CHECK-LABEL:   func @add(
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xi64>,
+// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK:           %[[VAL_3:.*]] = constant 32 : index
+// CHECK:           %[[VAL_4:.*]] = constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = constant true
+// CHECK:           %[[VAL_6:.*]] = constant 1 : index
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
+// CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xi64>
+// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_17:.*]] = cmpi ult, %[[VAL_15]], %[[VAL_13]] : index
+// CHECK:             scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
+// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK:             %[[VAL_21:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK:             scf.if %[[VAL_21]] {
+// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xi64>
+// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:               %[[VAL_24:.*]] = addi %[[VAL_22]], %[[VAL_23]] : i64
+// CHECK:               memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:             } else {
+// CHECK:               scf.if %[[VAL_5]] {
+// CHECK:                 %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:                 memref.store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:               } else {
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_26:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_27:.*]] = addi %[[VAL_18]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index
+// CHECK:             %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
+// CHECK:             scf.yield %[[VAL_28]], %[[VAL_29]] : index, index
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_30:.*]] = %[[VAL_31:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK:             %[[VAL_32:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<32xi64>
+// CHECK:             memref.store %[[VAL_32]], %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<32xi64>
+// CHECK:           }
+// CHECK:           %[[VAL_33:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xi64>
+// CHECK:           return %[[VAL_33]] : tensor<32xi64>
+// CHECK:         }
+func @add(%arga: tensor<32xi64, #SV>,
+          %argb: tensor<32xi64>,
+          %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
+    outs(%argx: tensor<32xi64>) {
+      ^bb(%a: i64, %b: i64, %x: i64):
+        %0 = addi %a, %b : i64
+        linalg.yield %0 : i64
+  } -> tensor<32xi64>
+  return %0 : tensor<32xi64>
+}
+
+// CHECK-LABEL:   func @sub(
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xi64>,
+// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK:           %[[VAL_3:.*]] = constant 32 : index
+// CHECK:           %[[VAL_4:.*]] = constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = constant true
+// CHECK:           %[[VAL_6:.*]] = constant 1 : index
+// CHECK:           %[[VAL_7:.*]] = constant 0 : i64
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
+// CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xi64>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_18:.*]] = cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
+// CHECK:             scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index):
+// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK:             %[[VAL_22:.*]] = cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK:             scf.if %[[VAL_22]] {
+// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xi64>
+// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<32xi64>
+// CHECK:               %[[VAL_25:.*]] = subi %[[VAL_23]], %[[VAL_24]] : i64
+// CHECK:               memref.store %[[VAL_25]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<32xi64>
+// CHECK:             } else {
+// CHECK:               scf.if %[[VAL_5]] {
+// CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<32xi64>
+// CHECK:                 %[[VAL_27:.*]] = subi %[[VAL_7]], %[[VAL_26]] : i64
+// CHECK:                 memref.store %[[VAL_27]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<32xi64>
+// CHECK:               } else {
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_28:.*]] = cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK:             %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_30:.*]] = select %[[VAL_28]], %[[VAL_29]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_31:.*]] = addi %[[VAL_20]], %[[VAL_6]] : index
+// CHECK:             scf.yield %[[VAL_30]], %[[VAL_31]] : index, index
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_32:.*]] = %[[VAL_33:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK:             %[[VAL_34:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_32]]] : memref<32xi64>
+// CHECK:             %[[VAL_35:.*]] = subi %[[VAL_7]], %[[VAL_34]] : i64
+// CHECK:             memref.store %[[VAL_35]], %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<32xi64>
+// CHECK:           }
+// CHECK:           %[[VAL_36:.*]] = memref.tensor_load %[[VAL_12]] : memref<32xi64>
+// CHECK:           return %[[VAL_36]] : tensor<32xi64>
+// CHECK:         }
+func @sub(%arga: tensor<32xi64, #SV>,
+          %argb: tensor<32xi64>,
+          %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
+    outs(%argx: tensor<32xi64>) {
+      ^bb(%a: i64, %b: i64, %x: i64):
+        %0 = subi %a, %b : i64
+        linalg.yield %0 : i64
+  } -> tensor<32xi64>
+  return %0 : tensor<32xi64>
+}
+
+// CHECK-LABEL:   func @mul(
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xi64>,
+// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK:           %[[VAL_3:.*]] = constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
+// CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xi64>
+// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<?xi64>
+// CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<32xi64>
+// CHECK:             %[[VAL_16:.*]] = muli %[[VAL_14]], %[[VAL_15]] : i64
+// CHECK:             memref.store %[[VAL_16]], %[[VAL_9]]{{\[}}%[[VAL_13]]] : memref<32xi64>
+// CHECK:           }
+// CHECK:           %[[VAL_17:.*]] = memref.tensor_load %[[VAL_9]] : memref<32xi64>
+// CHECK:           return %[[VAL_17]] : tensor<32xi64>
+// CHECK:         }
+func @mul(%arga: tensor<32xi64, #SV>,
+          %argb: tensor<32xi64>,
+          %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
+    outs(%argx: tensor<32xi64>) {
+      ^bb(%a: i64, %b: i64, %x: i64):
+        %0 = muli %a, %b : i64
+        linalg.yield %0 : i64
+  } -> tensor<32xi64>
+  return %0 : tensor<32xi64>
+}


        


More information about the Mlir-commits mailing list