[Mlir-commits] [mlir] b8a021d - [mlir][sparse] support for negation and subtractions
Aart Bik
llvmlistbot at llvm.org
Fri Jul 2 15:55:20 PDT 2021
Author: Aart Bik
Date: 2021-07-02T15:55:05-07:00
New Revision: b8a021dbe322c4fae196318df7c0ebb2dd0f4a31
URL: https://github.com/llvm/llvm-project/commit/b8a021dbe322c4fae196318df7c0ebb2dd0f4a31
DIFF: https://github.com/llvm/llvm-project/commit/b8a021dbe322c4fae196318df7c0ebb2dd0f4a31.diff
LOG: [mlir][sparse] support for negation and subtractions
This revision extends the sparse compiler support from fp/int addition and multiplication to fp/int negation and subtraction, thereby increasing the scope of sparse kernels that can be compiled.
Reviewed By: gussmith23
Differential Revision: https://reviews.llvm.org/D105306
Added:
mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 4141c68a5e37..d7496b30a1c3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -20,34 +20,33 @@
namespace mlir {
namespace sparse_tensor {
-/// Tensor expression kind.
-enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
-
/// Dimension level type for a tensor (undef means index does not appear).
enum class Dim { kSparse, kDense, kSingle, kUndef };
-/// Children expressions of a binary TensorExp.
+/// Tensor expression kind.
+enum class Kind {
+ // Leaf.
+ kTensor,
+ kInvariant,
+ kZero,
+ // Operation.
+ kMulF,
+ kMulI,
+ kAddF,
+ kAddI,
+ kSubF,
+ kSubI
+};
+
+/// Children subexpressions of tensor operations.
struct Children {
unsigned e0;
unsigned e1;
};
/// Tensor expression. Represents a MLIR expression in tensor index notation.
-/// For tensors, e0 denotes the tensor index. For invariants, the IR value is
-/// stored directly. For binary operations, e0 and e1 denote the index of the
-/// children tensor expressions.
struct TensorExp {
- TensorExp(Kind k, unsigned x, unsigned y, Value v) : kind(k), val(v) {
- assert((kind == Kind::kTensor && x != -1u && y == -1u && !val) ||
- (kind == Kind::kInvariant && x == -1u && y == -1u && val) ||
- (kind >= Kind::kMulF && x != -1u && y != -1u && !val));
- if (kind == Kind::kTensor) {
- tensor = x;
- } else if (kind >= Kind::kMulF) {
- children.e0 = x;
- children.e1 = y;
- }
- }
+ TensorExp(Kind k, unsigned x, unsigned y, Value v);
/// Tensor expression kind.
Kind kind;
@@ -56,7 +55,7 @@ struct TensorExp {
/// Expressions representing tensors simply have a tensor number.
unsigned tensor;
- /// Binary operations hold the indices of their child expressions.
+ /// Tensor operations hold the indices of their children.
Children children;
};
@@ -69,10 +68,8 @@ struct TensorExp {
/// loop indices (encoded in a bitvector) and the index of the corresponding
/// tensor expression.
struct LatPoint {
- LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
- bits.set(b);
- }
- LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
+ LatPoint(unsigned n, unsigned e, unsigned b);
+ LatPoint(const llvm::BitVector &b, unsigned e);
/// Conjunction of tensor loop indices as bitvector. This represents
/// all indices involved in the tensor expression
@@ -103,7 +100,8 @@ class Merger {
dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
/// Adds a tensor expression. Returns its index.
- unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value());
+ unsigned addExp(Kind k, unsigned e0 = -1u, unsigned e1 = -1u,
+ Value v = Value());
unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
/// Adds an iteration lattice point. Returns its index.
@@ -126,6 +124,12 @@ class Merger {
/// Returns the index of the new set.
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1);
+ /// Maps a zero operand over a lattice set, i.e. each lattice point on an
+ /// expression E is simply copied over, but with 0 OP E as new expression.
+ /// This is useful to deal with disjunctive, but non-commutative operators.
+ /// Returns the index of the new set.
+ unsigned mapZero(Kind kind, unsigned s0);
+
/// Optimizes the iteration lattice points in the given set. This
/// method should be called right before code generation to avoid
/// generating redundant loops and conditions.
@@ -135,7 +139,7 @@ class Merger {
/// within the given set using just two basic rules:
/// (1) multiple dense conditions are reduced to single dense, and
/// (2) a *singleton* sparse/dense is reduced to sparse/random access.
- llvm::BitVector simplifyCond(unsigned s, unsigned p0);
+ llvm::BitVector simplifyCond(unsigned s0, unsigned p0);
/// Returns true if Li > Lj.
bool latGT(unsigned i, unsigned j) const;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 813fe683ae61..775f3a140f82 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -624,24 +624,36 @@ static void genReductionEnd(Merger &merger, CodeGen &codegen,
/// Recursively generates tensor expression.
static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
linalg::GenericOp op, unsigned exp) {
+ Location loc = op.getLoc();
if (merger.exp(exp).kind == Kind::kTensor)
return genTensorLoad(merger, codegen, rewriter, op, exp);
- else if (merger.exp(exp).kind == Kind::kInvariant)
+ if (merger.exp(exp).kind == Kind::kInvariant)
+ return genInvariantValue(merger, codegen, rewriter, exp);
+ if (merger.exp(exp).kind == Kind::kZero) {
+ Type tp = op.getOutputTensorTypes()[0].getElementType();
+ merger.exp(exp).val =
+ rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
return genInvariantValue(merger, codegen, rewriter, exp);
+ }
Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0);
Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1);
switch (merger.exp(exp).kind) {
case Kind::kTensor:
case Kind::kInvariant:
+ case Kind::kZero:
llvm_unreachable("handled above");
case Kind::kMulF:
- return rewriter.create<MulFOp>(op.getLoc(), v0, v1);
+ return rewriter.create<MulFOp>(loc, v0, v1);
case Kind::kMulI:
- return rewriter.create<MulIOp>(op.getLoc(), v0, v1);
+ return rewriter.create<MulIOp>(loc, v0, v1);
case Kind::kAddF:
- return rewriter.create<AddFOp>(op.getLoc(), v0, v1);
+ return rewriter.create<AddFOp>(loc, v0, v1);
case Kind::kAddI:
- return rewriter.create<AddIOp>(op.getLoc(), v0, v1);
+ return rewriter.create<AddIOp>(loc, v0, v1);
+ case Kind::kSubF:
+ return rewriter.create<SubFOp>(loc, v0, v1);
+ case Kind::kSubI:
+ return rewriter.create<SubIOp>(loc, v0, v1);
}
llvm_unreachable("unexpected expression kind");
}
@@ -671,7 +683,8 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
merger.exp(exp).val =
hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
}
- } else if (merger.exp(exp).kind != Kind::kInvariant) {
+ } else if (merger.exp(exp).kind != Kind::kInvariant &&
+ merger.exp(exp).kind != Kind::kZero) {
// Traverse into the binary operations. Note that we only hoist
// tensor loads, since subsequent MLIR/LLVM passes know how to
// deal with all other kinds of derived loop invariants.
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 6150c15a0ad1..2a1ad9ad56df 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -14,6 +14,39 @@
namespace mlir {
namespace sparse_tensor {
+//
+// Constructors.
+//
+
+TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
+ : kind(k), val(v) {
+ switch (kind) {
+ case Kind::kTensor:
+ assert(x != -1u && y == -1u && !v);
+ tensor = x;
+ break;
+ case Kind::kInvariant:
+ assert(x == -1u && y == -1u && v);
+ break;
+ case Kind::kZero:
+ assert(x == -1u && y == -1u && !v);
+ break;
+ default:
+ assert(x != -1u && y != -1u && !v);
+ children.e0 = x;
+ children.e1 = y;
+ break;
+ }
+}
+
+LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
+ : bits(n, false), simple(), exp(e) {
+ bits.set(b);
+}
+
+LatPoint::LatPoint(const llvm::BitVector &b, unsigned e)
+ : bits(b), simple(), exp(e) {}
+
//
// Lattice methods.
//
@@ -56,13 +89,28 @@ unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
unsigned s = takeConj(kind, s0, s1);
+ // Followed by all in s0 and s1.
for (unsigned p : latSets[s0])
latSets[s].push_back(p);
+ if (Kind::kSubF <= kind && kind <= Kind::kSubI)
+ s1 = mapZero(kind, s1);
for (unsigned p : latSets[s1])
latSets[s].push_back(p);
return s;
}
+unsigned Merger::mapZero(Kind kind, unsigned s0) {
+ assert(Kind::kSubF <= kind && kind <= Kind::kSubI);
+ unsigned s = addSet();
+ unsigned z = addExp(Kind::kZero);
+ for (unsigned p : latSets[s0]) {
+ unsigned e = addExp(kind, z, latPoints[p].exp);
+ latPoints.push_back(LatPoint(latPoints[p].bits, e));
+ latSets[s].push_back(latPoints.size() - 1);
+ }
+ return s;
+}
+
unsigned Merger::optimizeSet(unsigned s0) {
unsigned s = addSet();
assert(latSets[s0].size() != 0);
@@ -93,11 +141,11 @@ unsigned Merger::optimizeSet(unsigned s0) {
return s;
}
-llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) {
+llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
// First determine if this lattice point is a *singleton*, i.e.,
// the last point in a lattice, no other is less than this one.
bool isSingleton = true;
- for (unsigned p1 : latSets[s]) {
+ for (unsigned p1 : latSets[s0]) {
if (p0 != p1 && latGT(p0, p1)) {
isSingleton = false;
break;
@@ -148,6 +196,23 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
// Print methods (for debugging).
//
+static char kindToOpSymbol(Kind kind) {
+ switch (kind) {
+ case Kind::kMulF:
+ case Kind::kMulI:
+ return '*';
+ case Kind::kAddF:
+ case Kind::kAddI:
+ return '+';
+ case Kind::kSubF:
+ case Kind::kSubI:
+ return '-';
+ default:
+ break;
+ }
+ llvm_unreachable("unexpected kind");
+}
+
void Merger::dumpExp(unsigned e) const {
switch (tensorExps[e].kind) {
case Kind::kTensor:
@@ -160,22 +225,15 @@ void Merger::dumpExp(unsigned e) const {
case Kind::kInvariant:
llvm::dbgs() << "invariant";
break;
- default:
- case Kind::kMulI:
- llvm::dbgs() << "(";
- dumpExp(tensorExps[e].children.e0);
- llvm::dbgs() << " * ";
- dumpExp(tensorExps[e].children.e1);
- llvm::dbgs() << ")";
+ case Kind::kZero:
+ llvm::dbgs() << "zero";
break;
- case Kind::kAddF:
- case Kind::kAddI:
+ default:
llvm::dbgs() << "(";
dumpExp(tensorExps[e].children.e0);
- llvm::dbgs() << " + ";
+ llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
dumpExp(tensorExps[e].children.e1);
llvm::dbgs() << ")";
- break;
}
}
@@ -184,7 +242,7 @@ void Merger::dumpLat(unsigned p) const {
dumpBits(latPoints[p].bits);
llvm::dbgs() << " :";
dumpBits(latPoints[p].simple);
- llvm::dbgs() << " / ";
+ llvm::dbgs() << " : ";
dumpExp(latPoints[p].exp);
llvm::dbgs() << " )\n";
}
@@ -230,28 +288,34 @@ void Merger::dumpBits(const llvm::BitVector &bits) const {
unsigned Merger::buildLattices(unsigned e, unsigned idx) {
Kind kind = tensorExps[e].kind;
- if (kind == Kind::kTensor || kind == Kind::kInvariant) {
+ switch (kind) {
+ case Kind::kTensor:
+ case Kind::kInvariant:
+ case Kind::kZero: {
// Either the index is really used in the tensor expression, or it is
// set to the undefined index in that dimension. An invariant expression
// is set to a synthetic tensor with undefined indices only.
unsigned s = addSet();
- unsigned t =
- kind == Kind::kTensor ? tensorExps[e].children.e0 : syntheticTensor;
+ unsigned t = kind == Kind::kTensor ? tensorExps[e].tensor : syntheticTensor;
latSets[s].push_back(addLat(t, idx, e));
return s;
}
- unsigned s0 = buildLattices(tensorExps[e].children.e0, idx);
- unsigned s1 = buildLattices(tensorExps[e].children.e1, idx);
- switch (kind) {
- case Kind::kTensor:
- case Kind::kInvariant:
- llvm_unreachable("handled above");
case Kind::kMulF:
case Kind::kMulI:
- return takeConj(kind, s0, s1);
+ return takeConj(kind, // take binary conjunction
+ buildLattices(tensorExps[e].children.e0, idx),
+ buildLattices(tensorExps[e].children.e1, idx));
+ case Kind::kSubF:
+ case Kind::kSubI:
+ if (tensorExps[tensorExps[e].children.e0].kind == Kind::kZero)
+ return mapZero(kind, // maps to 0-y with just y's lattices
+ buildLattices(tensorExps[e].children.e1, idx));
+ LLVM_FALLTHROUGH;
case Kind::kAddF:
case Kind::kAddI:
- return takeDisj(kind, s0, s1);
+ return takeDisj(kind, // take binary disjunction
+ buildLattices(tensorExps[e].children.e0, idx),
+ buildLattices(tensorExps[e].children.e1, idx));
}
llvm_unreachable("unexpected expression kind");
}
@@ -281,7 +345,18 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
Operation *def = val.getDefiningOp();
if (def->getBlock() != &op.region().front())
return addExp(Kind::kInvariant, val);
- // Construct binary operations if subexpressions could be built.
+ // Construct unary operations if subexpression can be built.
+ if (def->getNumOperands() == 1) {
+ auto x = buildTensorExp(op, def->getOperand(0));
+ if (x.hasValue()) {
+ unsigned e0 = addExp(Kind::kZero);
+ unsigned e1 = x.getValue();
+ if (isa<NegFOp>(def))
+ return addExp(Kind::kSubF, e0, e1);
+ // TODO: no negi in std?
+ }
+ }
+ // Construct binary operations if subexpressions can be built.
if (def->getNumOperands() == 2) {
auto x = buildTensorExp(op, def->getOperand(0));
auto y = buildTensorExp(op, def->getOperand(1));
@@ -296,6 +371,10 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
return addExp(Kind::kAddF, e0, e1);
if (isa<AddIOp>(def))
return addExp(Kind::kAddI, e0, e1);
+ if (isa<SubFOp>(def))
+ return addExp(Kind::kSubF, e0, e1);
+ if (isa<SubIOp>(def))
+ return addExp(Kind::kSubI, e0, e1);
}
}
// Cannot build.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
new file mode 100644
index 000000000000..86a009183c05
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
@@ -0,0 +1,215 @@
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+
+#trait1 = {
+ indexing_maps = [
+ affine_map<(i) -> (i)>, // a
+ affine_map<(i) -> (i)> // x (out)
+ ],
+ iterator_types = ["parallel"],
+ doc = "x(i) = OP a(i)"
+}
+
+#trait2 = {
+ indexing_maps = [
+ affine_map<(i) -> (i)>, // a
+ affine_map<(i) -> (i)>, // b
+ affine_map<(i) -> (i)> // x (out)
+ ],
+ iterator_types = ["parallel"],
+ doc = "x(i) = a(i) OP b(i)"
+}
+
+// CHECK-LABEL: func @neg(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK: %[[VAL_2:.*]] = constant 0 : index
+// CHECK: %[[VAL_3:.*]] = constant 1 : index
+// CHECK: %[[VAL_4:.*]] = constant 0.000000e+00 : f64
+// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_3]] {
+// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xf64>
+// CHECK: %[[VAL_14:.*]] = subf %[[VAL_4]], %[[VAL_13]] : f64
+// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xf64>
+// CHECK: }
+// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xf64>
+// CHECK: return %[[VAL_15]] : tensor<32xf64>
+// CHECK: }
+func @neg(%arga: tensor<32xf64, #SV>,
+ %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+ %0 = linalg.generic #trait1
+ ins(%arga: tensor<32xf64, #SV>)
+ outs(%argx: tensor<32xf64>) {
+ ^bb(%a: f64, %x: f64):
+ %0 = negf %a : f64
+ linalg.yield %0 : f64
+ } -> tensor<32xf64>
+ return %0 : tensor<32xf64>
+}
+
+// CHECK-LABEL: func @add(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf64>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK: %[[VAL_3:.*]] = constant 32 : index
+// CHECK: %[[VAL_4:.*]] = constant 0 : index
+// CHECK: %[[VAL_5:.*]] = constant true
+// CHECK: %[[VAL_6:.*]] = constant 1 : index
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
+// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
+// CHECK: %[[VAL_17:.*]] = cmpi ult, %[[VAL_15]], %[[VAL_13]] : index
+// CHECK: scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
+// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK: %[[VAL_21:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK: scf.if %[[VAL_21]] {
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xf64>
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK: %[[VAL_24:.*]] = addf %[[VAL_22]], %[[VAL_23]] : f64
+// CHECK: memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK: } else {
+// CHECK: scf.if %[[VAL_5]] {
+// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK: memref.store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK: } else {
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_26:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK: %[[VAL_27:.*]] = addi %[[VAL_18]], %[[VAL_6]] : index
+// CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index
+// CHECK: %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
+// CHECK: scf.yield %[[VAL_28]], %[[VAL_29]] : index, index
+// CHECK: }
+// CHECK: scf.for %[[VAL_30:.*]] = %[[VAL_31:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<32xf64>
+// CHECK: memref.store %[[VAL_32]], %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<32xf64>
+// CHECK: }
+// CHECK: %[[VAL_33:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64>
+// CHECK: return %[[VAL_33]] : tensor<32xf64>
+// CHECK: }
+func @add(%arga: tensor<32xf64, #SV>,
+ %argb: tensor<32xf64>,
+ %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+ %0 = linalg.generic #trait2
+ ins(%arga, %argb: tensor<32xf64, #SV>, tensor<32xf64>)
+ outs(%argx: tensor<32xf64>) {
+ ^bb(%a: f64, %b: f64, %x: f64):
+ %0 = addf %a, %b : f64
+ linalg.yield %0 : f64
+ } -> tensor<32xf64>
+ return %0 : tensor<32xf64>
+}
+
+// CHECK-LABEL: func @sub(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf64>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK: %[[VAL_3:.*]] = constant 32 : index
+// CHECK: %[[VAL_4:.*]] = constant 0 : index
+// CHECK: %[[VAL_5:.*]] = constant true
+// CHECK: %[[VAL_6:.*]] = constant 1 : index
+// CHECK: %[[VAL_7:.*]] = constant 0.000000e+00 : f64
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK: %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK: %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
+// CHECK: %[[VAL_18:.*]] = cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
+// CHECK: scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index):
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK: scf.if %[[VAL_22]] {
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xf64>
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<32xf64>
+// CHECK: %[[VAL_25:.*]] = subf %[[VAL_23]], %[[VAL_24]] : f64
+// CHECK: memref.store %[[VAL_25]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<32xf64>
+// CHECK: } else {
+// CHECK: scf.if %[[VAL_5]] {
+// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<32xf64>
+// CHECK: %[[VAL_27:.*]] = subf %[[VAL_7]], %[[VAL_26]] : f64
+// CHECK: memref.store %[[VAL_27]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<32xf64>
+// CHECK: } else {
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_28:.*]] = cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK: %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
+// CHECK: %[[VAL_30:.*]] = select %[[VAL_28]], %[[VAL_29]], %[[VAL_19]] : index
+// CHECK: %[[VAL_31:.*]] = addi %[[VAL_20]], %[[VAL_6]] : index
+// CHECK: scf.yield %[[VAL_30]], %[[VAL_31]] : index, index
+// CHECK: }
+// CHECK: scf.for %[[VAL_32:.*]] = %[[VAL_33:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_32]]] : memref<32xf64>
+// CHECK: %[[VAL_35:.*]] = subf %[[VAL_7]], %[[VAL_34]] : f64
+// CHECK: memref.store %[[VAL_35]], %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<32xf64>
+// CHECK: }
+// CHECK: %[[VAL_36:.*]] = memref.tensor_load %[[VAL_12]] : memref<32xf64>
+// CHECK: return %[[VAL_36]] : tensor<32xf64>
+// CHECK: }
+func @sub(%arga: tensor<32xf64, #SV>,
+ %argb: tensor<32xf64>,
+ %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+ %0 = linalg.generic #trait2
+ ins(%arga, %argb: tensor<32xf64, #SV>, tensor<32xf64>)
+ outs(%argx: tensor<32xf64>) {
+ ^bb(%a: f64, %b: f64, %x: f64):
+ %0 = subf %a, %b : f64
+ linalg.yield %0 : f64
+ } -> tensor<32xf64>
+ return %0 : tensor<32xf64>
+}
+
+// CHECK-LABEL: func @mul(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf64>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK: %[[VAL_3:.*]] = constant 0 : index
+// CHECK: %[[VAL_4:.*]] = constant 1 : index
+// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
+// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] {
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<?xf64>
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<32xf64>
+// CHECK: %[[VAL_16:.*]] = mulf %[[VAL_14]], %[[VAL_15]] : f64
+// CHECK: memref.store %[[VAL_16]], %[[VAL_9]]{{\[}}%[[VAL_13]]] : memref<32xf64>
+// CHECK: }
+// CHECK: %[[VAL_17:.*]] = memref.tensor_load %[[VAL_9]] : memref<32xf64>
+// CHECK: return %[[VAL_17]] : tensor<32xf64>
+// CHECK: }
+func @mul(%arga: tensor<32xf64, #SV>,
+ %argb: tensor<32xf64>,
+ %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+ %0 = linalg.generic #trait2
+ ins(%arga, %argb: tensor<32xf64, #SV>, tensor<32xf64>)
+ outs(%argx: tensor<32xf64>) {
+ ^bb(%a: f64, %b: f64, %x: f64):
+ %0 = mulf %a, %b : f64
+ linalg.yield %0 : f64
+ } -> tensor<32xf64>
+ return %0 : tensor<32xf64>
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
new file mode 100644
index 000000000000..f306b6624099
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
@@ -0,0 +1,173 @@
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+
+#trait2 = {
+ indexing_maps = [
+ affine_map<(i) -> (i)>, // a
+ affine_map<(i) -> (i)>, // b
+ affine_map<(i) -> (i)> // x (out)
+ ],
+ iterator_types = ["parallel"],
+ doc = "x(i) = a(i) OP b(i)"
+}
+
+// CHECK-LABEL: func @add(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK: %[[VAL_3:.*]] = constant 32 : index
+// CHECK: %[[VAL_4:.*]] = constant 0 : index
+// CHECK: %[[VAL_5:.*]] = constant true
+// CHECK: %[[VAL_6:.*]] = constant 1 : index
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
+// CHECK: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xi64>
+// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
+// CHECK: %[[VAL_17:.*]] = cmpi ult, %[[VAL_15]], %[[VAL_13]] : index
+// CHECK: scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
+// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK: %[[VAL_21:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK: scf.if %[[VAL_21]] {
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xi64>
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK: %[[VAL_24:.*]] = addi %[[VAL_22]], %[[VAL_23]] : i64
+// CHECK: memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK: } else {
+// CHECK: scf.if %[[VAL_5]] {
+// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK: memref.store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK: } else {
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_26:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK: %[[VAL_27:.*]] = addi %[[VAL_18]], %[[VAL_6]] : index
+// CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index
+// CHECK: %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
+// CHECK: scf.yield %[[VAL_28]], %[[VAL_29]] : index, index
+// CHECK: }
+// CHECK: scf.for %[[VAL_30:.*]] = %[[VAL_31:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<32xi64>
+// CHECK: memref.store %[[VAL_32]], %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<32xi64>
+// CHECK: }
+// CHECK: %[[VAL_33:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xi64>
+// CHECK: return %[[VAL_33]] : tensor<32xi64>
+// CHECK: }
+func @add(%arga: tensor<32xi64, #SV>,
+ %argb: tensor<32xi64>,
+ %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+ %0 = linalg.generic #trait2
+ ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
+ outs(%argx: tensor<32xi64>) {
+ ^bb(%a: i64, %b: i64, %x: i64):
+ %0 = addi %a, %b : i64
+ linalg.yield %0 : i64
+ } -> tensor<32xi64>
+ return %0 : tensor<32xi64>
+}
+
+// CHECK-LABEL: func @sub(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK: %[[VAL_3:.*]] = constant 32 : index
+// CHECK: %[[VAL_4:.*]] = constant 0 : index
+// CHECK: %[[VAL_5:.*]] = constant true
+// CHECK: %[[VAL_6:.*]] = constant 1 : index
+// CHECK: %[[VAL_7:.*]] = constant 0 : i64
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
+// CHECK: %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xi64>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK: %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
+// CHECK: %[[VAL_18:.*]] = cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
+// CHECK: scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index):
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK: scf.if %[[VAL_22]] {
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xi64>
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<32xi64>
+// CHECK: %[[VAL_25:.*]] = subi %[[VAL_23]], %[[VAL_24]] : i64
+// CHECK: memref.store %[[VAL_25]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<32xi64>
+// CHECK: } else {
+// CHECK: scf.if %[[VAL_5]] {
+// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<32xi64>
+// CHECK: %[[VAL_27:.*]] = subi %[[VAL_7]], %[[VAL_26]] : i64
+// CHECK: memref.store %[[VAL_27]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<32xi64>
+// CHECK: } else {
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_28:.*]] = cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK: %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
+// CHECK: %[[VAL_30:.*]] = select %[[VAL_28]], %[[VAL_29]], %[[VAL_19]] : index
+// CHECK: %[[VAL_31:.*]] = addi %[[VAL_20]], %[[VAL_6]] : index
+// CHECK: scf.yield %[[VAL_30]], %[[VAL_31]] : index, index
+// CHECK: }
+// CHECK: scf.for %[[VAL_32:.*]] = %[[VAL_33:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_32]]] : memref<32xi64>
+// CHECK: %[[VAL_35:.*]] = subi %[[VAL_7]], %[[VAL_34]] : i64
+// CHECK: memref.store %[[VAL_35]], %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<32xi64>
+// CHECK: }
+// CHECK: %[[VAL_36:.*]] = memref.tensor_load %[[VAL_12]] : memref<32xi64>
+// CHECK: return %[[VAL_36]] : tensor<32xi64>
+// CHECK: }
+func @sub(%arga: tensor<32xi64, #SV>,
+ %argb: tensor<32xi64>,
+ %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+ %0 = linalg.generic #trait2
+ ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
+ outs(%argx: tensor<32xi64>) {
+ ^bb(%a: i64, %b: i64, %x: i64):
+ %0 = subi %a, %b : i64
+ linalg.yield %0 : i64
+ } -> tensor<32xi64>
+ return %0 : tensor<32xi64>
+}
+
+// CHECK-LABEL: func @mul(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK: %[[VAL_3:.*]] = constant 0 : index
+// CHECK: %[[VAL_4:.*]] = constant 1 : index
+// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
+// CHECK: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xi64>
+// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] {
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<?xi64>
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<32xi64>
+// CHECK: %[[VAL_16:.*]] = muli %[[VAL_14]], %[[VAL_15]] : i64
+// CHECK: memref.store %[[VAL_16]], %[[VAL_9]]{{\[}}%[[VAL_13]]] : memref<32xi64>
+// CHECK: }
+// CHECK: %[[VAL_17:.*]] = memref.tensor_load %[[VAL_9]] : memref<32xi64>
+// CHECK: return %[[VAL_17]] : tensor<32xi64>
+// CHECK: }
+func @mul(%arga: tensor<32xi64, #SV>,
+ %argb: tensor<32xi64>,
+ %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+ %0 = linalg.generic #trait2
+ ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
+ outs(%argx: tensor<32xi64>) {
+ ^bb(%a: i64, %b: i64, %x: i64):
+ %0 = muli %a, %b : i64
+ linalg.yield %0 : i64
+ } -> tensor<32xi64>
+ return %0 : tensor<32xi64>
+}
More information about the Mlir-commits
mailing list