[Mlir-commits] [mlir] cb82d37 - [mlir][sparse][vectorization] optimize reduction chains
Aart Bik
llvmlistbot at llvm.org
Sat Nov 26 12:41:03 PST 2022
Author: Aart Bik
Date: 2022-11-26T12:40:51-08:00
New Revision: cb82d375a8060bd3af83b64d7d2c94f4a59d4b97
URL: https://github.com/llvm/llvm-project/commit/cb82d375a8060bd3af83b64d7d2c94f4a59d4b97
DIFF: https://github.com/llvm/llvm-project/commit/cb82d375a8060bd3af83b64d7d2c94f4a59d4b97.diff
LOG: [mlir][sparse][vectorization] optimize reduction chains
A few more dots on the i's of the sparse vectorizer.
Also makes reduction matching less brittle.
Reviewed By: qcolombet
Differential Revision: https://reviews.llvm.org/D138513
Added:
mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index aed394990428d..028a471f41c8d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -101,7 +101,9 @@ static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
}
/// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
-/// where 'lo' denotes the current index and 'hi = lo + vl - 1'.
+/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
+/// that the sparse compiler can only generate indirect loads in
+/// the last index, i.e. back().
static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
Value ptr, ArrayRef<Value> idxs, Value vmask) {
VectorType vtp = vectorType(vl, ptr);
@@ -118,7 +120,9 @@ static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
}
/// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
-/// where 'lo' denotes the current index and 'hi = lo + vl - 1'.
+/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
+/// that the sparse compiler can only generate indirect stores in
+/// the last index, i.e. back().
static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr,
ArrayRef<Value> idxs, Value vmask, Value rhs) {
if (idxs.back().getType().isa<VectorType>()) {
@@ -132,32 +136,60 @@ static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr,
rewriter.create<vector::MaskedStoreOp>(loc, ptr, idxs, vmask, rhs);
}
-/// Maps operation to combining kind for reduction.
-static vector::CombiningKind getCombiningKind(Operation *def) {
- if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def) ||
- isa<arith::SubFOp>(def) || isa<arith::SubIOp>(def))
- return vector::CombiningKind::ADD;
- if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
- return vector::CombiningKind::MUL;
- if (isa<arith::AndIOp>(def))
- return vector::CombiningKind::AND;
- if (isa<arith::OrIOp>(def))
- return vector::CombiningKind::OR;
- if (isa<arith::XOrIOp>(def))
- return vector::CombiningKind::XOR;
- llvm_unreachable("unknown reduction kind");
+/// Detects a vectorizable reduction operations and returns the
+/// combining kind of reduction on success in `kind`.
+static bool isVectorizableReduction(Value red, Value iter,
+ vector::CombiningKind &kind) {
+ if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
+ kind = vector::CombiningKind::ADD;
+ return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
+ }
+ if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
+ kind = vector::CombiningKind::ADD;
+ return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
+ }
+ if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
+ kind = vector::CombiningKind::ADD;
+ return subf->getOperand(0) == iter;
+ }
+ if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
+ kind = vector::CombiningKind::ADD;
+ return subi->getOperand(0) == iter;
+ }
+ if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
+ kind = vector::CombiningKind::MUL;
+ return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
+ }
+ if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
+ kind = vector::CombiningKind::MUL;
+ return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
+ }
+ if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
+ kind = vector::CombiningKind::AND;
+ return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
+ }
+ if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
+ kind = vector::CombiningKind::OR;
+ return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
+ }
+ if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
+ kind = vector::CombiningKind::XOR;
+ return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
+ }
+ return false;
}
/// Generates an initial value for a vector reduction, following the scheme
/// given in Chapter 5 of "The Software Vectorization Handbook", where the
/// initial scalar value is correctly embedded in the vector reduction value,
/// and a straightforward horizontal reduction will complete the operation.
-/// The value 'r' denotes the initial value of the accumulator. Value 'rd'
-/// denotes the accumulation operation, which is solely used here to determine
-/// the kind of combining reduction (viz. addf -> sum-accumulation).
+/// Value 'r' denotes the initial value of the reduction outside the loop.
static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
- VectorType vtp, Value r, Value rd) {
- vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp());
+ Value red, Value iter, Value r,
+ VectorType vtp) {
+ vector::CombiningKind kind;
+ if (!isVectorizableReduction(red, iter, kind))
+ llvm_unreachable("unknown reduction");
switch (kind) {
case vector::CombiningKind::ADD:
case vector::CombiningKind::XOR:
@@ -180,13 +212,6 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
llvm_unreachable("unknown reduction kind");
}
-/// Generates final value for a vector reduction.
-static Value genVectorReducEnd(PatternRewriter &rewriter, Location loc,
- Value vexp, Value rd) {
- vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp());
- return rewriter.create<vector::ReductionOp>(loc, kind, vexp);
-}
-
/// This method is called twice to analyze and rewrite the given subscripts.
/// The first call (!codegen) does the analysis. Then, on success, the second
/// call (codegen) yields the proper vector form in the output parameter
@@ -379,10 +404,14 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
if (!yield.getResults().empty()) {
Value init = forOp.getInitArgs()[0];
VectorType vtp = vectorType(vl, init.getType());
- Value vinit =
- genVectorReducInit(rewriter, loc, vtp, init, yield->getOperand(0));
+ Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
+ forOp.getRegionIterArg(0), init, vtp);
forOpNew = rewriter.create<scf::ForOp>(
loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
+ forOpNew->setAttr(
+ SparseTensorLoopEmitter::getLoopEmitterLoopAttrName(),
+ forOp->getAttr(
+ SparseTensorLoopEmitter::getLoopEmitterLoopAttrName()));
rewriter.setInsertionPointToStart(forOpNew.getBody());
} else {
forOp.setStep(step);
@@ -395,20 +424,22 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
// Sparse for-loops either are terminated by a non-empty yield operation
// (reduction loop) or otherwise by a store operation (pararallel loop).
if (!yield.getResults().empty()) {
+ // Analyze/vectorize reduction.
if (yield->getNumOperands() != 1)
return false;
- Value redOp = yield->getOperand(0);
- // Analyze/vectorize reduction.
- // TODO: use linalg utils to verify the actual reduction?
+ Value red = yield->getOperand(0);
+ Value iter = forOp.getRegionIterArg(0);
+ vector::CombiningKind kind;
Value vrhs;
- if (vectorizeExpr(rewriter, forOp, vl, redOp, codegen, vmask, vrhs)) {
+ if (isVectorizableReduction(red, iter, kind) &&
+ vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
if (codegen) {
- Value vpass =
- genVectorInvariantValue(rewriter, vl, forOp.getRegionIterArg(0));
+ Value partial = forOpNew.getResult(0);
+ Value vpass = genVectorInvariantValue(rewriter, vl, iter);
Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
rewriter.create<scf::YieldOp>(loc, vred);
rewriter.setInsertionPointAfter(forOpNew);
- Value vres = genVectorReducEnd(rewriter, loc, forOpNew.getResult(0), redOp);
+ Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
// Now do some relinking (last one is not completely type safe
// but all bad ones are removed right away). This also folds away
// nop broadcast operations.
@@ -469,6 +500,32 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
const VL vl;
};
+/// Reduction chain cleanup.
+/// v = for { }
+/// s = vsum(v) v = for { }
+/// u = expand(s) -> for (v) { }
+/// for (u) { }
+template <typename VectorOp>
+struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
+public:
+ using OpRewritePattern<VectorOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(VectorOp op,
+ PatternRewriter &rewriter) const override {
+ Value inp = op.getSource();
+ if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
+ if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
+ if (forOp->hasAttr(
+ SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) {
+ rewriter.replaceOp(op, redOp.getVector());
+ return success();
+ }
+ }
+ }
+ return failure();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -482,4 +539,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
bool enableSIMDIndex32) {
patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
enableVLAVectorization, enableSIMDIndex32);
+ patterns.add<ReducChainRewriter<vector::InsertElementOp>,
+ ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
new file mode 100644
index 0000000000000..612927c471920
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -0,0 +1,122 @@
+// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=8" -cse | \
+// RUN: FileCheck %s
+
+#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","compressed"]}>
+
+#trait = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // a (in)
+ affine_map<(i,j) -> (i,j)>, // b (in)
+ affine_map<(i,j) -> ()> // x (out)
+ ],
+ iterator_types = ["reduction", "reduction"]
+}
+
+//
+// Verifies that the SIMD reductions in the two for-loops after the
+// while-loop are chained before horizontally reducing these back to scalar.
+//
+// CHECK-LABEL: func.func @sparse_matrix_sum(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<f64>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>) -> tensor<f64> {
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<0.000000e+00> : vector<8xf64>
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 64 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xf64>
+// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_2]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_2]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref<?xf64>
+// CHECK: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_0]] : memref<f64>
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_14]][] : memref<f64>
+// CHECK: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) {
+// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex>
+// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_17]], %[[VAL_7]] : index
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_17]]] : memref<?xindex>
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK: %[[VAL_24:.*]]:3 = scf.while (%[[VAL_25:.*]] = %[[VAL_19]], %[[VAL_26:.*]] = %[[VAL_22]], %[[VAL_27:.*]] = %[[VAL_18]]) : (index, index, f64) -> (index, index, f64) {
+// CHECK: %[[VAL_28:.*]] = arith.cmpi ult, %[[VAL_25]], %[[VAL_21]] : index
+// CHECK: %[[VAL_29:.*]] = arith.cmpi ult, %[[VAL_26]], %[[VAL_23]] : index
+// CHECK: %[[VAL_30:.*]] = arith.andi %[[VAL_28]], %[[VAL_29]] : i1
+// CHECK: scf.condition(%[[VAL_30]]) %[[VAL_25]], %[[VAL_26]], %[[VAL_27]] : index, index, f64
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_31:.*]]: index, %[[VAL_32:.*]]: index, %[[VAL_33:.*]]: f64):
+// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_31]]] : memref<?xindex>
+// CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<?xindex>
+// CHECK: %[[VAL_36:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_34]] : index
+// CHECK: %[[VAL_37:.*]] = arith.select %[[VAL_36]], %[[VAL_35]], %[[VAL_34]] : index
+// CHECK: %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_37]] : index
+// CHECK: %[[VAL_39:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_37]] : index
+// CHECK: %[[VAL_40:.*]] = arith.andi %[[VAL_38]], %[[VAL_39]] : i1
+// CHECK: %[[VAL_41:.*]] = scf.if %[[VAL_40]] -> (f64) {
+// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref<?xf64>
+// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK: %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f64
+// CHECK: %[[VAL_45:.*]] = arith.addf %[[VAL_33]], %[[VAL_44]] : f64
+// CHECK: scf.yield %[[VAL_45]] : f64
+// CHECK: } else {
+// CHECK: %[[VAL_46:.*]] = scf.if %[[VAL_38]] -> (f64) {
+// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref<?xf64>
+// CHECK: %[[VAL_48:.*]] = arith.addf %[[VAL_33]], %[[VAL_47]] : f64
+// CHECK: scf.yield %[[VAL_48]] : f64
+// CHECK: } else {
+// CHECK: %[[VAL_49:.*]] = scf.if %[[VAL_39]] -> (f64) {
+// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_32]]] : memref<?xf64>
+// CHECK: %[[VAL_51:.*]] = arith.addf %[[VAL_33]], %[[VAL_50]] : f64
+// CHECK: scf.yield %[[VAL_51]] : f64
+// CHECK: } else {
+// CHECK: scf.yield %[[VAL_33]] : f64
+// CHECK: }
+// CHECK: scf.yield %[[VAL_52:.*]] : f64
+// CHECK: }
+// CHECK: scf.yield %[[VAL_53:.*]] : f64
+// CHECK: }
+// CHECK: %[[VAL_54:.*]] = arith.addi %[[VAL_31]], %[[VAL_7]] : index
+// CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_38]], %[[VAL_54]], %[[VAL_31]] : index
+// CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index
+// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index
+// CHECK: scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64
+// CHECK: } attributes {"Emitted from" = "linalg.generic"}
+// CHECK: %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64>
+// CHECK: %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) {
+// CHECK: %[[VAL_64:.*]] = affine.min #map2(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]]
+// CHECK: %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1>
+// CHECK: %[[VAL_66:.*]] = vector.maskedload %[[VAL_10]]{{\[}}%[[VAL_62]]], %[[VAL_65]], %[[VAL_4]] : memref<?xf64>, vector<8xi1>, vector<8xf64> into vector<8xf64>
+// CHECK: %[[VAL_67:.*]] = arith.addf %[[VAL_63]], %[[VAL_66]] : vector<8xf64>
+// CHECK: %[[VAL_68:.*]] = arith.select %[[VAL_65]], %[[VAL_67]], %[[VAL_63]] : vector<8xi1>, vector<8xf64>
+// CHECK: scf.yield %[[VAL_68]] : vector<8xf64>
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: %[[VAL_69:.*]] = scf.for %[[VAL_70:.*]] = %[[VAL_60]]#1 to %[[VAL_23]] step %[[VAL_3]] iter_args(%[[VAL_71:.*]] = %[[VAL_61]]) -> (vector<8xf64>) {
+// CHECK: %[[VAL_73:.*]] = affine.min #map2(%[[VAL_23]], %[[VAL_70]]){{\[}}%[[VAL_3]]]
+// CHECK: %[[VAL_74:.*]] = vector.create_mask %[[VAL_73]] : vector<8xi1>
+// CHECK: %[[VAL_75:.*]] = vector.maskedload %[[VAL_13]]{{\[}}%[[VAL_70]]], %[[VAL_74]], %[[VAL_4]] : memref<?xf64>, vector<8xi1>, vector<8xf64> into vector<8xf64>
+// CHECK: %[[VAL_76:.*]] = arith.addf %[[VAL_71]], %[[VAL_75]] : vector<8xf64>
+// CHECK: %[[VAL_77:.*]] = arith.select %[[VAL_74]], %[[VAL_76]], %[[VAL_71]] : vector<8xi1>, vector<8xf64>
+// CHECK: scf.yield %[[VAL_77]] : vector<8xf64>
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: %[[VAL_78:.*]] = vector.reduction <add>, %[[VAL_69]] : vector<8xf64> into f64
+// CHECK: scf.yield %[[VAL_78]] : f64
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: memref.store %[[VAL_80:.*]], %[[VAL_14]][] : memref<f64>
+// CHECK: %[[VAL_81:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<f64>
+// CHECK: return %[[VAL_81]] : tensor<f64>
+// CHECK: }
+func.func @sparse_matrix_sum(%argx: tensor<f64>,
+ %arga: tensor<64x32xf64, #SparseMatrix>,
+ %argb: tensor<64x32xf64, #SparseMatrix>) -> tensor<f64> {
+ %0 = linalg.generic #trait
+ ins(%arga, %argb: tensor<64x32xf64, #SparseMatrix>,
+ tensor<64x32xf64, #SparseMatrix>)
+ outs(%argx: tensor<f64>) {
+ ^bb(%a: f64, %b: f64, %x: f64):
+ %m = arith.addf %a, %b : f64
+ %t = arith.addf %x, %m : f64
+ linalg.yield %t : f64
+ } -> tensor<f64>
+ return %0 : tensor<f64>
+}
More information about the Mlir-commits
mailing list