[Mlir-commits] [mlir] 7d4da4e - [mlir][sparse] generalize sparse tensor output implementation
Aart Bik
llvmlistbot at llvm.org
Mon Nov 29 16:16:01 PST 2021
Author: Aart Bik
Date: 2021-11-29T16:15:53-08:00
New Revision: 7d4da4e1ab7f79e51db0d5c2a0f5ef1711122dd7
URL: https://github.com/llvm/llvm-project/commit/7d4da4e1ab7f79e51db0d5c2a0f5ef1711122dd7
DIFF: https://github.com/llvm/llvm-project/commit/7d4da4e1ab7f79e51db0d5c2a0f5ef1711122dd7.diff
LOG: [mlir][sparse] generalize sparse tensor output implementation
Moves sparse tensor output support forward by generalizing from injective
insertions only to include reductions. This revision accepts the case with all
parallel outer and all reduction inner loops, since that can be handled with
an injective insertion still. Next revision will allow the inner parallel loop
to move inward (but that will require "access pattern expansion" aka "workspace").
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D114399
Added:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/test/Dialect/SparseTensor/sparse_out.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index d396f7a50ef50..8724ff3c52ece 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -122,7 +122,7 @@ class Merger {
/// invariant expressions in the kernel.
Merger(unsigned t, unsigned l)
: outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l),
- dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
+ hasSparseOut(false), dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
/// Adds a tensor expression. Returns its index.
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value());
@@ -200,6 +200,9 @@ class Merger {
/// Dimension setter.
void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
+ // Has sparse output tensor setter.
+ void setHasSparseOut(bool s) { hasSparseOut = s; }
+
/// Convenience getters to immediately access the stored nodes.
/// Typically it is inadvisible to keep the reference around, as in
/// "TensorExpr &te = merger.exp(e))", since insertions into the merger
@@ -230,6 +233,7 @@ class Merger {
Value v1);
private:
+ /// Private helpers.
bool maybeZero(unsigned e) const;
bool isInvariant(unsigned e) const;
Type inferType(unsigned e, Value src);
@@ -237,11 +241,12 @@ class Merger {
/// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value v);
+ /// Merger data structures.
const unsigned outTensor;
const unsigned syntheticTensor;
const unsigned numTensors;
const unsigned numLoops;
-
+ bool hasSparseOut;
std::vector<std::vector<Dim>> dims;
llvm::SmallVector<TensorExp, 32> tensorExps;
llvm::SmallVector<LatPoint, 16> latPoints;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 3b7a0cfd41e0f..d640af0699004 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -46,15 +46,15 @@ enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor };
// Code generation.
struct CodeGen {
CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops,
- OpOperand *op)
+ OpOperand *op, unsigned nest)
: options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
pointers(numTensors, std::vector<Value>(numLoops)),
indices(numTensors, std::vector<Value>(numLoops)),
highs(numTensors, std::vector<Value>(numLoops)),
pidxs(numTensors, std::vector<Value>(numLoops)),
idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
- redKind(kNoReduc), sparseOut(op), lexIdx(), curVecLength(1),
- curVecMask() {}
+ redKind(kNoReduc), sparseOut(op), outerParNest(nest), lexIdx(),
+ curVecLength(1), curVecMask() {}
/// Sparsification options.
SparsificationOptions options;
/// Universal dense indices and upper bounds (by index). The loops array
@@ -79,8 +79,11 @@ struct CodeGen {
unsigned redExp;
Value redVal;
Reduction redKind;
- // Sparse tensor as output.
+ // Sparse tensor as output. Implemented either through direct injective
+ // insertion in lexicographic index order (where indices are updated
+ // in the temporary array `lexIdx`) or TODO: access pattern expansion
OpOperand *sparseOut;
+ unsigned outerParNest;
Value lexIdx;
// Current vector length and mask.
unsigned curVecLength;
@@ -288,10 +291,13 @@ static bool isMaterializing(Value val) {
/// Returns true when the tensor expression is admissable for codegen.
/// Since all sparse input tensors are admissable, we just need to check
-/// whether the output tensor in the tensor expression codegen is admissable.
-/// Sets `sparseOut` when a "truly dynamic" sparse tensor output occurs.
+/// whether the out tensor in the tensor expression codegen is admissable.
+/// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective
+/// nesting depth when a "truly dynamic" sparse tensor output occurs.
static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
- unsigned exp, OpOperand **sparseOut) {
+ std::vector<unsigned> &topSort, unsigned exp,
+ OpOperand **sparseOut,
+ unsigned &outerParNest) {
OpOperand *lhs = op.getOutputOperand(0);
unsigned tensor = lhs->getOperandNumber();
auto enc = getSparseTensorEncoding(lhs->get().getType());
@@ -302,7 +308,8 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
// An all-dense annotated "sparse" output tensor becomes a linearized random
// access 1-dim memref. Also admissable since insertions cannot occur.
bool allDense = true;
- unsigned numLoops = op.iterator_types().getValue().size();
+ auto iteratorTypes = op.iterator_types().getValue();
+ unsigned numLoops = iteratorTypes.size();
for (unsigned i = 0; i < numLoops; i++)
if (merger.isDim(tensor, i, Dim::kSparse)) {
allDense = false;
@@ -319,15 +326,20 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
// Accept "truly dynamic" if the output tensor materializes uninitialized
// into the computation and insertions occur in lexicographic index order.
if (isMaterializing(lhs->get())) {
- // In this first sparse tensor output implementation, this is enforced by
- // rejecting any reduction loops (since the sparse parallel loops give a
- // lexicographically sorted and injective view into that tensor).
- // TODO: generalize to include reductions
- for (auto attr : op.iterator_types())
- if (isReductionIterator(attr))
- return false;
- *sparseOut = lhs;
- return true;
+ unsigned nest = 0;
+ for (unsigned i = 0; i < numLoops; i++) {
+ if (isReductionIterator(iteratorTypes[topSort[i]]))
+ break; // terminate at first reduction
+ nest++;
+ }
+ // Determine admissable dynamic insertion situations:
+ // (1) fully injective, since there are no reductions,
+ // (2) admissable 1-d expansion in innermost dimension. TODO: accept
+ if (nest == op.getRank(lhs)) {
+ *sparseOut = lhs;
+ outerParNest = nest;
+ return true;
+ }
}
return false;
}
@@ -704,9 +716,15 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
return genVectorInvariantValue(codegen, rewriter, val);
return val;
}
+ // Insertion (a sparse tensor output "loads" as zero).
+ OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
+ if (t == codegen.sparseOut) {
+ Type tp = getElementTypeOrSelf(t->get().getType());
+ return rewriter.create<arith::ConstantOp>(op.getLoc(), tp,
+ rewriter.getZeroAttr(tp));
+ }
// Actual load.
SmallVector<Value, 4> args;
- OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
Value ptr = genSubscript(codegen, rewriter, op, t, args);
if (codegen.curVecLength > 1)
return genVectorLoad(codegen, rewriter, ptr, args);
@@ -1515,11 +1533,14 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// Rejects an inadmissable tensor expression.
OpOperand *sparseOut = nullptr;
- if (!isAdmissableTensorExp(merger, op, exp, &sparseOut))
+ unsigned outerParNest = 0;
+ if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut,
+ outerParNest))
return failure();
// Recursively generates code.
- CodeGen codegen(options, numTensors, numLoops, sparseOut);
+ merger.setHasSparseOut(sparseOut != nullptr);
+ CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest);
genBuffers(merger, codegen, rewriter, op);
genStmt(merger, codegen, rewriter, op, topSort, exp, 0);
genResult(merger, codegen, rewriter, op);
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index cd75911556cd2..466191dd38b70 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -415,9 +415,13 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
case kInvariant: {
// Either the index is really used in the tensor expression, or it is
// set to the undefined index in that dimension. An invariant expression
- // is set to a synthetic tensor with undefined indices only.
+ // and a truly dynamic sparse output tensor are set to a synthetic tensor
+ // with undefined indices only to ensure the iteration space is not
+ // skipped as a result of their contents.
unsigned s = addSet();
unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor;
+ if (hasSparseOut && t == outTensor)
+ t = syntheticTensor;
latSets[s].push_back(addLat(t, i, e));
return s;
}
@@ -593,8 +597,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
}
}
// Construct binary operations if subexpressions can be built.
- // TODO: see buildLattices() for an explanation of rejecting
- // certain division and shift operations
+ // See buildLattices() for an explanation of rejecting certain
+ // division and shift operations
if (def->getNumOperands() == 2) {
auto x = buildTensorExp(op, def->getOperand(0));
auto y = buildTensorExp(op, def->getOperand(1));
diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index 90ba2ff4d6df0..5481c518d521d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -11,6 +11,10 @@
dimOrdering = affine_map<(i,j) -> (i,j)>
}>
+#SparseTensor = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed", "compressed", "compressed" ]
+}>
+
#trait_scale_inpl = {
indexing_maps = [
affine_map<(i,j) -> (i,j)> // X (out)
@@ -182,3 +186,161 @@ func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20xf32,
} -> tensor<10x20xf32, #DCSR>
return %0 : tensor<10x20xf32, #DCSR>
}
+
+#trait_sumred = {
+ indexing_maps = [
+ affine_map<(i,j,k) -> (i,j,k)>, // A
+ affine_map<(i,j,k) -> (i,j,k)>, // B
+ affine_map<(i,j,k) -> (i,j)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ doc = "X(i,j) = SUM_k A(i,j,k) * B(i,j,k)"
+}
+
+// CHECK-LABEL: func @sumred(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>)
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #{{.*}}>>
+// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #{{.*}}>>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.init{{\[}}%[[VAL_6]], %[[VAL_7]]] : tensor<?x?xi32, #{{.*}}>>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xi32>
+// CHECK: %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_2]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_2]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_19:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_20:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_21:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xi32>
+// CHECK: %[[VAL_23:.*]] = memref.alloca(%[[VAL_4]]) : memref<?xindex>
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_28:.*]]:2 = scf.while (%[[VAL_29:.*]] = %[[VAL_24]], %[[VAL_30:.*]] = %[[VAL_26]]) : (index, index) -> (index, index) {
+// CHECK: %[[VAL_31:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_25]] : index
+// CHECK: %[[VAL_32:.*]] = arith.cmpi ult, %[[VAL_30]], %[[VAL_27]] : index
+// CHECK: %[[VAL_33:.*]] = arith.andi %[[VAL_31]], %[[VAL_32]] : i1
+// CHECK: scf.condition(%[[VAL_33]]) %[[VAL_29]], %[[VAL_30]] : index, index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_34:.*]]: index, %[[VAL_35:.*]]: index):
+// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_34]]] : memref<?xindex>
+// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_35]]] : memref<?xindex>
+// CHECK: %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_37]], %[[VAL_36]] : index
+// CHECK: %[[VAL_39:.*]] = select %[[VAL_38]], %[[VAL_37]], %[[VAL_36]] : index
+// CHECK: memref.store %[[VAL_39]], %[[VAL_23]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index
+// CHECK: %[[VAL_41:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index
+// CHECK: %[[VAL_42:.*]] = arith.andi %[[VAL_40]], %[[VAL_41]] : i1
+// CHECK: scf.if %[[VAL_42]] {
+// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_34]]] : memref<?xindex>
+// CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_34]], %[[VAL_3]] : index
+// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_44]]] : memref<?xindex>
+// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_35]]] : memref<?xindex>
+// CHECK: %[[VAL_47:.*]] = arith.addi %[[VAL_35]], %[[VAL_3]] : index
+// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_47]]] : memref<?xindex>
+// CHECK: %[[VAL_49:.*]]:2 = scf.while (%[[VAL_50:.*]] = %[[VAL_43]], %[[VAL_51:.*]] = %[[VAL_46]]) : (index, index) -> (index, index) {
+// CHECK: %[[VAL_52:.*]] = arith.cmpi ult, %[[VAL_50]], %[[VAL_45]] : index
+// CHECK: %[[VAL_53:.*]] = arith.cmpi ult, %[[VAL_51]], %[[VAL_48]] : index
+// CHECK: %[[VAL_54:.*]] = arith.andi %[[VAL_52]], %[[VAL_53]] : i1
+// CHECK: scf.condition(%[[VAL_54]]) %[[VAL_50]], %[[VAL_51]] : index, index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_55:.*]]: index, %[[VAL_56:.*]]: index):
+// CHECK: %[[VAL_57:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_55]]] : memref<?xindex>
+// CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_56]]] : memref<?xindex>
+// CHECK: %[[VAL_59:.*]] = arith.cmpi ult, %[[VAL_58]], %[[VAL_57]] : index
+// CHECK: %[[VAL_60:.*]] = select %[[VAL_59]], %[[VAL_58]], %[[VAL_57]] : index
+// CHECK: memref.store %[[VAL_60]], %[[VAL_23]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_61:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index
+// CHECK: %[[VAL_62:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_60]] : index
+// CHECK: %[[VAL_63:.*]] = arith.andi %[[VAL_61]], %[[VAL_62]] : i1
+// CHECK: scf.if %[[VAL_63]] {
+// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_55]]] : memref<?xindex>
+// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_55]], %[[VAL_3]] : index
+// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_65]]] : memref<?xindex>
+// CHECK: %[[VAL_67:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_56]]] : memref<?xindex>
+// CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_56]], %[[VAL_3]] : index
+// CHECK: %[[VAL_69:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_68]]] : memref<?xindex>
+// CHECK: %[[VAL_70:.*]]:3 = scf.while (%[[VAL_71:.*]] = %[[VAL_64]], %[[VAL_72:.*]] = %[[VAL_67]], %[[VAL_73:.*]] = %[[VAL_5]]) : (index, index, i32) -> (index, index, i32) {
+// CHECK: %[[VAL_74:.*]] = arith.cmpi ult, %[[VAL_71]], %[[VAL_66]] : index
+// CHECK: %[[VAL_75:.*]] = arith.cmpi ult, %[[VAL_72]], %[[VAL_69]] : index
+// CHECK: %[[VAL_76:.*]] = arith.andi %[[VAL_74]], %[[VAL_75]] : i1
+// CHECK: scf.condition(%[[VAL_76]]) %[[VAL_71]], %[[VAL_72]], %[[VAL_73]] : index, index, i32
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_77:.*]]: index, %[[VAL_78:.*]]: index, %[[VAL_79:.*]]: i32):
+// CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_77]]] : memref<?xindex>
+// CHECK: %[[VAL_81:.*]] = memref.load %[[VAL_21]]{{\[}}%[[VAL_78]]] : memref<?xindex>
+// CHECK: %[[VAL_82:.*]] = arith.cmpi ult, %[[VAL_81]], %[[VAL_80]] : index
+// CHECK: %[[VAL_83:.*]] = select %[[VAL_82]], %[[VAL_81]], %[[VAL_80]] : index
+// CHECK: memref.store %[[VAL_83]], %[[VAL_23]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_84:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_83]] : index
+// CHECK: %[[VAL_85:.*]] = arith.cmpi eq, %[[VAL_81]], %[[VAL_83]] : index
+// CHECK: %[[VAL_86:.*]] = arith.andi %[[VAL_84]], %[[VAL_85]] : i1
+// CHECK: %[[VAL_87:.*]] = scf.if %[[VAL_86]] -> (i32) {
+// CHECK: %[[VAL_88:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_77]]] : memref<?xi32>
+// CHECK: %[[VAL_89:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_78]]] : memref<?xi32>
+// CHECK: %[[VAL_90:.*]] = arith.muli %[[VAL_88]], %[[VAL_89]] : i32
+// CHECK: %[[VAL_91:.*]] = arith.addi %[[VAL_79]], %[[VAL_90]] : i32
+// CHECK: scf.yield %[[VAL_91]] : i32
+// CHECK: } else {
+// CHECK: scf.yield %[[VAL_79]] : i32
+// CHECK: }
+// CHECK: %[[VAL_92:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_83]] : index
+// CHECK: %[[VAL_93:.*]] = arith.addi %[[VAL_77]], %[[VAL_3]] : index
+// CHECK: %[[VAL_94:.*]] = select %[[VAL_92]], %[[VAL_93]], %[[VAL_77]] : index
+// CHECK: %[[VAL_95:.*]] = arith.cmpi eq, %[[VAL_81]], %[[VAL_83]] : index
+// CHECK: %[[VAL_96:.*]] = arith.addi %[[VAL_78]], %[[VAL_3]] : index
+// CHECK: %[[VAL_97:.*]] = select %[[VAL_95]], %[[VAL_96]], %[[VAL_78]] : index
+// CHECK: scf.yield %[[VAL_94]], %[[VAL_97]], %[[VAL_98:.*]] : index, index, i32
+// CHECK: }
+// CHECK: sparse_tensor.lex_insert %[[VAL_8]], %[[VAL_23]], %[[VAL_99:.*]]#2 : tensor<?x?xi32, #{{.*}}>, memref<?xindex>, i32
+// CHECK: } else {
+// CHECK: }
+// CHECK: %[[VAL_100:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index
+// CHECK: %[[VAL_101:.*]] = arith.addi %[[VAL_55]], %[[VAL_3]] : index
+// CHECK: %[[VAL_102:.*]] = select %[[VAL_100]], %[[VAL_101]], %[[VAL_55]] : index
+// CHECK: %[[VAL_103:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_60]] : index
+// CHECK: %[[VAL_104:.*]] = arith.addi %[[VAL_56]], %[[VAL_3]] : index
+// CHECK: %[[VAL_105:.*]] = select %[[VAL_103]], %[[VAL_104]], %[[VAL_56]] : index
+// CHECK: scf.yield %[[VAL_102]], %[[VAL_105]] : index, index
+// CHECK: }
+// CHECK: } else {
+// CHECK: }
+// CHECK: %[[VAL_106:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index
+// CHECK: %[[VAL_107:.*]] = arith.addi %[[VAL_34]], %[[VAL_3]] : index
+// CHECK: %[[VAL_108:.*]] = select %[[VAL_106]], %[[VAL_107]], %[[VAL_34]] : index
+// CHECK: %[[VAL_109:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index
+// CHECK: %[[VAL_110:.*]] = arith.addi %[[VAL_35]], %[[VAL_3]] : index
+// CHECK: %[[VAL_111:.*]] = select %[[VAL_109]], %[[VAL_110]], %[[VAL_35]] : index
+// CHECK: scf.yield %[[VAL_108]], %[[VAL_111]] : index, index
+// CHECK: }
+// CHECK: %[[VAL_112:.*]] = sparse_tensor.load %[[VAL_8]] hasInserts : tensor<?x?xi32, #{{.*}}>
+// CHECK: return %[[VAL_112]] : tensor<?x?xi32, #{{.*}}>
+// CHECK: }
+func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
+ %argb: tensor<?x?x?xi32, #SparseTensor>) -> tensor<?x?xi32, #DCSR> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %arga, %c0 : tensor<?x?x?xi32, #SparseTensor>
+ %d1 = tensor.dim %arga, %c1 : tensor<?x?x?xi32, #SparseTensor>
+ %xinit = sparse_tensor.init [%d0, %d1] : tensor<?x?xi32, #DCSR>
+ %0 = linalg.generic #trait_sumred
+ ins(%arga, %argb: tensor<?x?x?xi32, #SparseTensor>,
+ tensor<?x?x?xi32, #SparseTensor>)
+ outs(%xinit: tensor<?x?xi32, #DCSR>) {
+ ^bb(%a: i32, %b: i32, %x: i32):
+ %0 = arith.muli %a, %b : i32
+ %1 = arith.addi %x, %0 : i32
+ linalg.yield %1 : i32
+ } -> tensor<?x?xi32, #DCSR>
+ return %0 : tensor<?x?xi32, #DCSR>
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir
new file mode 100644
index 0000000000000..08343231bdb72
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir
@@ -0,0 +1,99 @@
+// RUN: mlir-opt %s \
+// RUN: --sparsification --sparse-tensor-conversion \
+// RUN: --linalg-bufferize --convert-linalg-to-loops \
+// RUN: --convert-vector-to-scf --convert-scf-to-std \
+// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN: --std-bufferize --finalizing-bufferize --lower-affine \
+// RUN: --convert-vector-to-llvm --convert-memref-to-llvm --convert-math-to-llvm \
+// RUN: --convert-std-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#SparseMatrix = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed", "compressed" ]
+}>
+
+#SparseTensor = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed", "compressed", "compressed" ]
+}>
+
+#redsum = {
+ indexing_maps = [
+ affine_map<(i,j,k) -> (i,j,k)>, // A
+ affine_map<(i,j,k) -> (i,j,k)>, // B
+ affine_map<(i,j,k) -> (i,j)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ doc = "X(i,j) = SUM_k A(i,j,k) * B(i,j,k)"
+}
+
+module {
+ func @redsum(%arga: tensor<?x?x?xi32, #SparseTensor>,
+ %argb: tensor<?x?x?xi32, #SparseTensor>)
+ -> tensor<?x?xi32, #SparseMatrix> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %arga, %c0 : tensor<?x?x?xi32, #SparseTensor>
+ %d1 = tensor.dim %arga, %c1 : tensor<?x?x?xi32, #SparseTensor>
+ %xinit = sparse_tensor.init [%d0, %d1] : tensor<?x?xi32, #SparseMatrix>
+ %0 = linalg.generic #redsum
+ ins(%arga, %argb: tensor<?x?x?xi32, #SparseTensor>,
+ tensor<?x?x?xi32, #SparseTensor>)
+ outs(%xinit: tensor<?x?xi32, #SparseMatrix>) {
+ ^bb(%a: i32, %b: i32, %x: i32):
+ %0 = arith.muli %a, %b : i32
+ %1 = arith.addi %x, %0 : i32
+ linalg.yield %1 : i32
+ } -> tensor<?x?xi32, #SparseMatrix>
+ return %0 : tensor<?x?xi32, #SparseMatrix>
+ }
+
+ // Driver method to call and verify tensor kernel.
+ func @entry() {
+ %c0 = arith.constant 0 : index
+ %i0 = arith.constant -1 : i32
+
+ // Setup very sparse 3-d tensors.
+ %t1 = arith.constant sparse<
+ [ [1,1,3], [2,0,0], [2,2,1], [2,2,2], [2,2,3] ], [ 1, 2, 3, 4, 5 ]
+ > : tensor<3x3x4xi32>
+ %t2 = arith.constant sparse<
+ [ [1,0,0], [1,1,3], [2,2,1], [2,2,3] ], [ 6, 7, 8, 9 ]
+ > : tensor<3x3x4xi32>
+ %st1 = sparse_tensor.convert %t1
+ : tensor<3x3x4xi32> to tensor<?x?x?xi32, #SparseTensor>
+ %st2 = sparse_tensor.convert %t2
+ : tensor<3x3x4xi32> to tensor<?x?x?xi32, #SparseTensor>
+
+
+ // Call kernel.
+ %0 = call @redsum(%st1, %st2)
+ : (tensor<?x?x?xi32, #SparseTensor>,
+ tensor<?x?x?xi32, #SparseTensor>) -> tensor<?x?xi32, #SparseMatrix>
+
+ //
+ // Verify results. Only two entries stored in result. Correct structure.
+ //
+ // CHECK: ( 7, 69, -1, -1 )
+ // CHECK-NEXT: ( ( 0, 0, 0 ), ( 0, 7, 0 ), ( 0, 0, 69 ) )
+ //
+ %val = sparse_tensor.values %0
+ : tensor<?x?xi32, #SparseMatrix> to memref<?xi32>
+ %vv = vector.transfer_read %val[%c0], %i0: memref<?xi32>, vector<4xi32>
+ vector.print %vv : vector<4xi32>
+ %dm = sparse_tensor.convert %0
+ : tensor<?x?xi32, #SparseMatrix> to tensor<?x?xi32>
+ %db = bufferization.to_memref %dm : memref<?x?xi32>
+ %vm = vector.transfer_read %db[%c0, %c0], %i0: memref<?x?xi32>, vector<3x3xi32>
+ vector.print %vm : vector<3x3xi32>
+
+ // Release the resources.
+ sparse_tensor.release %st1 : tensor<?x?x?xi32, #SparseTensor>
+ sparse_tensor.release %st2 : tensor<?x?x?xi32, #SparseTensor>
+ sparse_tensor.release %0 : tensor<?x?xi32, #SparseMatrix>
+ memref.dealloc %db : memref<?x?xi32>
+ return
+ }
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
index 3d2da32ea51f0..08e380d4324ed 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
@@ -144,7 +144,7 @@ module {
return %0 : tensor<f64>
}
- // Dumps just the values array of the sparse vector.
+ // Dumps a sparse vector.
func @dump(%arg0: tensor<?xf64, #SparseVector>) {
// Dump the values array to verify only sparse contents are stored.
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list